torch.distributed.tensor — PyTorch 2.7 documentation (original) (raw)

Note

torch.distributed.tensor is currently in alpha state and under development, we are committing backward compatibility for the most APIs listed in the doc, but there might be API changes if necessary.

PyTorch DTensor (Distributed Tensor)

PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed logic, including sharded storage, operator computation and collective communications across devices/hosts.DTensor could be used to build different paralleism solutions and support sharded state_dict representation when working with multi-dimensional sharding.

Please see examples from the PyTorch native parallelism solutions that are built on top of DTensor:

DTensor follows the SPMD (single program, multiple data) programming model to empower users to write distributed program as if it’s a single-device program with the same convergence property. It provides a uniform tensor sharding layout (DTensor Layout) through specifying the DeviceMeshand Placement:

DTensor Class APIs

DTensor is a torch.Tensor subclass. This means once a DTensor is created, it could be used in very similar way to torch.Tensor, including running different types of PyTorch operators as if running them in a single device, allowing proper distributed computation for PyTorch operators.

In addition to existing torch.Tensor methods, it also offers a set of additional methods to interact withtorch.Tensor, redistribute the DTensor Layout to a new DTensor, get the full tensor content on all devices, etc.

class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)

DTensor (Distributed Tensor) is a subclass of torch.Tensor that provides single-device like abstraction to program with multi-device torch.Tensor. It describes the distributed tensor sharding layout (DTensor Layout) through the DeviceMesh and following types of Placement:

When calling PyTorch operators, DTensor overrides the PyTorch operators to perform sharded computation and issue communications whenever necessary. Along with the operator computation, DTensor will transform or propagate the placements (DTensor Layout) properly (based on the operator semantic itself) and generate new DTensor outputs.

To ensure numerical correctness of the DTensor sharded computation when calling PyTorch operators, DTensorrequires every Tensor argument of the operator be DTensor.

Note

Directly using the Tensor subclass constructor here is not the recommended way to create a DTensor(i.e. it does not handle autograd correctly hence is not the public API). Please refer to the create_dtensorsection to see how to create a DTensor.

Return type

DTensor

__create_chunk_list__()[source][source]

Return a list of ChunkStorageMetadata, which is a dataclass that describes the size/offset of the local shard/replica on current rank. For DTensor, each rank will have a single local shard/replica, so the returned list usually only has one element.

This dunder method is primariy used for distributed checkpoint purpose.

Returns

A List[ChunkStorageMetadata] object that represents the shard size/offset on the current rank.

static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source][source]

Create a DTensor from a local torch.Tensor on each rank according to the device_mesh and placements specified.

Parameters

Keyword Arguments

Returns

A DTensor object

Return type

DTensor

Note

When run_check=False, it is the user’s responsibility to ensure the local tensor passed in is correct across ranks (i.e. the tensor is sharded for the Shard(dim) placement or replicated for the Replicate() placement). If not, the behavior of the created DTensor is undefined.

Note

from_local is differentiable, the requires_grad of the createdDTensor object will depend on if local_tensor requires_grad or not.

full_tensor(*, grad_placements=None)[source][source]

Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate them together. It’s a syntatic sugar of the following code:

dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()

Keyword Arguments

grad_placements (List[Placement], optional) – the placements describes the future layout of any gradient layout of the full Tensor returned from this function.full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensor might not be used as the original replicated DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original replicated DTensor layout. If not specified, we will assume the gradient layout of the full tensor be replicated.

Returns

A torch.Tensor object that represents the full tensor of this DTensor.

Return type

Tensor

Note

full_tensor is differentiable.

redistribute(device_mesh=None, placements=None, *, async_op=False)[source][source]

redistribute performs necessary collective operations that redistribute the current DTensor from its current placements to a new placements, or from is current DeviceMesh to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by specifying a Replicate placement for each dimension of the DeviceMesh.

When redistributing from current to the new placements on one device mesh dimension, we will perform the following operations including communication collective or local operation:

  1. Shard(dim) -> Replicate(): all_gather
  2. Shard(src_dim) -> Shard(dst_dim): all_to_all
  3. Replicate() -> Shard(dim): local chunking (i.e. torch.chunk)
  4. Partial() -> Replicate(): all_reduce
  5. Partial() -> Shard(dim): reduce_scatter

redistribute would correctly figure out the necessary redistribute steps for DTensors that are created either on 1-D or N-D DeviceMesh.

Parameters

Keyword Arguments

async_op (bool, optional) – whether to perform the DTensor redistribute operation asynchronously or not. Default: False

Returns

A DTensor object

Return type

DTensor

Note

redistribute is differentiable, which means user do not need to worry about the backward formula of the redistribute operation.

Note

redistribute currently only supports redistributing DTensor on the same DeviceMesh, Please file an issue if you need to redistribute DTensor to different DeviceMesh.

to_local(*, grad_placements=None)[source][source]

Get the local tensor of this DTensor on its current rank. For sharding it returns a local shard of the logical tensor view, for replication it returns the replica on its current rank.

Keyword Arguments

grad_placements (List[Placement], optional) – the placements describes the future layout of any gradient layout of the Tensor returned from this function.to_local converts DTensor to local tensor and the returned local tensor might not be used as the original DTensor layout later in the code. This argument is the hint that user can give to autograd in case the gradient layout of the returned tensor does not match the original DTensor layout. If not specified, we will assume the gradient layout remains the same as the original DTensor and use that for gradient computation.

Returns

A torch.Tensor or AsyncCollectiveTensor object. it represents the local tensor on its current rank. When an AsyncCollectiveTensor object is returned, it means the local tensor is not ready yet (i.e. communication is not finished). In this case, user needs to call wait to wait the local tensor to be ready.

Return type

Tensor

Note

to_local is differentiable, the requires_grad of the local tensor returned will depend on if the DTensor requires_grad or not.

property device_mesh_: DeviceMesh_

The DeviceMesh attribute that associates with this DTensor object.

Note

device_mesh is a read-only property, it can not be set.

property placements_: tuple[torch.distributed.tensor.placement_types.Placement, ...]_

The placements attribute of this DTensor that describes the layout of this DTensor on the its DeviceMesh.

Note

placements is a read-only property, it can not be set.

DeviceMesh as the distributed communicator

DeviceMesh was built from DTensor as the abstraction to describe cluster’s device topology and represent multi-dimensional communicators (on top of ProcessGroup). To see the details of how to create/use a DeviceMesh, please refer to the DeviceMesh recipe.

DTensor Placement Types

DTensor supports the following types of Placement on each DeviceMesh dimension:

class torch.distributed.tensor.placement_types.Shard(dim)[source][source]

The Shard(dim) placement describes the DTensor sharding on tensor dimensiondim over a corresponding DeviceMesh dimension, where each rank on the DeviceMesh dimension only holds a shard/piece of the global Tensor. TheShard(dim) placement follows the torch.chunk(dim) semantic, where the last few shards on the DeviceMesh dimension might be empty when the tensor dimension is not evenly divisible on the DeviceMesh dimension. The Shard placement can be used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)

Parameters

dim (int) – The tensor dimension that describes the DTensor is sharded over its corresponding DeviceMesh dimension.

Warning

sharding on a tensor dimension where the tensor dimension size is not evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.

dim_: int_

class torch.distributed.tensor.placement_types.Replicate[source][source]

The Replicate() placement describes the DTensor replicating on a correspondingDeviceMesh dimension, where each rank on the DeviceMesh dimension holds a replica of the global Tensor. The Replicate placement can be used by all DTensor APIs (i.e. distribute_tensor, DTensor.from_local, etc.)

class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source][source]

The Partial(reduce_op) placement describes the DTensor that is pending reduction on a specified DeviceMesh dimension, where each rank on the DeviceMesh dimension holds the partial value of the global Tensor. User can redistribute the Partial DTensor to a Replicate or Shard(dim)placement on the specified DeviceMesh dimension using redistribute, which would trigger necessary communication operations under the hood (i.e.allreduce, reduce_scatter).

Parameters

reduce_op (str, optional) – The reduction op to be used for the partial DTensor to produce Replicated/Sharded DTensor. Only element-wise reduction operations are supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”.

Note

The Partial placement can be generated as a result of the DTensor operators, and can only be used by the DTensor.from_local API.

reduce_op_: str_ = 'sum'

class torch.distributed.tensor.placement_types.Placement[source][source]

The base class for the Placement type, where it describes how a DTensor is placed onto theDeviceMesh. Placement and DeviceMesh together could describe the DTensor Layout. It is the base class of the three main DTensor Placement types: Shard, Replicate, and Partial.

This class is not meant to be used directly, mainly served as a typing stub.

is_partial()[source][source]

Return type

bool

is_replicate()[source][source]

Return type

bool

is_shard(dim=None)[source][source]

Return type

bool

Different ways to create a DTensor

There’re three ways to construct a DTensor:

Create DTensor from a logical torch.Tensor

The SPMD (single program, multiple data) programming model in torch.distributed launches multiple processes (i.e. via torchrun) to execute the same program, this means that the model inside the program would be initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly on GPU if enough memory).

DTensor offers a distribute_tensor() API that could shard the model weights or Tensors to DTensor s, where it would create a DTensor from the “logical” Tensor on each process. This would empower the createdDTensor s to comply with the single device semantic, which is critical for numerical correctness.

torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)[source]

Distribute a leaf torch.Tensor (i.e. nn.Parameter/buffers) to the device_mesh according to the placements specified. The rank of device_mesh and placements must be the same. The tensor to distribute is the logical or “global” tensor, and the API would use the tensor from first rank of the DeviceMesh dimension as the source of truth to preserve the single-device semantic. If you want to construct a DTensor in the middle of the Autograd computation, please use DTensor.from_local() instead.

Parameters

Keyword Arguments

src_data_rank (int, optional) – the rank of the source data for the logical/global tensor, it is used by distribute_tensor() to scatter/broadcast the shards/replicas to other ranks. By default, we use group_rank=0 on each DeviceMesh dimension as the source data to preserve the single-device semantic. If passing None explicitly, distribute_tensor() simply uses its local data instead of trying to preserve the single-device semantic via scatter/broadcast. Default: 0

Returns

A DTensor or XLAShardedTensor object.

Return type

DTensor

Note

When initialize the DeviceMesh with the xla device_type, distribute_tensorreturn XLAShardedTensor instead. see this issuefor more details. The XLA integration is experimental and subject to change.

Along with distribute_tensor(), DTensor also offers a distribute_module() API to allow easier sharding on the nn.Module level

torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]

This function expose three functions to control the parameters/inputs/outputs of the module:

1. To perform sharding on the module before runtime execution by specifying thepartition_fn (i.e. allow user to convert Module parameters to DTensorparameters according to the partition_fn specified). 2. To control the inputs or outputs of the module during runtime execution by specifying the input_fn and output_fn. (i.e. convert the input toDTensor, convert the output back to torch.Tensor)

Parameters

Returns

A module that contains parameters/buffers that are all DTensor s.

Return type

Module

Note

When initialize the DeviceMesh with the xla device_type, distribute_modulereturn nn.Module with PyTorch/XLA SPMD annotated parameters. Seethis issuefor more details. The XLA integration is experimental and subject to change.

DTensor Factory Functions

DTensor also provides dedicated tensor factory functions to allow creating DTensor directly using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally specifying the DeviceMesh and Placement for the DTensor created:

torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]

Returns a DTensor filled with the scalar value 0.

Parameters

size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))

Keyword Arguments

Returns

A DTensor object on each rank

Return type

DTensor

torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]

Returns a DTensor filled with the scalar value 1, with the shape defined by the variable argument size.

Parameters

size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))

Keyword Arguments

Returns

A DTensor object on each rank

Return type

DTensor

torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]

Returns a DTensor filled with uninitialized data. The shape of the DTensoris defined by the variable argument size.

Parameters

size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))

Keyword Arguments

Returns

A DTensor object on each rank

Return type

DTensor

torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]

Returns a DTensor filled with fill_value according to device_mesh andplacements, with the shape defined by the argument size.

Parameters

Keyword Arguments

Returns

A DTensor object on each rank

Return type

DTensor

torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]

Returns a DTensor filled with random numbers from a uniform distribution on the interval [0, 1). The shape of the tensor is defined by the variable argument size.

Parameters

size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))

Keyword Arguments

Returns

A DTensor object on each rank

Return type

DTensor

torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]

Returns a DTensor filled with random numbers from a normal distribution with mean 0 and variance 1. The shape of the tensor is defined by the variable argument size.

Parameters

size (int...) – a sequence of integers defining the shape of the output DTensor. Can be a variable number of arguments or a collection like a list or tuple. E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))

Keyword Arguments

Returns

A DTensor object on each rank

Return type

DTensor

Debugging

Logging

When launching the program, you can turn on additional logging using the TORCH_LOGS environment variable fromtorch._logging :

Experimental Features

DTensor also provides a set of experimental features. These features are either in prototyping stage, or the basic functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to these features.

torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[source]

context_parallel is an experimental API to enable context parallelism (CP). This API performs two actions: 1) patch the SDPA (torch.nn.functional.scaled_dot_product_attention) with the CP-enabled one, 2) shard buffers along the sequence dimension and each rank will preserve the corresponding shard according mesh.

Parameters

Return type

Generator[None, None, None]

Warning

torch.distributed._tensor.experimental.attention.context_parallel is a prototype feature in PyTorch. The API is subject to change.

torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)[source]

local_map() is an experimental API that allows users to pass DTensor s to a function that is written to be applied on torch.Tensor s. It is done by extracting the local components of DTensor, call the function, and wrap the outputs toDTensor according to the out_placements.

Parameters

Returns

A Callable that applies func to each local shard of the input DTensorand returns a DTensor constructed from the return value of func.

Raises

Example

def mm_allreduce_forward(device_mesh, W, X): partial_sum_tensor = torch.mm(W, X) reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) return reduced_tensor

W = torch.randn(12, 8, requires_grad=False) X = torch.randn(8, 16, requires_grad=False) Y = torch.mm(W, X) row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh

local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion

local_mm_allreduce_forward = local_map( mm_allreduce_forward, out_placements=[Replicate()], in_placements=[col_wise, row_wise], device_mesh=device_mesh, )

W_dt = distribute_tensor( ... W, device_mesh, (col_wise) ... ) # col-wisely sharded W tensor X_dt = distribute_tensor( ... X, device_mesh, (row_wise) ... ) # row-wisely sharded X tensor Y_dt = local_mm_allreduce_forward( ... device_mesh, W_dt, X_dt ... ) # apply local_mm_allreduce_forward to DTensors

Note

This API is currently experimental and subject to change

torch.distributed.tensor.experimental.register_sharding(op)[source]

register_sharding() is an experimental API that allows users to register sharding strategies for an operator when the tensor inputs and outputs are DTensor. It can be useful when: (1) there doesn’t exist a default sharding strategy for op, e.g. when op is a custom operator that is not supported by DTensor; (2) when users would like to overwrite default sharding strategies of existing operators.

Parameters

op (Union [ OpOverload , List [ OpOverload ] ]) – An op or a list of ops to register the customized sharding function.

Returns

A function decorator which can be used to wrap a function that defines the sharding strategy for the operator specified in op. The defined sharding strategy will be registered to DTensor and will override the default sharding strategy if DTensor has already implemented the operator. The customized sharding function takes the same inputs as the original op (except that if an arg is a torch.Tensor, it will be replaced by a tensor-like object that DTensor uses internally). The function should return a sequence of 2-tuples, each specifying acceptable output placements and its corresponding intput placements.

Example

@register_sharding(aten._softmax.default) def custom_softmax_sharding(x, dim, half_to_float): softmax_dim = dim if dim >= 0 else dim + x.ndim acceptable_shardings = []

all_replicate = ([Replicate()], [Replicate(), None, None])
acceptable_shardings.append(all_replicate)

for sharding_dim in range(x.ndim):
    if sharding_dim != softmax_dim:
        all_sharded = (
            [Shard(sharding_dim)],
            [Shard(sharding_dim), None, None],
        )
        acceptable_shardings.append(all_sharded)

return acceptable_shardings

Note

This API is currently experimental and subject to change