torch.distributed.tensor — PyTorch 2.7 documentation (original) (raw)
Note
torch.distributed.tensor
is currently in alpha state and under development, we are committing backward compatibility for the most APIs listed in the doc, but there might be API changes if necessary.
PyTorch DTensor (Distributed Tensor)¶
PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed logic, including sharded storage, operator computation and collective communications across devices/hosts.DTensor
could be used to build different paralleism solutions and support sharded state_dict representation when working with multi-dimensional sharding.
Please see examples from the PyTorch native parallelism solutions that are built on top of DTensor
:
DTensor follows the SPMD (single program, multiple data) programming model to empower users to write distributed program as if it’s a single-device program with the same convergence property. It provides a uniform tensor sharding layout (DTensor Layout) through specifying the DeviceMesh
and Placement
:
DeviceMesh
represents the device topology and the communicators of the cluster using an n-dimensional array.Placement
describes the sharding layout of the logical tensor on theDeviceMesh
. DTensor supports three types of placements:Shard
,Replicate
andPartial
.
DTensor Class APIs¶
DTensor is a torch.Tensor
subclass. This means once a DTensor is created, it could be used in very similar way to torch.Tensor
, including running different types of PyTorch operators as if running them in a single device, allowing proper distributed computation for PyTorch operators.
In addition to existing torch.Tensor
methods, it also offers a set of additional methods to interact withtorch.Tensor
, redistribute
the DTensor Layout to a new DTensor, get the full tensor content on all devices, etc.
class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)¶
DTensor
(Distributed Tensor) is a subclass of torch.Tensor
that provides single-device like abstraction to program with multi-device torch.Tensor
. It describes the distributed tensor sharding layout (DTensor Layout) through the DeviceMesh
and following types of Placement
:
Shard
: Tensor sharded on the tensor dimensiondim
on the devices of theDeviceMesh
dimensionReplicate
: Tensor replicated on the devices of theDeviceMesh
dimensionPartial
: Tensor is pending reduction on the devices of theDeviceMesh
dimension
When calling PyTorch operators, DTensor
overrides the PyTorch operators to perform sharded computation and issue communications whenever necessary. Along with the operator computation, DTensor
will transform or propagate the placements (DTensor Layout) properly (based on the operator semantic itself) and generate new DTensor
outputs.
To ensure numerical correctness of the DTensor
sharded computation when calling PyTorch operators, DTensor
requires every Tensor argument of the operator be DTensor.
Note
Directly using the Tensor subclass constructor here is not the recommended way to create a DTensor
(i.e. it does not handle autograd correctly hence is not the public API). Please refer to the create_dtensorsection to see how to create a DTensor
.
Return type
__create_chunk_list__()[source][source]¶
Return a list of ChunkStorageMetadata, which is a dataclass that describes the size/offset of the local shard/replica on current rank. For DTensor, each rank will have a single local shard/replica, so the returned list usually only has one element.
This dunder method is primariy used for distributed checkpoint purpose.
Returns
A List[ChunkStorageMetadata
] object that represents the shard size/offset on the current rank.
static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source][source]¶
Create a DTensor from a local torch.Tensor on each rank according to the device_mesh
and placements
specified.
Parameters
- local_tensor (torch.Tensor) – local torch.Tensor on each rank.
- device_mesh (
DeviceMesh
, optional) – DeviceMesh to place the tensor, if not specified, must be called under a DeviceMesh context manager, default: None - placements (List[
Placement
], optional) – the placements that describes how to place the local torch.Tensor on DeviceMesh, must have the same number of elements asdevice_mesh.ndim
.
Keyword Arguments
- run_check (bool, optional) – at a cost of extra communications, perform sanity check across ranks to check each local tensor’s meta information to ensure correctness. If have
Replicate
inplacements
, the data on first rank of the device mesh dimension will be broadcasted to other ranks. default: False - shape (torch.Size, optional) – A List of int which specifies the size of DTensor which build on top of local_tensor. Note this needs to be provided if the shape of
local_tensor
are different across the ranks. If not provided,shape
will be computed assuming the given distributed tensor is evenly sharded across ranks. default: None - stride (tuple, optional) – A List of int which specifies the stride of DTensor. If not provided,
stride
will be computed assuming the given distributed tensor is evenly sharded across ranks. default: None
Returns
A DTensor object
Return type
Note
When run_check=False
, it is the user’s responsibility to ensure the local tensor passed in is correct across ranks (i.e. the tensor is sharded for the Shard(dim)
placement or replicated for the Replicate()
placement). If not, the behavior of the created DTensor is undefined.
Note
from_local
is differentiable, the requires_grad of the createdDTensor object will depend on if local_tensor requires_grad or not.
full_tensor(*, grad_placements=None)[source][source]¶
Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate them together. It’s a syntatic sugar of the following code:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()
Keyword Arguments
grad_placements (List[Placement
], optional) – the placements describes the future layout of any gradient layout of the full Tensor returned from this function.full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensor might not be used as the original replicated DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original replicated DTensor layout. If not specified, we will assume the gradient layout of the full tensor be replicated.
Returns
A torch.Tensor object that represents the full tensor of this DTensor.
Return type
Note
full_tensor
is differentiable.
redistribute(device_mesh=None, placements=None, *, async_op=False)[source][source]¶
redistribute
performs necessary collective operations that redistribute the current DTensor from its current placements to a new placements, or from is current DeviceMesh to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by specifying a Replicate placement for each dimension of the DeviceMesh.
When redistributing from current to the new placements on one device mesh dimension, we will perform the following operations including communication collective or local operation:
Shard(dim)
->Replicate()
:all_gather
Shard(src_dim)
->Shard(dst_dim)
:all_to_all
Replicate()
->Shard(dim)
: local chunking (i.e.torch.chunk
)Partial()
->Replicate()
:all_reduce
Partial()
->Shard(dim)
:reduce_scatter
redistribute
would correctly figure out the necessary redistribute steps for DTensors that are created either on 1-D or N-D DeviceMesh.
Parameters
- device_mesh (
DeviceMesh
, optional) – DeviceMesh to place the DTensor. If not specified, it would use the current DTensor’s DeviceMesh. default: None - placements (List[
Placement
], optional) – the new placements that describes how to place the DTensor into the DeviceMesh, must have the same number of elements asdevice_mesh.ndim
. default: replicate on all mesh dimensions
Keyword Arguments
async_op (bool, optional) – whether to perform the DTensor redistribute operation asynchronously or not. Default: False
Returns
A DTensor object
Return type
Note
redistribute
is differentiable, which means user do not need to worry about the backward formula of the redistribute operation.
Note
redistribute
currently only supports redistributing DTensor on the same DeviceMesh, Please file an issue if you need to redistribute DTensor to different DeviceMesh.
to_local(*, grad_placements=None)[source][source]¶
Get the local tensor of this DTensor on its current rank. For sharding it returns a local shard of the logical tensor view, for replication it returns the replica on its current rank.
Keyword Arguments
grad_placements (List[Placement
], optional) – the placements describes the future layout of any gradient layout of the Tensor returned from this function.to_local converts DTensor to local tensor and the returned local tensor might not be used as the original DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original DTensor layout. If not specified, we will assume the gradient layout remains the same as the original DTensor and use that for gradient computation.
Returns
A torch.Tensor or AsyncCollectiveTensor
object. it represents the local tensor on its current rank. When an AsyncCollectiveTensor
object is returned, it means the local tensor is not ready yet (i.e. communication is not finished). In this case, user needs to call wait
to wait the local tensor to be ready.
Return type
Note
to_local
is differentiable, the requires_grad
of the local tensor returned will depend on if the DTensor requires_grad or not.
property device_mesh_: DeviceMesh_¶
The DeviceMesh
attribute that associates with this DTensor object.
Note
device_mesh
is a read-only property, it can not be set.
property placements_: tuple[torch.distributed.tensor.placement_types.Placement, ...]_¶
The placements attribute of this DTensor that describes the layout of this DTensor on the its DeviceMesh.
Note
placements
is a read-only property, it can not be set.
DeviceMesh as the distributed communicator¶
DeviceMesh was built from DTensor as the abstraction to describe cluster’s device topology and represent multi-dimensional communicators (on top of ProcessGroup
). To see the details of how to create/use a DeviceMesh, please refer to the DeviceMesh recipe.
DTensor Placement Types¶
DTensor supports the following types of Placement on each DeviceMesh
dimension:
class torch.distributed.tensor.placement_types.Shard(dim)[source][source]¶
The Shard(dim)
placement describes the DTensor sharding on tensor dimensiondim
over a corresponding DeviceMesh
dimension, where each rank on the DeviceMesh dimension only holds a shard/piece of the global Tensor. TheShard(dim)
placement follows the torch.chunk(dim)
semantic, where the last few shards on the DeviceMesh dimension might be empty when the tensor dimension is not evenly divisible on the DeviceMesh dimension. The Shard
placement can be used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)
Parameters
dim (int) – The tensor dimension that describes the DTensor is sharded over its corresponding DeviceMesh dimension.
Warning
sharding on a tensor dimension where the tensor dimension size is not evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
class torch.distributed.tensor.placement_types.Replicate[source][source]¶
The Replicate()
placement describes the DTensor replicating on a correspondingDeviceMesh
dimension, where each rank on the DeviceMesh dimension holds a replica of the global Tensor. The Replicate
placement can be used by all DTensor APIs (i.e. distribute_tensor
, DTensor.from_local
, etc.)
class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source][source]¶
The Partial(reduce_op)
placement describes the DTensor that is pending reduction on a specified DeviceMesh
dimension, where each rank on the DeviceMesh dimension holds the partial value of the global Tensor. User can redistribute the Partial
DTensor to a Replicate
or Shard(dim)
placement on the specified DeviceMesh
dimension using redistribute
, which would trigger necessary communication operations under the hood (i.e.allreduce
, reduce_scatter
).
Parameters
reduce_op (str, optional) – The reduction op to be used for the partial DTensor to produce Replicated/Sharded DTensor. Only element-wise reduction operations are supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”.
Note
The Partial
placement can be generated as a result of the DTensor operators, and can only be used by the DTensor.from_local
API.
class torch.distributed.tensor.placement_types.Placement[source][source]¶
The base class for the Placement type, where it describes how a DTensor is placed onto theDeviceMesh
. Placement
and DeviceMesh
together could describe the DTensor Layout. It is the base class of the three main DTensor Placement types: Shard
, Replicate
, and Partial
.
This class is not meant to be used directly, mainly served as a typing stub.
Return type
is_replicate()[source][source]¶
Return type
is_shard(dim=None)[source][source]¶
Return type
Different ways to create a DTensor¶
There’re three ways to construct a DTensor:
- distribute_tensor() creates a DTensor from a logical or “global”
torch.Tensor
on each rank. This could be used to shard the leaftorch.Tensor
s (i.e. model parameters/buffers and inputs). - DTensor.from_local() creates a DTensor from a local
torch.Tensor
on each rank, which can be used to create DTensor from a non-leaftorch.Tensor
s (i.e. intermediate activation tensors during forward/backward). - DTensor provides dedicated tensor factory functions (e.g. empty(), ones(), randn(), etc.) to allow different DTensor creations by directly specifying the
DeviceMesh
andPlacement
. Compare to distribute_tensor(), this could directly materializing the sharded memory on device, instead of performing sharding after initializing the logical Tensor memory.
Create DTensor from a logical torch.Tensor¶
The SPMD (single program, multiple data) programming model in torch.distributed
launches multiple processes (i.e. via torchrun
) to execute the same program, this means that the model inside the program would be initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly on GPU if enough memory).
DTensor
offers a distribute_tensor() API that could shard the model weights or Tensors to DTensor
s, where it would create a DTensor from the “logical” Tensor on each process. This would empower the createdDTensor
s to comply with the single device semantic, which is critical for numerical correctness.
torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)[source]¶
Distribute a leaf torch.Tensor
(i.e. nn.Parameter/buffers) to the device_mesh
according to the placements
specified. The rank of device_mesh
and placements
must be the same. The tensor
to distribute is the logical or “global” tensor, and the API would use the tensor
from first rank of the DeviceMesh dimension as the source of truth to preserve the single-device semantic. If you want to construct a DTensor in the middle of the Autograd computation, please use DTensor.from_local() instead.
Parameters
- tensor (torch.Tensor) – torch.Tensor to be distributed. Note that if you want to shard a tensor on a dimension that is not evenly divisible by the number of devices in that mesh dimension, we use
torch.chunk
semantic to shard the tensor and scatter the shards. The uneven sharding behavior is experimental and subject to change. - device_mesh (
DeviceMesh
, optional) – DeviceMesh to distribute the tensor, if not specified, must be called under a DeviceMesh context manager, default: None - placements (List[
Placement
], optional) – the placements that describes how to place the tensor on DeviceMesh, must have the same number of elements asdevice_mesh.ndim
. If not specified, we will by default replicate the tensor across thedevice_mesh
from the first rank of each dimension of the device_mesh.
Keyword Arguments
src_data_rank (int, optional) – the rank of the source data for the logical/global tensor, it is used by distribute_tensor() to scatter/broadcast the shards/replicas to other ranks. By default, we use group_rank=0
on each DeviceMesh dimension as the source data to preserve the single-device semantic. If passing None
explicitly, distribute_tensor() simply uses its local data instead of trying to preserve the single-device semantic via scatter/broadcast. Default: 0
Returns
A DTensor or XLAShardedTensor
object.
Return type
Note
When initialize the DeviceMesh with the xla
device_type, distribute_tensor
return XLAShardedTensor instead. see this issuefor more details. The XLA integration is experimental and subject to change.
Along with distribute_tensor(), DTensor also offers a distribute_module() API to allow easier sharding on the nn.Module
level
torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]¶
This function expose three functions to control the parameters/inputs/outputs of the module:
1. To perform sharding on the module before runtime execution by specifying thepartition_fn
(i.e. allow user to convert Module parameters to DTensorparameters according to the partition_fn specified). 2. To control the inputs or outputs of the module during runtime execution by specifying the input_fn
and output_fn
. (i.e. convert the input toDTensor, convert the output back to torch.Tensor
)
Parameters
- module (
nn.Module
) – user module to be partitioned. - device_mesh (
DeviceMesh
) – the device mesh to place the module. - partition_fn (Callable) – the function to partition parameters (i.e. shard certain parameters across the
device_mesh
). Ifpartition_fn
is not specified, by default we replicate all module parameters ofmodule
across the mesh. - input_fn (Callable) – specify the input distribution, i.e. could control how the input of the module is sharded.
input_fn
will be installed as a moduleforward_pre_hook
(pre forward hook). - output_fn (Callable) – specify the output distribution, i.e. could control how the output is sharded, or convert it back to torch.Tensor.
output_fn
will be installed as a moduleforward_hook
(post forward hook).
Returns
A module that contains parameters/buffers that are all DTensor
s.
Return type
Note
When initialize the DeviceMesh with the xla
device_type, distribute_module
return nn.Module with PyTorch/XLA SPMD annotated parameters. Seethis issuefor more details. The XLA integration is experimental and subject to change.
DTensor Factory Functions¶
DTensor also provides dedicated tensor factory functions to allow creating DTensor directly using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally specifying the DeviceMesh
and Placement
for the DTensor created:
torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
Returns a DTensor filled with the scalar value 0.
Parameters
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
Keyword Arguments
- requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default:
False
. - dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if
None
, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional) – the desired layout of returned DTensor. Default:
torch.strided
. - device_mesh –
DeviceMesh
type, contains the mesh info of ranks - placements – a sequence of
Placement
type:Shard
,Replicate
Returns
A DTensor object on each rank
Return type
torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
Returns a DTensor filled with the scalar value 1, with the shape defined by the variable argument size
.
Parameters
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword Arguments
- dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if
None
, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional) – the desired layout of returned DTensor. Default:
torch.strided
. - requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default:
False
. - device_mesh –
DeviceMesh
type, contains the mesh info of ranks - placements – a sequence of
Placement
type:Shard
,Replicate
Returns
A DTensor object on each rank
Return type
torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
Returns a DTensor filled with uninitialized data. The shape of the DTensoris defined by the variable argument size
.
Parameters
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
Keyword Arguments
- dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if
None
, uses a global default (see torch.set_default_dtype()). layout (torch.layout, optional): the desired layout of returned DTensor. Default:torch.strided
. - requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default:
False
. - device_mesh –
DeviceMesh
type, contains the mesh info of ranks - placements – a sequence of
Placement
type:Shard
,Replicate
Returns
A DTensor object on each rank
Return type
torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]¶
Returns a DTensor filled with fill_value
according to device_mesh
andplacements
, with the shape defined by the argument size
.
Parameters
- size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
- fill_value (Scalar) – the value to fill the output tensor with.
Keyword Arguments
- dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if
None
, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional) – the desired layout of returned DTensor. Default:
torch.strided
. - requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default:
False
. - device_mesh –
DeviceMesh
type, contains the mesh info of ranks. - placements – a sequence of
Placement
type:Shard
,Replicate
Returns
A DTensor object on each rank
Return type
torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
Returns a DTensor filled with random numbers from a uniform distribution on the interval [0, 1)
. The shape of the tensor is defined by the variable argument size
.
Parameters
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword Arguments
- dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if
None
, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional) – the desired layout of returned DTensor. Default:
torch.strided
. - requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default:
False
. - device_mesh –
DeviceMesh
type, contains the mesh info of ranks. - placements – a sequence of
Placement
type:Shard
,Replicate
Returns
A DTensor object on each rank
Return type
torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]¶
Returns a DTensor filled with random numbers from a normal distribution with mean 0 and variance 1. The shape of the tensor is defined by the variable argument size
.
Parameters
size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
Keyword Arguments
- dtype (torch.dtype, optional) – the desired data type of returned DTensor. Default: if
None
, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional) – the desired layout of returned DTensor. Default:
torch.strided
. - requires_grad (bool, optional) – If autograd should record operations on the returned DTensor. Default:
False
. - device_mesh –
DeviceMesh
type, contains the mesh info of ranks. - placements – a sequence of
Placement
type:Shard
,Replicate
Returns
A DTensor object on each rank
Return type
Debugging¶
Logging¶
When launching the program, you can turn on additional logging using the TORCH_LOGS environment variable fromtorch._logging :
- TORCH_LOGS=+dtensor will display logging.DEBUG messages and all levels above it.
- TORCH_LOGS=dtensor will display logging.INFO messages and above.
- TORCH_LOGS=-dtensor will display logging.WARNING messages and above.
Experimental Features¶
DTensor
also provides a set of experimental features. These features are either in prototyping stage, or the basic functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to these features.
torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[source]¶
context_parallel
is an experimental API to enable context parallelism (CP). This API performs two actions: 1) patch the SDPA (torch.nn.functional.scaled_dot_product_attention
) with the CP-enabled one, 2) shard buffers
along the sequence dimension and each rank will preserve the corresponding shard according mesh
.
Parameters
- mesh (
DeviceMesh
) – the device mesh for the context parallelism. - buffers (Optional _[_ _List_ _[_torch.Tensor] ]) – buffers that the usage depend on the sequence dimension. Examples are input batch, labels and positional embedding buffers. These buffers must be sharded along the sequence dimension to ensure the accuracy. The sharding will happen in-place, the buffer’s shape will change within the context. The buffers will be restored after the context finishes.
no_restore_buffers
can be used to specify which buffers don’t need to be restored. Note thatbuffers
should not contain any nn.Parameter. - buffer_seq_dims (Optional _[_ _List_ _[_int] ]) – the sequence dimensions of
buffers
. - no_restore_buffers (Optional _[_ _Set_ _[_torch.Tensor] ]) – buffers in these set won’t be restored after the context exits. This set must be a subset of
buffers
. If the buffers won’t be used after the context exits, these buffers can be put in this list to avoid extra restore time.
Return type
Generator[None, None, None]
Warning
torch.distributed._tensor.experimental.attention.context_parallel is a prototype feature in PyTorch. The API is subject to change.
torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)[source]¶
local_map() is an experimental API that allows users to pass DTensor
s to a function that is written to be applied on torch.Tensor
s. It is done by extracting the local components of DTensor
, call the function, and wrap the outputs toDTensor
according to the out_placements
.
Parameters
- func (Callable) – the function to be applied on each local shard of
DTensor
s. - out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – the desired placements of the
DTensor
s infunc
’s flattened output. If the flattenedoutput
is a single value, theout_placements
should be of type PlacementType. Otherwise if the flattenedoutput
has multiple values, theout_placements
should be a tuple of PlacementType values 1:1 mapping to the flattenedoutput
. Besides, forTensor
output, we use PlacementType as its placements (a Tuple[Placement] value). For non-Tensor output, the PlacementTypeshould be None. Note that the only exception is when noDTensor
argument is passed in. In this case, even if out_placements is not None, the result function should ignore the desired placements because the function is not running withDTensor
s. - in_placements (Tuple[PlacementType, …], optional) – the required placements of the
DTensor
s in the flattened inputs offunc
. Ifin_placements
is specified, local_map() would examine whether the placements of eachDTensor
argument is the same as the required placements or not. If the placements are not the same andredistribute_inputs
isFalse
, an exception will be raised. Otherwise ifredistribute_inputs
isTrue
, the argument will be first redistributed to the required sharding placements before passing its local tensor tofunc
. The only exception is when required placements are notNone
and the argument is a torch.Tensor. In this case, the placements examination will be skipped and the argument will be directly passed tofunc
. Ifin_placements
isNone
, no placements examination will be performed. Default: None - device_mesh (
DeviceMesh
, optional) – the device mesh that all theDTensor
s are placed on. If not specified, this will be inferred from the inputDTensor
s’ device mesh. local_map requires everyDTensor
s to be placed on the same device mesh. Default: None. - redistribute_inputs (bool, optional) – the bool value indicating whether to reshard the input
DTensor
s when their placements are different from the required input placements. If this value isFalse
and someDTensor
input has a different placement, an exception will be raised. Default: False.
Returns
A Callable
that applies func
to each local shard of the input DTensor
and returns a DTensor
constructed from the return value of func
.
Raises
- AssertionError – If the input
DTensor
is not placed on the same device mesh, or if they are placed on a different device mesh than thedevice_mesh
argument passed in. - AssertionError – For any non-DTensor output, we require its corresponding output placement in
out_placements
be None. An AssertionError will be raised if this is not the case. - ValueError – If
redistribute_inputs=False
but the inputDTensor
needs a redistribution according toin_placements
.
Example
def mm_allreduce_forward(device_mesh, W, X): partial_sum_tensor = torch.mm(W, X) reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) return reduced_tensor
W = torch.randn(12, 8, requires_grad=False) X = torch.randn(8, 16, requires_grad=False) Y = torch.mm(W, X) row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion
local_mm_allreduce_forward = local_map( mm_allreduce_forward, out_placements=[Replicate()], in_placements=[col_wise, row_wise], device_mesh=device_mesh, )
W_dt = distribute_tensor( ... W, device_mesh, (col_wise) ... ) # col-wisely sharded W tensor X_dt = distribute_tensor( ... X, device_mesh, (row_wise) ... ) # row-wisely sharded X tensor Y_dt = local_mm_allreduce_forward( ... device_mesh, W_dt, X_dt ... ) # apply local_mm_allreduce_forward to DTensors
Note
This API is currently experimental and subject to change
torch.distributed.tensor.experimental.register_sharding(op)[source]¶
register_sharding() is an experimental API that allows users to register sharding strategies for an operator when the tensor inputs and outputs are DTensor. It can be useful when: (1) there doesn’t exist a default sharding strategy for op
, e.g. when op
is a custom operator that is not supported by DTensor
; (2) when users would like to overwrite default sharding strategies of existing operators.
Parameters
op (Union [ OpOverload , List [ OpOverload ] ]) – An op or a list of ops to register the customized sharding function.
Returns
A function decorator which can be used to wrap a function that defines the sharding strategy for the operator specified in op
. The defined sharding strategy will be registered to DTensor and will override the default sharding strategy if DTensor has already implemented the operator. The customized sharding function takes the same inputs as the original op (except that if an arg is a torch.Tensor, it will be replaced by a tensor-like object that DTensor uses internally). The function should return a sequence of 2-tuples, each specifying acceptable output placements and its corresponding intput placements.
Example
@register_sharding(aten._softmax.default) def custom_softmax_sharding(x, dim, half_to_float): softmax_dim = dim if dim >= 0 else dim + x.ndim acceptable_shardings = []
all_replicate = ([Replicate()], [Replicate(), None, None]) acceptable_shardings.append(all_replicate) for sharding_dim in range(x.ndim): if sharding_dim != softmax_dim: all_sharded = ( [Shard(sharding_dim)], [Shard(sharding_dim), None, None], ) acceptable_shardings.append(all_sharded) return acceptable_shardings
Note
This API is currently experimental and subject to change