Distributed communication package - torch.distributed — PyTorch 2.7 documentation (original) (raw)

Backends

torch.distributed supports three built-in backends, each with different capabilities. The table below shows which functions are available for use with CPU / CUDA tensors. MPI supports CUDA only if the implementation used to build PyTorch supports it.

Backend gloo mpi nccl
Device CPU GPU CPU GPU CPU GPU
send ?
recv ?
broadcast ?
all_reduce ?
reduce ?
all_gather ?
gather ?
scatter ?
reduce_scatter
all_to_all ?
barrier ?

Backends that come with PyTorch

PyTorch distributed package supports Linux (stable), MacOS (stable), and Windows (prototype). By default for Linux, the Gloo and NCCL backends are built and included in PyTorch distributed (NCCL only when building with CUDA). MPI is an optional backend that can only be included if you build PyTorch from source. (e.g. building PyTorch on a host that has MPI installed.)

Note

As of PyTorch v1.8, Windows supports all collective communications backend but NCCL, If the init_method argument of init_process_group() points to a file it must adhere to the following schema:

Same as on Linux platform, you can enable TcpStore by setting environment variables, MASTER_ADDR and MASTER_PORT.

Which backend to use?

In the past, we were often asked: “which backend should I use?”.

Common environment variables

Choosing the network interface to use

By default, both the NCCL and Gloo backends will try to find the right network interface to use. If the automatically detected interface is not correct, you can override it using the following environment variables (applicable to the respective backend):

If you’re using the Gloo backend, you can specify multiple interfaces by separating them by a comma, like this: export GLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3. The backend will dispatch operations in a round-robin fashion across these interfaces. It is imperative that all processes specify the same number of interfaces in this variable.

Other NCCL environment variables

Debugging - in case of NCCL failure, you can set NCCL_DEBUG=INFO to print an explicit warning message as well as basic NCCL initialization information.

You may also use NCCL_DEBUG_SUBSYS to get more details about a specific aspect of NCCL. For example, NCCL_DEBUG_SUBSYS=COLL would print logs of collective calls, which may be helpful when debugging hangs, especially those caused by collective type or message size mismatch. In case of topology detection failure, it would be helpful to set NCCL_DEBUG_SUBSYS=GRAPHto inspect the detailed detection result and save as reference if further help from NCCL team is needed.

Performance tuning - NCCL performs automatic tuning based on its topology detection to save users’ tuning effort. On some socket-based systems, users may still try tuningNCCL_SOCKET_NTHREADS and NCCL_NSOCKS_PERTHREAD to increase socket network bandwidth. These two environment variables have been pre-tuned by NCCL for some cloud providers, such as AWS or GCP.

For a full list of NCCL environment variables, please refer toNVIDIA NCCL’s official documentation

Basics

The torch.distributed package provides PyTorch support and communication primitives for multiprocess parallelism across several computation nodes running on one or more machines. The class torch.nn.parallel.DistributedDataParallel() builds on this functionality to provide synchronous distributed training as a wrapper around any PyTorch model. This differs from the kinds of parallelism provided byMultiprocessing package - torch.multiprocessing and torch.nn.DataParallel() in that it supports multiple network-connected machines and in that the user must explicitly launch a separate copy of the main training script for each process.

In the single-machine synchronous case, torch.distributed or thetorch.nn.parallel.DistributedDataParallel() wrapper may still have advantages over other approaches to data-parallelism, including torch.nn.DataParallel():

Initialization

The package needs to be initialized using the torch.distributed.init_process_group()or torch.distributed.device_mesh.init_device_mesh() function before calling any other methods. Both block until all processes have joined.

Warning

Initialization is not thread-safe. Process group creation should be performed from a single thread, to prevent inconsistent ‘UUID’ assignment across ranks, and to prevent races during initialization that can lead to hangs.

torch.distributed.is_available()[source][source]

Return True if the distributed package is available.

Otherwise,torch.distributed does not expose any other APIs. Currently,torch.distributed is available on Linux, MacOS and Windows. SetUSE_DISTRIBUTED=1 to enable it when building PyTorch from source. Currently, the default value is USE_DISTRIBUTED=1 for Linux and Windows,USE_DISTRIBUTED=0 for MacOS.

Return type

bool

torch.distributed.init_process_group(backend=None, init_method=None, timeout=None, world_size=-1, rank=-1, store=None, group_name='', pg_options=None, device_id=None)[source][source]

Initialize the default distributed process group.

This will also initialize the distributed package.

There are 2 main ways to initialize a process group:

  1. Specify store, rank, and world_size explicitly.
  2. Specify init_method (a URL string) which indicates where/how to discover peers. Optionally specify rank and world_size, or encode all required parameters in the URL and omit them.

If neither is specified, init_method is assumed to be “env://”.

Parameters

Note

To enable backend == Backend.MPI, PyTorch needs to be built from source on a system that supports MPI.

Note

Support for multiple backends is experimental. Currently when no backend is specified, both gloo and nccl backends will be created. The gloo backend will be used for collectives with CPU tensors and the nccl backend will be used for collectives with CUDA tensors. A custom backend can be specified by passing in a string with format “<device_type>:<backend_name>,<device_type>:<backend_name>”, e.g. “cpu:gloo,cuda:custom_backend”.

torch.distributed.device_mesh.init_device_mesh(device_type, mesh_shape, *, mesh_dim_names=None)[source][source]

Initializes a DeviceMesh based on device_type, mesh_shape, and mesh_dim_names parameters.

This creates a DeviceMesh with an n-dimensional array layout, where n is the length of mesh_shape. If mesh_dim_names is provided, each dimension is labeled as mesh_dim_names[i].

Note

init_device_mesh follows SPMD programming model, meaning the same PyTorch Python program runs on all processes/ranks in the cluster. Ensure mesh_shape (the dimensions of the nD array describing device layout) is identical across all ranks. Inconsistent mesh_shape may lead to hanging.

Note

If no process group is found, init_device_mesh will initialize distributed process group/groups required for distributed communications behind the scene.

Parameters

Returns

A DeviceMesh object representing the device layout.

Return type

DeviceMesh

Example::

from torch.distributed.device_mesh import init_device_mesh

mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))

torch.distributed.is_initialized()[source][source]

Check if the default process group has been initialized.

Return type

bool

torch.distributed.is_mpi_available()[source][source]

Check if the MPI backend is available.

Return type

bool

torch.distributed.is_nccl_available()[source][source]

Check if the NCCL backend is available.

Return type

bool

torch.distributed.is_gloo_available()[source][source]

Check if the Gloo backend is available.

Return type

bool

torch.distributed.distributed_c10d.is_xccl_available()[source][source]

Check if the XCCL backend is available.

Return type

bool

torch.distributed.is_torchelastic_launched()[source][source]

Check whether this process was launched with torch.distributed.elastic (aka torchelastic).

The existence of TORCHELASTIC_RUN_ID environment variable is used as a proxy to determine whether the current process was launched with torchelastic. This is a reasonable proxy sinceTORCHELASTIC_RUN_ID maps to the rendezvous id which is always a non-null value indicating the job id for peer discovery purposes..

Return type

bool


Currently three initialization methods are supported:

TCP initialization

There are two ways to initialize using TCP, both requiring a network address reachable from all processes and a desired world_size. The first way requires specifying an address that belongs to the rank 0 process. This initialization method requires that all processes have manually specified ranks.

Note that multicast address is not supported anymore in the latest distributed package. group_name is deprecated as well.

import torch.distributed as dist

Use address of one of the machines

dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4)

Environment variable initialization

This method will read the configuration from environment variables, allowing one to fully customize how the information is obtained. The variables to be set are:

The machine with rank 0 will be used to set up all connections.

This is the default method, meaning that init_method does not have to be specified (or can be env://).

Post-Initialization

Once torch.distributed.init_process_group() was run, the following functions can be used. To check whether the process group has already been initialized use torch.distributed.is_initialized().

class torch.distributed.Backend(name)[source][source]

An enum-like class for backends.

Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends.

The values of this class are lowercase strings, e.g., "gloo". They can be accessed as attributes, e.g., Backend.NCCL.

This class can be directly called to parse the string, e.g.,Backend(backend_str) will check if backend_str is valid, and return the parsed lowercase string if so. It also accepts uppercase strings, e.g., Backend("GLOO") returns "gloo".

Note

The entry Backend.UNDEFINED is present but only used as initial value of some fields. Users should neither use it directly nor assume its existence.

classmethod register_backend(name, func, extended_api=False, devices=None)[source][source]

Register a new backend with the given name and instantiating function.

This class method is used by 3rd party ProcessGroup extension to register new backends.

Parameters

Note

This support of 3rd party backend is experimental and subject to change.

torch.distributed.get_backend(group=None)[source][source]

Return the backend of the given process group.

Parameters

group (ProcessGroup , optional) – The process group to work on. The default is the general main process group. If another specific group is specified, the calling process must be part of group.

Returns

The backend of the given process group as a lower case string.

Return type

Backend

torch.distributed.get_rank(group=None)[source][source]

Return the rank of the current process in the provided group, default otherwise.

Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 toworld_size.

Parameters

group (ProcessGroup , optional) – The process group to work on. If None, the default process group will be used.

Returns

The rank of the process group -1, if not part of the group

Return type

int

torch.distributed.get_world_size(group=None)[source][source]

Return the number of processes in the current process group.

Parameters

group (ProcessGroup , optional) – The process group to work on. If None, the default process group will be used.

Returns

The world size of the process group -1, if not part of the group

Return type

int

Shutdown

It is important to clean up resources on exit by calling destroy_process_group().

The simplest pattern to follow is to destroy every process group and backend by callingdestroy_process_group() with the default value of None for the group argument, at a point in the training script where communications are no longer needed, usually near the end of main(). The call should be made once per trainer-process, not at the outer process-launcher level.

if destroy_process_group() is not called by all ranks in a pg within the timeout duration, especially when there are multiple process-groups in the application e.g. for N-D parallelism, hangs on exit are possible. This is because the destructor for ProcessGroupNCCL calls ncclCommAbort, which must be called collectively, but the order of calling ProcessGroupNCCL’s destructor if called by python’s GC is not deterministic. Calling destroy_process_group() helps by ensuring ncclCommAbort is called in a consistent order across ranks, and avoids calling ncclCommAbort during ProcessGroupNCCL’s destructor.

Reinitialization

destroy_process_group can also be used to destroy individual process groups. One use case could be fault tolerant training, where a process group may be destroyed and then a new one initialized during runtime. In this case, it’s critical to synchronize the trainer processes using some means other than torch.distributed primitives _after_ calling destroy and before subsequently initializing. This behavior is currently unsupported/untested, due to the difficulty of achieving this synchronization, and is considered a known issue. Please file a github issue or RFC if this is a use case that’s blocking you.


Groups

By default collectives operate on the default group (also called the world) and require all processes to enter the distributed function call. However, some workloads can benefit from more fine-grained communication. This is where distributed groups come into play. new_group() function can be used to create new groups, with arbitrary subsets of all processes. It returns an opaque group handle that can be given as a group argument to all collectives (collectives are distributed functions to exchange information in certain well-known programming patterns).

torch.distributed.new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None, device_id=None)[source][source]

Create a new distributed group.

This function requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes.

Warning

Safe concurrent usage: When using multiple process groups with the NCCL backend, the user must ensure a globally consistent execution order of collectives across ranks.

If multiple threads within a process issue collectives, explicit synchronization is necessary to ensure consistent ordering.

When using async variants of torch.distributed communication APIs, a work object is returned and the communication kernel is enqueued on a separate CUDA stream, allowing overlap of communication and computation. Once one or more async ops have been issued on one process group, they must be synchronized with other cuda streams by calling work.wait()before using another process group.

See Using multiple NCCL communicators concurrently for more details.

Parameters

Returns

A handle of distributed group that can be given to collective calls or GroupMember.NON_GROUP_MEMBER if the rank is not part of ranks.

N.B. use_local_synchronization doesn’t work with MPI.

N.B. While use_local_synchronization=True can be significantly faster with larger clusters and small process groups, care must be taken since it changes cluster behavior as non-member ranks don’t join the group barrier().

N.B. use_local_synchronization=True can lead to deadlocks when each rank creates multiple overlaping process groups. To avoid that, make sure all ranks follow the same global creation order.

torch.distributed.get_group_rank(group, global_rank)[source][source]

Translate a global rank into a group rank.

global_rank must be part of group otherwise this raises RuntimeError.

Parameters

Returns

Group rank of global_rank relative to group

Return type

int

N.B. calling this function on the default process group returns identity

torch.distributed.get_global_rank(group, group_rank)[source][source]

Translate a group rank into a global rank.

group_rank must be part of group otherwise this raises RuntimeError.

Parameters

Returns

Global rank of group_rank relative to group

Return type

int

N.B. calling this function on the default process group returns identity

torch.distributed.get_process_group_ranks(group)[source][source]

Get all ranks associated with group.

Parameters

group (ProcessGroup) – ProcessGroup to get all ranks from.

Returns

List of global ranks ordered by group rank.

Return type

list[int]

DeviceMesh

DeviceMesh is a higher level abstraction that manages process groups (or NCCL communicators). It allows user to easily create inter node and intra node process groups without worrying about how to set up the ranks correctly for different sub process groups, and it helps manage those distributed process group easily. init_device_mesh() function can be used to create new DeviceMesh, with a mesh shape describing the device topology.

class torch.distributed.device_mesh.DeviceMesh(device_type, mesh, *, mesh_dim_names=None, _init_backend=True)[source][source]

DeviceMesh represents a mesh of devices, where layout of devices could be represented as a n-d dimension array, and each value of the n-d dimensional array is the global id of the default process group ranks.

DeviceMesh could be used to describe the layout of devices across the cluster, and serves as a proxy for communication among the device lists within the cluster.

DeviceMesh can be used as a context manager.

Note

DeviceMesh follows SPMD programming model, which means the same PyTorch Python program is running on all processes/ranks in the cluster. Therefore, users need to make sure themesh array (which describes the layout of devices) should be identical across all ranks. Inconsistent mesh will lead to silent hang.

Parameters

Returns

A DeviceMesh object representing the device layout.

Return type

DeviceMesh

The following program runs on each process/rank in an SPMD manner. In this example, we have 2 hosts with 4 GPUs each. A reduction over the first dimension of mesh will reduce across columns (0, 4), .. and (3, 7), a reduction over the second dimension of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).

Example::

from torch.distributed.device_mesh import DeviceMesh

Initialize device mesh as (2, 4) to represent the topology

of cross-host(dim 0), and within-host (dim 1).

mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])

static from_group(group, device_type, mesh=None, *, mesh_dim_names=None)[source][source]

Constructs a DeviceMesh with device_type from an existing ProcessGroup or a list of existing ProcessGroup.

The constructed device mesh has number of dimensions equal to the number of groups passed. For example, if a single process group is passed in, the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in, the resulted DeviceMesh is a 2D mesh.

If more than one group is passed, then the mesh and mesh_dim_names arguments are required. The order of the process groups passed in determines the topology of the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh. The mesh tensor passed in must have the same number of dimensions as the number of process groups passed in, and the order of the dimensions in the mesh tensor must match the order in the process groups passed in.

Parameters

Returns

A DeviceMesh object representing the device layout.

Return type

DeviceMesh

get_all_groups()[source][source]

Returns a list of ProcessGroups for all mesh dimensions.

Returns

A list of ProcessGroup object.

Return type

list[torch.distributed.distributed_c10d.ProcessGroup]

get_coordinate()[source][source]

Return the relative indices of this rank relative to all dimensions of the mesh. If this rank is not part of the mesh, return None.

Return type

Optional[list[int]]

get_group(mesh_dim=None)[source][source]

Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.

Parameters

Returns

A ProcessGroup object.

Return type

ProcessGroup

get_local_rank(mesh_dim=None)[source][source]

Returns the local rank of the given mesh_dim of the DeviceMesh.

Parameters

Returns

An integer denotes the local rank.

Return type

int

The following program runs on each process/rank in an SPMD manner. In this example, we have 2 hosts with 4 GPUs each. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.

Example::

from torch.distributed.device_mesh import DeviceMesh

Initialize device mesh as (2, 4) to represent the topology

of cross-host(dim 0), and within-host (dim 1).

mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])

get_rank()[source][source]

Returns the current global rank.

Return type

int

Point-to-point communication

torch.distributed.send(tensor, dst=None, group=None, tag=0, group_dst=None)[source][source]

Send a tensor synchronously.

Warning

tag is not supported with the NCCL backend.

Parameters

torch.distributed.recv(tensor, src=None, group=None, tag=0, group_src=None)[source][source]

Receives a tensor synchronously.

Warning

tag is not supported with the NCCL backend.

Parameters

Returns

Sender rank -1, if not part of the group

Return type

int

isend() and irecv()return distributed request objects when used. In general, the type of this object is unspecified as they should never be created manually, but they are guaranteed to support two methods:

torch.distributed.isend(tensor, dst=None, group=None, tag=0, group_dst=None)[source][source]

Send a tensor asynchronously.

Warning

Modifying tensor before the request completes causes undefined behavior.

Warning

tag is not supported with the NCCL backend.

Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self.

Parameters

Returns

A distributed request object. None, if not part of the group

Return type

Optional[Work]

torch.distributed.irecv(tensor, src=None, group=None, tag=0, group_src=None)[source][source]

Receives a tensor asynchronously.

Warning

tag is not supported with the NCCL backend.

Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self.

Parameters

Returns

A distributed request object. None, if not part of the group

Return type

Optional[Work]

torch.distributed.send_object_list(object_list, dst=None, group=None, device=None, group_dst=None)[source][source]

Sends picklable objects in object_list synchronously.

Similar to send(), but Python objects can be passed in. Note that all objects in object_list must be picklable in order to be sent.

Parameters

Returns

None.

Note

For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Warning

send_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Warning

Calling send_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using send() instead.

Example::

Note: Process group initialization omitted on each rank.

import torch.distributed as dist

Assumes backend is not NCCL

device = torch.device("cpu") if dist.get_rank() == 0: # Assumes world_size of 2. objects = ["foo", 12, {1: 2}] # any picklable object dist.send_object_list(objects, dst=1, device=device) else: objects = [None, None, None] dist.recv_object_list(objects, src=0, device=device) objects ['foo', 12, {1: 2}]

torch.distributed.recv_object_list(object_list, src=None, group=None, device=None, group_src=None)[source][source]

Receives picklable objects in object_list synchronously.

Similar to recv(), but can receive Python objects.

Parameters

Returns

Sender rank. -1 if rank is not part of the group. If rank is part of the group,object_list will contain the sent objects from src rank.

Note

For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Warning

recv_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Warning

Calling recv_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using recv() instead.

Example::

Note: Process group initialization omitted on each rank.

import torch.distributed as dist

Assumes backend is not NCCL

device = torch.device("cpu") if dist.get_rank() == 0: # Assumes world_size of 2. objects = ["foo", 12, {1: 2}] # any picklable object dist.send_object_list(objects, dst=1, device=device) else: objects = [None, None, None] dist.recv_object_list(objects, src=0, device=device) objects ['foo', 12, {1: 2}]

torch.distributed.batch_isend_irecv(p2p_op_list)[source][source]

Send or Receive a batch of tensors asynchronously and return a list of requests.

Process each of the operations in p2p_op_list and return the corresponding requests. NCCL, Gloo, and UCC backend are currently supported.

Parameters

p2p_op_list (list_[_torch.distributed.distributed_c10d.P2POp]) – A list of point-to-point operations(type of each operator istorch.distributed.P2POp). The order of the isend/irecv in the list matters and it needs to match with corresponding isend/irecv on the remote end.

Returns

A list of distributed request objects returned by calling the corresponding op in the op_list.

Return type

list[torch.distributed.distributed_c10d.Work]

Examples

send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank recv_tensor = torch.randn(2, dtype=torch.float32) send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size) recv_op = dist.P2POp( ... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size ... ) reqs = batch_isend_irecv([send_op, recv_op]) for req in reqs: req.wait() recv_tensor tensor([2, 3]) # Rank 0 tensor([0, 1]) # Rank 1

Note

Note that when this API is used with the NCCL PG backend, users must set the current GPU device with torch.cuda.set_device, otherwise it will lead to unexpected hang issues.

In addition, if this API is the first collective call in the grouppassed to dist.P2POp, all ranks of the group must participate in this API call; otherwise, the behavior is undefined. If this API call is not the first collective call in the group, batched P2P operations involving only a subset of ranks of the group are allowed.

class torch.distributed.P2POp(op, tensor, peer=None, group=None, tag=0, group_peer=None)[source][source]

A class to build point-to-point operations for batch_isend_irecv.

This class builds the type of P2P operation, communication buffer, peer rank, Process Group, and tag. Instances of this class will be passed tobatch_isend_irecv for point-to-point communications.

Parameters

Synchronous and asynchronous collective operations

Every collective operation function supports the following two kinds of operations, depending on the setting of the async_op flag passed into the collective:

Synchronous operation - the default mode, when async_op is set to False. When the function returns, it is guaranteed that the collective operation is performed. In the case of CUDA operations, it is not guaranteed that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream synchronization, see CUDA Semantics. See the below script to see examples of differences in these semantics for CPU and CUDA operations.

Asynchronous operation - when async_op is set to True. The collective operation function returns a distributed request object. In general, you don’t need to create it manually and it is guaranteed to support two methods:

Example

The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. It shows the explicit need to synchronize when using collective outputs on different CUDA streams:

Code runs on each rank.

dist.init_process_group("nccl", rank=rank, world_size=2) output = torch.tensor([rank]).cuda(rank) s = torch.cuda.Stream() handle = dist.all_reduce(output, async_op=True)

Wait ensures the operation is enqueued, but not necessarily complete.

handle.wait()

Using result on non-default stream.

with torch.cuda.stream(s): s.wait_stream(torch.cuda.default_stream()) output.add_(100) if rank == 0: # if the explicit call to wait_stream was omitted, the output below will be # non-deterministically 1 or 101, depending on whether the allreduce overwrote # the value after the add completed. print(output)

Collective functions

torch.distributed.broadcast(tensor, src=None, group=None, async_op=False, group_src=None)[source][source]

Broadcasts the tensor to the whole group.

tensor must have the same number of elements in all processes participating in the collective.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

torch.distributed.broadcast_object_list(object_list, src=None, group=None, device=None, group_src=None)[source][source]

Broadcasts picklable objects in object_list to the whole group.

Similar to broadcast(), but Python objects can be passed in. Note that all objects in object_list must be picklable in order to be broadcasted.

Parameters

Returns

None. If rank is part of the group, object_list will contain the broadcasted objects from src rank.

Note

For NCCL-based process groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsibility to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Note

Note that this API differs slightly from the broadcast()collective since it does not provide an async_op handle and thus will be a blocking call.

Warning

broadcast_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Warning

Calling broadcast_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using broadcast() instead.

Example::

Note: Process group initialization omitted on each rank.

import torch.distributed as dist if dist.get_rank() == 0: # Assumes world_size of 3. objects = ["foo", 12, {1: 2}] # any picklable object else: objects = [None, None, None]

Assumes backend is not NCCL

device = torch.device("cpu") dist.broadcast_object_list(objects, src=0, device=device) objects ['foo', 12, {1: 2}]

torch.distributed.all_reduce(tensor, op=<RedOpType.SUM: 0>, group=None, async_op=False)[source][source]

Reduces the tensor data across all machines in a way that all get the final result.

After the call tensor is going to be bitwise identical in all processes.

Complex tensors are supported.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

Examples

All tensors below are of torch.int64 type.

We have 2 process groups, 2 ranks.

device = torch.device(f"cuda:{rank}") tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank tensor tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1 dist.all_reduce(tensor, op=ReduceOp.SUM) tensor tensor([4, 6], device='cuda:0') # Rank 0 tensor([4, 6], device='cuda:1') # Rank 1

All tensors below are of torch.cfloat type.

We have 2 process groups, 2 ranks.

tensor = torch.tensor( ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device ... ) + 2 * rank * (1 + 1j) tensor tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 dist.all_reduce(tensor, op=ReduceOp.SUM) tensor tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0 tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1

torch.distributed.reduce(tensor, dst=None, op=<RedOpType.SUM: 0>, group=None, async_op=False, group_dst=None)[source][source]

Reduces the tensor data across all machines.

Only the process with rank dst is going to receive the final result.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

torch.distributed.all_gather(tensor_list, tensor, group=None, async_op=False)[source][source]

Gathers tensors from the whole group in a list.

Complex and uneven sized tensors are supported.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

Examples

All tensors below are of torch.int64 dtype.

We have 2 process groups, 2 ranks.

device = torch.device(f"cuda:{rank}") tensor_list = [ ... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2) ... ] tensor_list [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0 [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1 tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank tensor tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1 dist.all_gather(tensor_list, tensor) tensor_list [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0 [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1

All tensors below are of torch.cfloat dtype.

We have 2 process groups, 2 ranks.

tensor_list = [ ... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2) ... ] tensor_list [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0 [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1 tensor = torch.tensor( ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device ... ) + 2 * rank * (1 + 1j) tensor tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 dist.all_gather(tensor_list, tensor) tensor_list [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0 [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1

torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False)[source][source]

Gather tensors from all ranks and put them in a single output tensor.

This function requires all tensors to be the same size on each process.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

Examples

All tensors below are of torch.int64 dtype and on CUDA devices.

We have two ranks.

device = torch.device(f"cuda:{rank}") tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank tensor_in tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1

Output in concatenation form

tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device) dist.all_gather_into_tensor(tensor_out, tensor_in) tensor_out tensor([1, 2, 3, 4], device='cuda:0') # Rank 0 tensor([1, 2, 3, 4], device='cuda:1') # Rank 1

Output in stack form

tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device) dist.all_gather_into_tensor(tensor_out2, tensor_in) tensor_out2 tensor([[1, 2], [3, 4]], device='cuda:0') # Rank 0 tensor([[1, 2], [3, 4]], device='cuda:1') # Rank 1

Warning

The Gloo backend does not support this API.

torch.distributed.all_gather_object(object_list, obj, group=None)[source][source]

Gathers picklable objects from the whole group into a list.

Similar to all_gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered.

Parameters

Returns

None. If the calling rank is part of this group, the output of the collective will be populated into the input object_list. If the calling rank is not part of the group, the passed in object_list will be unmodified.

Note

Note that this API differs slightly from the all_gather()collective since it does not provide an async_op handle and thus will be a blocking call.

Note

For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsiblity to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Warning

all_gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Warning

Calling all_gather_object() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using all_gather() instead.

Example::

Note: Process group initialization omitted on each rank.

import torch.distributed as dist

Assumes world_size of 3.

gather_objects = ["foo", 12, {1: 2}] # any picklable object output = [None for _ in gather_objects] dist.all_gather_object(output, gather_objects[dist.get_rank()]) output ['foo', 12, {1: 2}]

torch.distributed.gather(tensor, gather_list=None, dst=None, group=None, async_op=False, group_dst=None)[source][source]

Gathers a list of tensors in a single process.

This function requires all tensors to be the same size on each process.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

Note

Note that all Tensors in gather_list must have the same size.

Example::

We have 2 process groups, 2 ranks.

tensor_size = 2 device = torch.device(f'cuda:{rank}') tensor = torch.ones(tensor_size, device=device) + rank if dist.get_rank() == 0: gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)] else: gather_list = None dist.gather(tensor, gather_list, dst=0)

Rank 0 gets gathered data.

gather_list [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0 None # Rank 1

torch.distributed.gather_object(obj, object_gather_list=None, dst=None, group=None, group_dst=None)[source][source]

Gathers picklable objects from the whole group in a single process.

Similar to gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered.

Parameters

Returns

None. On the dst rank, object_gather_list will contain the output of the collective.

Note

Note that this API differs slightly from the gather collective since it does not provide an async_op handle and thus will be a blocking call.

Note

For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given bytorch.cuda.current_device() and it is the user’s responsiblity to ensure that this is set so that each rank has an individual GPU, viatorch.cuda.set_device().

Warning

gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Warning

Calling gather_object() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using gather() instead.

Example::

Note: Process group initialization omitted on each rank.

import torch.distributed as dist

Assumes world_size of 3.

gather_objects = ["foo", 12, {1: 2}] # any picklable object output = [None for _ in gather_objects] dist.gather_object( ... gather_objects[dist.get_rank()], ... output if dist.get_rank() == 0 else None, ... dst=0 ... )

On rank 0

output ['foo', 12, {1: 2}]

torch.distributed.scatter(tensor, scatter_list=None, src=None, group=None, async_op=False, group_src=None)[source][source]

Scatters a list of tensors to all processes in a group.

Each process will receive exactly one tensor and store its data in thetensor argument.

Complex tensors are supported.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

Note

Note that all Tensors in scatter_list must have the same size.

Example::

Note: Process group initialization omitted on each rank.

import torch.distributed as dist tensor_size = 2 device = torch.device(f'cuda:{rank}') output_tensor = torch.zeros(tensor_size, device=device) if dist.get_rank() == 0: # Assumes world_size of 2. # Only tensors, all of which must be the same size. t_ones = torch.ones(tensor_size, device=device) t_fives = torch.ones(tensor_size, device=device) * 5 scatter_list = [t_ones, t_fives] else: scatter_list = None dist.scatter(output_tensor, scatter_list, src=0)

Rank i gets scatter_list[i].

output_tensor tensor([1., 1.], device='cuda:0') # Rank 0 tensor([5., 5.], device='cuda:1') # Rank 1

torch.distributed.scatter_object_list(scatter_object_output_list, scatter_object_input_list=None, src=None, group=None, group_src=None)[source][source]

Scatters picklable objects in scatter_object_input_list to the whole group.

Similar to scatter(), but Python objects can be passed in. On each rank, the scattered object will be stored as the first element ofscatter_object_output_list. Note that all objects inscatter_object_input_list must be picklable in order to be scattered.

Parameters

Returns

None. If rank is part of the group, scatter_object_output_listwill have its first element set to the scattered object for this rank.

Note

Note that this API differs slightly from the scatter collective since it does not provide an async_op handle and thus will be a blocking call.

Warning

scatter_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Warning

Calling scatter_object_list() with GPU tensors is not well supported and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled. Please consider using scatter() instead.

Example::

Note: Process group initialization omitted on each rank.

import torch.distributed as dist if dist.get_rank() == 0: # Assumes world_size of 3. objects = ["foo", 12, {1: 2}] # any picklable object else: # Can be any list on non-src ranks, elements are not used. objects = [None, None, None] output_list = [None] dist.scatter_object_list(output_list, objects, src=0)

Rank i gets objects[i]. For example, on rank 2:

output_list [{1: 2}]

torch.distributed.reduce_scatter(output, input_list, op=<RedOpType.SUM: 0>, group=None, async_op=False)[source][source]

Reduces, then scatters a list of tensors to all processes in a group.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.

torch.distributed.reduce_scatter_tensor(output, input, op=<RedOpType.SUM: 0>, group=None, async_op=False)[source][source]

Reduces, then scatters a tensor to all ranks in a group.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.

Examples

All tensors below are of torch.int64 dtype and on CUDA devices.

We have two ranks.

device = torch.device(f"cuda:{rank}") tensor_out = torch.zeros(2, dtype=torch.int64, device=device)

Input in concatenation form

tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device) tensor_in tensor([0, 1, 2, 3], device='cuda:0') # Rank 0 tensor([0, 1, 2, 3], device='cuda:1') # Rank 1 dist.reduce_scatter_tensor(tensor_out, tensor_in) tensor_out tensor([0, 2], device='cuda:0') # Rank 0 tensor([4, 6], device='cuda:1') # Rank 1

Input in stack form

tensor_in = torch.reshape(tensor_in, (world_size, 2)) tensor_in tensor([[0, 1], [2, 3]], device='cuda:0') # Rank 0 tensor([[0, 1], [2, 3]], device='cuda:1') # Rank 1 dist.reduce_scatter_tensor(tensor_out, tensor_in) tensor_out tensor([0, 2], device='cuda:0') # Rank 0 tensor([4, 6], device='cuda:1') # Rank 1

Warning

The Gloo backend does not support this API.

torch.distributed.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None, group=None, async_op=False)[source][source]

Split input tensor and then scatter the split list to all processes in a group.

Later the received tensors are concatenated from all the processes in the group and returned as a single output tensor.

Complex tensors are supported.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.

Warning

all_to_all_single is experimental and subject to change.

Examples

input = torch.arange(4) + rank * 4 input tensor([0, 1, 2, 3]) # Rank 0 tensor([4, 5, 6, 7]) # Rank 1 tensor([8, 9, 10, 11]) # Rank 2 tensor([12, 13, 14, 15]) # Rank 3 output = torch.empty([4], dtype=torch.int64) dist.all_to_all_single(output, input) output tensor([0, 4, 8, 12]) # Rank 0 tensor([1, 5, 9, 13]) # Rank 1 tensor([2, 6, 10, 14]) # Rank 2 tensor([3, 7, 11, 15]) # Rank 3

Essentially, it is similar to following operation:

scatter_list = list(input.chunk(world_size)) gather_list = list(output.chunk(world_size)) for i in range(world_size): dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)

Another example with uneven split

input tensor([0, 1, 2, 3, 4, 5]) # Rank 0 tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 tensor([20, 21, 22, 23, 24]) # Rank 2 tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 input_splits [2, 2, 1, 1] # Rank 0 [3, 2, 2, 2] # Rank 1 [2, 1, 1, 1] # Rank 2 [2, 2, 2, 1] # Rank 3 output_splits [2, 3, 2, 2] # Rank 0 [2, 2, 1, 2] # Rank 1 [1, 2, 1, 2] # Rank 2 [1, 2, 1, 1] # Rank 3 output = ... dist.all_to_all_single(output, input, output_splits, input_splits) output tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0 tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1 tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2 tensor([ 5, 17, 18, 24, 36]) # Rank 3

Another example with tensors of torch.cfloat type.

input = torch.tensor( ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat ... ) + 4 * rank * (1 + 1j) input tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0 tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1 tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2 tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3 output = torch.empty([4], dtype=torch.int64) dist.all_to_all_single(output, input) output tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0 tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1 tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2 tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3

torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False)[source][source]

Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.

Complex tensors are supported.

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.

Warning

all_to_all is experimental and subject to change.

Examples

input = torch.arange(4) + rank * 4 input = list(input.chunk(4)) input [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0 [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1 [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2 [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3 output = list(torch.empty([4], dtype=torch.int64).chunk(4)) dist.all_to_all(output, input) output [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0 [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1 [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2 [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3

Essentially, it is similar to following operation:

scatter_list = input gather_list = output for i in range(world_size): dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)

input tensor([0, 1, 2, 3, 4, 5]) # Rank 0 tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 tensor([20, 21, 22, 23, 24]) # Rank 2 tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 input_splits [2, 2, 1, 1] # Rank 0 [3, 2, 2, 2] # Rank 1 [2, 1, 1, 1] # Rank 2 [2, 2, 2, 1] # Rank 3 output_splits [2, 3, 2, 2] # Rank 0 [2, 2, 1, 2] # Rank 1 [1, 2, 1, 2] # Rank 2 [1, 2, 1, 1] # Rank 3 input = list(input.split(input_splits)) input [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0 [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1 [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2 [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3 output = ... dist.all_to_all(output, input) output [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0 [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1 [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2 [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3

Another example with tensors of torch.cfloat type.

input = torch.tensor( ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat ... ) + 4 * rank * (1 + 1j) input = list(input.chunk(4)) input [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0 [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1 [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2 [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3 output = list(torch.empty([4], dtype=torch.int64).chunk(4)) dist.all_to_all(output, input) output [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0 [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1 [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2 [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3

torch.distributed.barrier(group=None, async_op=False, device_ids=None)[source][source]

Synchronize all processes.

This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait().

Parameters

Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

Note

ProcessGroupNCCL now blocks the cpu thread till the completion of the barrier collective.

torch.distributed.monitored_barrier(group=None, timeout=None, wait_all_ranks=False)[source][source]

Synchronize processes similar to torch.distributed.barrier, but consider a configurable timeout.

It is able to report ranks that did not pass this barrier within the provided timeout. Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. Rank 0 will block until all send /recv from other ranks are processed, and will report failures for ranks that failed to respond in time. Note that if one rank does not reach the monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier.

This collective will block all processes/ranks in the group, until the whole group exits the function successfully, making it useful for debugging and synchronizing. However, it can have a performance impact and should only be used for debugging or scenarios that require full synchronization points on the host-side. For debugging purposes, this barrier can be inserted before the application’s collective calls to check if any ranks are desynchronized.

Note

Note that this collective is only supported with the GLOO backend.

Parameters

Returns

None.

Example::

Note: Process group initialization omitted on each rank.

import torch.distributed as dist if dist.get_rank() != 1: dist.monitored_barrier() # Raises exception indicating that

rank 1 did not call into monitored_barrier.

Example with wait_all_ranks=True

if dist.get_rank() == 0: dist.monitored_barrier(wait_all_ranks=True) # Raises exception

indicating that ranks 1, 2, ... world_size - 1 did not call into

monitored_barrier.

class torch.distributed.Work

A Work object represents the handle to a pending asynchronous operation in PyTorch’s distributed package. It is returned by non-blocking collective operations, such as dist.all_reduce(tensor, async_op=True).

boxed(self: torch._C._distributed_c10d.Work) → object

exception(self: torch._C._distributed_c10d.Work) → std::__exception_ptr::exception_ptr

get_future(self: torch._C._distributed_c10d.Work) → torch.Future

Returns

A torch.futures.Future object which is associated with the completion of the Work. As an example, a future object can be retrieved by fut = process_group.allreduce(tensors).get_future().

Example::

Below is an example of a simple allreduce DDP communication hook that usesget_future` API to retrieve a Future associated with the completion of ``allreduce.

def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -> torch.futures.Future group_to_use = process_group if process_group is not None else torch.distributed.group.WORLD tensor = bucket.buffer().div_(group_to_use.size()) return torch.distributed.all_reduce(tensor, group=group_to_use, async_op=True).get_future() ddp_model.register_comm_hook(state=None, hook=allreduce)

Warning

get_future API supports NCCL, and partially GLOO and MPI backends (no support for peer-to-peer operations like send/recv) and will return a torch.futures.Future.

In the example above, allreduce work will be done on GPU using NCCL backend,fut.wait() will return after synchronizing the appropriate NCCL streams with PyTorch’s current device streams to ensure we can have asynchronous CUDA execution and it does not wait for the entire operation to complete on GPU. Note thatCUDAFuture does not support TORCH_NCCL_BLOCKING_WAIT flag or NCCL’s barrier(). In addition, if a callback function was added by fut.then(), it will wait untilWorkNCCL’s NCCL streams synchronize with ProcessGroupNCCL’s dedicated callback stream and invoke the callback inline after running the callback on the callback stream.fut.then() will return another CUDAFuture that holds the return value of the callback and a CUDAEvent that recorded the callback stream.

  1. For CPU work, fut.done() returns true when work has been completed and value() tensors are ready.
  2. For GPU work, fut.done() returns true only whether the operation has been enqueued.
  3. For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), fut.done() returns true when tensors have arrived on respective nodes, but not yet necessarily synched on respective GPUs (similarly to GPU work).

get_future_result(self: torch._C._distributed_c10d.Work) → torch.Future

Returns

A torch.futures.Future object of int type which maps to the enum type of WorkResult As an example, a future object can be retrieved by fut = process_group.allreduce(tensor).get_future_result().

Example::

users can use fut.wait() to blocking wait for the completion of the work and get the WorkResult by fut.value(). Also, users can use fut.then(call_back_func) to register a callback function to be called when the work is completed, without blocking the current thread.

Warning

get_future_result API supports NCCL

is_completed(self: torch._C._distributed_c10d.Work) → bool

is_success(self: torch._C._distributed_c10d.Work) → bool

result(self: torch._C._distributed_c10d.Work) → list[torch.Tensor]

source_rank(self: torch._C._distributed_c10d.Work) → int

synchronize(self: torch._C._distributed_c10d.Work) → None

static unbox(arg0: object) → torch._C._distributed_c10d.Work

wait(self: torch._C._distributed_c10d.Work, timeout: datetime.timedelta = datetime.timedelta(0)) → bool

Returns

true/false.

Example::

try:

work.wait(timeout)

except:

# some handling

Warning

In normal cases, users do not need to set the timeout. calling wait() is the same as calling synchronize(): Letting the current stream block on the completion of the NCCL work. However, if timeout is set, it will block the CPU thread until the NCCL work is completed or timed out. If timeout, exception will be thrown.

class torch.distributed.ReduceOp

An enum-like class for available reduction operations: SUM, PRODUCT,MIN, MAX, BAND, BOR, BXOR, and PREMUL_SUM.

BAND, BOR, and BXOR reductions are not available when using the NCCL backend.

AVG divides values by the world size before summing across ranks.AVG is only available with the NCCL backend, and only for NCCL versions 2.10 or later.

PREMUL_SUM multiplies inputs by a given scalar locally before reduction.PREMUL_SUM is only available with the NCCL backend, and only available for NCCL versions 2.11 or later. Users are supposed to use torch.distributed._make_nccl_premul_sum.

Additionally, MAX, MIN and PRODUCT are not supported for complex tensors.

The values of this class can be accessed as attributes, e.g., ReduceOp.SUM. They are used in specifying strategies for reduction collectives, e.g.,reduce().

This class does not support __members__ property.

class torch.distributed.reduce_op

Deprecated enum-like class for reduction operations: SUM, PRODUCT,MIN, and MAX.

ReduceOp is recommended to use instead.

Distributed Key-Value Store

The distributed package comes with a distributed key-value store, which can be used to share information between processes in the group as well as to initialize the distributed package intorch.distributed.init_process_group() (by explicitly creating the store as an alternative to specifying init_method.) There are 3 choices for Key-Value Stores: TCPStore,FileStore, and HashStore.

class torch.distributed.Store

Base class for all store implementations, such as the 3 provided by PyTorch distributed: (TCPStore, FileStore, and HashStore).

__init__(self: torch._C._distributed_c10d.Store) → None

add(self: torch._C._distributed_c10d.Store, arg0: str, arg1: int) → int

The first call to add for a given key creates a counter associated with key in the store, initialized to amount. Subsequent calls to add with the same key increment the counter by the specified amount. Calling add() with a key that has already been set in the store by set() will result in an exception.

Parameters

Example::

import torch.distributed as dist from datetime import timedelta

Using TCPStore as an example, other store types can also be used

store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.add("first_key", 1) store.add("first_key", 6)

Should return 7

store.get("first_key")

append(self: torch._C._distributed_c10d.Store, arg0: str, arg1: str) → None

Append the key-value pair into the store based on the supplied key andvalue. If key does not exists in the store, it will be created.

Parameters

Example::

import torch.distributed as dist from datetime import timedelta store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.append("first_key", "po") store.append("first_key", "tato")

Should return "potato"

store.get("first_key")

check(self: torch._C._distributed_c10d.Store, arg0: list[str]) → bool

The call to check whether a given list of keys have value stored in the store. This call immediately returns in normal cases but still suffers from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed. Calling check() with a list of keys that one wants to check whether stored in the store or not.

Parameters

keys (lisr _[_str]) – The keys to query whether stored in the store.

Example::

import torch.distributed as dist from datetime import timedelta

Using TCPStore as an example, other store types can also be used

store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.add("first_key", 1)

Should return 7

store.check(["first_key"])

compare_set(self: torch._C._distributed_c10d.Store, arg0: str, arg1: str, arg2: str) → bytes

Inserts the key-value pair into the store based on the supplied key and performs comparison between expected_value and desired_value before inserting. desired_valuewill only be set if expected_value for the key already exists in the store or if expected_valueis an empty string.

Parameters

Example::

import torch.distributed as dist from datetime import timedelta store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.set("key", "first_value") store.compare_set("key", "first_value", "second_value")

Should return "second_value"

store.get("key")

delete_key(self: torch._C._distributed_c10d.Store, arg0: str) → bool

Deletes the key-value pair associated with key from the store. Returnstrue if the key was successfully deleted, and false if it was not.

Warning

The delete_key API is only supported by the TCPStore and HashStore. Using this API with the FileStore will result in an exception.

Parameters

key (str) – The key to be deleted from the store

Returns

True if key was deleted, otherwise False.

Example::

import torch.distributed as dist from datetime import timedelta

Using TCPStore as an example, HashStore can also be used

store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.set("first_key")

This should return true

store.delete_key("first_key")

This should return false

store.delete_key("bad_key")

get(self: torch._C._distributed_c10d.Store, arg0: str) → bytes

Retrieves the value associated with the given key in the store. If key is not present in the store, the function will wait for timeout, which is defined when initializing the store, before throwing an exception.

Parameters

key (str) – The function will return the value associated with this key.

Returns

Value associated with key if key is in the store.

Example::

import torch.distributed as dist from datetime import timedelta store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.set("first_key", "first_value")

Should return "first_value"

store.get("first_key")

has_extended_api(self: torch._C._distributed_c10d.Store) → bool

Returns true if the store supports extended operations.

multi_get(self: torch._C._distributed_c10d.Store, arg0: list[str]) → list[bytes]

Retrieve all values in keys. If any key in keys is not present in the store, the function will wait for timeout

Parameters

keys (List _[_str]) – The keys to be retrieved from the store.

Example::

import torch.distributed as dist from datetime import timedelta store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.set("first_key", "po") store.set("second_key", "tato")

Should return [b"po", b"tato"]

store.multi_get(["first_key", "second_key"])

multi_set(self: torch._C._distributed_c10d.Store, arg0: list[str], arg1: list[str]) → None

Inserts a list key-value pair into the store based on the supplied keys and values

Parameters

Example::

import torch.distributed as dist from datetime import timedelta store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.multi_set(["first_key", "second_key"], ["po", "tato"])

Should return b"po"

store.get("first_key")

num_keys(self: torch._C._distributed_c10d.Store) → int

Returns the number of keys set in the store. Note that this number will typically be one greater than the number of keys added by set()and add() since one key is used to coordinate all the workers using the store.

Warning

When used with the TCPStore, num_keys returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained.

Returns

The number of keys present in the store.

Example::

import torch.distributed as dist from datetime import timedelta

Using TCPStore as an example, other store types can also be used

store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.set("first_key", "first_value")

This should return 2

store.num_keys()

set(self: torch._C._distributed_c10d.Store, arg0: str, arg1: str) → None

Inserts the key-value pair into the store based on the supplied key andvalue. If key already exists in the store, it will overwrite the old value with the new supplied value.

Parameters

Example::

import torch.distributed as dist from datetime import timedelta store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.set("first_key", "first_value")

Should return "first_value"

store.get("first_key")

set_timeout(self: torch._C._distributed_c10d.Store, arg0: datetime.timedelta) → None

Sets the store’s default timeout. This timeout is used during initialization and inwait() and get().

Parameters

timeout (timedelta) – timeout to be set in the store.

Example::

import torch.distributed as dist from datetime import timedelta

Using TCPStore as an example, other store types can also be used

store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) store.set_timeout(timedelta(seconds=10))

This will throw an exception after 10 seconds

store.wait(["bad_key"])

property timeout

Gets the timeout of the store.

wait(*args, **kwargs)

Overloaded function.

  1. wait(self: torch._C._distributed_c10d.Store, arg0: list[str]) -> None

Waits for each key in keys to be added to the store. If not all keys are set before the timeout (set during store initialization), then waitwill throw an exception.

Parameters

keys (list) – List of keys on which to wait until they are set in the store.

Example::

import torch.distributed as dist from datetime import timedelta

Using TCPStore as an example, other store types can also be used

store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))

This will throw an exception after 30 seconds

store.wait(["bad_key"])

  1. wait(self: torch._C._distributed_c10d.Store, arg0: list[str], arg1: datetime.timedelta) -> None

Waits for each key in keys to be added to the store, and throws an exception if the keys have not been set by the supplied timeout.

Parameters

Example::

import torch.distributed as dist from datetime import timedelta

Using TCPStore as an example, other store types can also be used

store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))

This will throw an exception after 10 seconds

store.wait(["bad_key"], timedelta(seconds=10))

class torch.distributed.TCPStore

A TCP-based distributed key-value store implementation. The server store holds the data, while the client stores can connect to the server store over TCP and perform actions such as set() to insert a key-value pair, get() to retrieve a key-value pair, etc. There should always be one server store initialized because the client store(s) will wait for the server to establish a connection.

Parameters

Example::

import torch.distributed as dist from datetime import timedelta

Run on process 1 (server)

server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))

Run on process 2 (client)

client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)

Use any of the store methods from either the client or server after initialization

server_store.set("first_key", "first_value") client_store.get("first_key")

__init__(self: torch._C._distributed_c10d.TCPStore, host_name: str, port: int, world_size: Optional[int] = None, is_master: bool = False, timeout: datetime.timedelta = datetime.timedelta(seconds=300), wait_for_workers: bool = True, multi_tenant: bool = False, master_listen_fd: Optional[int] = None, use_libuv: bool = True) → None

Creates a new TCPStore.

property host

Gets the hostname on which the store listens for requests.

property libuvBackend

Returns True if it’s using the libuv backend.

property port

Gets the port number on which the store listens for requests.

class torch.distributed.HashStore

A thread-safe store implementation based on an underlying hashmap. This store can be used within the same process (for example, by other threads), but cannot be used across processes.

Example::

import torch.distributed as dist store = dist.HashStore()

store can be used from other threads

Use any of the store methods after initialization

store.set("first_key", "first_value")

__init__(self: torch._C._distributed_c10d.HashStore) → None

Creates a new HashStore.

class torch.distributed.FileStore

A store implementation that uses a file to store the underlying key-value pairs.

Parameters

Example::

import torch.distributed as dist store1 = dist.FileStore("/tmp/filestore", 2) store2 = dist.FileStore("/tmp/filestore", 2)

Use any of the store methods from either the client or server after initialization

store1.set("first_key", "first_value") store2.get("first_key")

__init__(self: torch._C._distributed_c10d.FileStore, file_name: str, world_size: int = -1) → None

Creates a new FileStore.

property path

Gets the path of the file used by FileStore to store key-value pairs.

class torch.distributed.PrefixStore

A wrapper around any of the 3 key-value stores (TCPStore,FileStore, and HashStore) that adds a prefix to each key inserted to the store.

Parameters

__init__(self: torch._C._distributed_c10d.PrefixStore, prefix: str, store: torch._C._distributed_c10d.Store) → None

Creates a new PrefixStore.

property underlying_store

Gets the underlying store object that PrefixStore wraps around.

Profiling Collective Communication

Note that you can use torch.profiler (recommended, only available after 1.8.1) or torch.autograd.profiler to profile collective communication and point-to-point communication APIs mentioned here. All out-of-the-box backends (gloo,nccl, mpi) are supported and collective communication usage will be rendered as expected in profiling output/traces. Profiling your code is the same as any regular torch operator:

import torch import torch.distributed as dist with torch.profiler(): tensor = torch.randn(20, 10) dist.all_reduce(tensor)

Please refer to the profiler documentation for a full overview of profiler features.

Multi-GPU collective functions

Warning

The multi-GPU functions (which stand for multiple GPUs per CPU thread) are deprecated. As of today, PyTorch Distributed’s preferred programming model is one device per thread, as exemplified by the APIs in this document. If you are a backend developer and want to support multiple devices per thread, please contact PyTorch Distributed’s maintainers.

Third-party backends

Besides the builtin GLOO/MPI/NCCL backends, PyTorch distributed supports third-party backends through a run-time register mechanism. For references on how to develop a third-party backend through C++ Extension, please refer to Tutorials - Custom C++ and CUDA Extensions andtest/cpp_extensions/cpp_c10d_extension.cpp. The capability of third-party backends are decided by their own implementations.

The new backend derives from c10d::ProcessGroup and registers the backend name and the instantiating interface through torch.distributed.Backend.register_backend()when imported.

When manually importing this backend and invoking torch.distributed.init_process_group()with the corresponding backend name, the torch.distributed package runs on the new backend.

Warning

The support of third-party backend is experimental and subject to change.

Launch utility

The torch.distributed package also provides a launch utility intorch.distributed.launch. This helper utility can be used to launch multiple processes per node for distributed training.

Module torch.distributed.launch.

torch.distributed.launch is a module that spawns up multiple distributed training processes on each of the training nodes.

Warning

This module is going to be deprecated in favor of torchrun.

The utility can be used for single-node distributed training, in which one or more processes per node will be spawned. The utility can be used for either CPU training or GPU training. If the utility is used for GPU training, each distributed process will be operating on a single GPU. This can achieve well-improved single-node training performance. It can also be used in multi-node distributed training, by spawning up multiple processes on each node for well-improved multi-node distributed training performance as well. This will especially be beneficial for systems with multiple Infiniband interfaces that have direct-GPU support, since all of them can be utilized for aggregated communication bandwidth.

In both cases of single-node distributed training or multi-node distributed training, this utility will launch the given number of processes per node (--nproc-per-node). If used for GPU training, this number needs to be less or equal to the number of GPUs on the current system (nproc_per_node), and each process will be operating on a single GPU from GPU 0 to GPU (nproc_per_node - 1).

How to use this module:

  1. Single-Node multi-process distributed training

python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of your training script)

  1. Multi-Node multi-process distributed training: (e.g. two nodes)

Node 1: (IP: 192.168.1.1, and has a free port: 1234)

python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE --nnodes=2 --node-rank=0 --master-addr="192.168.1.1" --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of your training script)

Node 2:

python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE --nnodes=2 --node-rank=1 --master-addr="192.168.1.1" --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of your training script)

  1. To look up what optional arguments this module offers:

python -m torch.distributed.launch --help

Important Notices:

1. This utility and multi-process distributed (single-node or multi-node) GPU training currently only achieves the best performance using the NCCL distributed backend. Thus NCCL backend is the recommended backend to use for GPU training.

2. In your training program, you must parse the command-line argument:--local-rank=LOCAL_PROCESS_RANK, which will be provided by this module. If your training program uses GPUs, you should ensure that your code only runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:

Parsing the local_rank argument

import argparse parser = argparse.ArgumentParser() parser.add_argument("--local-rank", "--local_rank", type=int) args = parser.parse_args()

Set your device to local rank using either

torch.cuda.set_device(args.local_rank) # before your code runs

or

with torch.cuda.device(args.local_rank):

your code to run

...

Changed in version 2.0.0: The launcher will passes the --local-rank=<rank> argument to your script. From PyTorch 2.0.0 onwards, the dashed --local-rank is preferred over the previously used underscored --local_rank.

For backward compatibility, it may be necessary for users to handle both cases in their argument parsing code. This means including both "--local-rank"and "--local_rank" in the argument parser. If only "--local_rank" is provided, the launcher will trigger an error: “error: unrecognized arguments: –local-rank=”. For training code that only supports PyTorch 2.0.0+, including "--local-rank" should be sufficient.

3. In your training program, you are supposed to call the following function at the beginning to start the distributed backend. It is strongly recommended that init_method=env://. Other init methods (e.g. tcp://) may work, but env:// is the one that is officially supported by this module.

torch.distributed.init_process_group(backend='YOUR BACKEND', init_method='env://')

4. In your training program, you can either use regular distributed functions or use torch.nn.parallel.DistributedDataParallel() module. If your training program uses GPUs for training and you would like to usetorch.nn.parallel.DistributedDataParallel() module, here is how to configure it.

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

Please ensure that device_ids argument is set to be the only GPU device id that your code will be operating on. This is generally the local rank of the process. In other words, the device_ids needs to be [args.local_rank], and output_device needs to be args.local_rank in order to use this utility

5. Another way to pass local_rank to the subprocesses via environment variableLOCAL_RANK. This behavior is enabled when you launch the script with--use-env=True. You must adjust the subprocess example above to replaceargs.local_rank with os.environ['LOCAL_RANK']; the launcher will not pass --local-rank when you specify this flag.

Warning

local_rank is NOT globally unique: it is only unique per process on a machine. Thus, don’t use it to decide if you should, e.g., write to a networked filesystem. Seehttps://github.com/pytorch/pytorch/issues/12042 for an example of how things can go wrong if you don’t do this correctly.

Debugging torch.distributed applications

Debugging distributed applications can be challenging due to hard to understand hangs, crashes, or inconsistent behavior across ranks. torch.distributed provides a suite of tools to help debug training applications in a self-serve fashion:

Python Breakpoint

It is extremely convenient to use python’s debugger in a distributed environment, but because it does not work out of the box many people do not use it at all. PyTorch offers a customized wrapper around pdb that streamlines the process.

torch.distributed.breakpoint makes this process easy. Internally, it customizes pdb’s breakpoint behavior in two ways but otherwise behaves as normal pdb. 1. Attaches the debugger only on one rank (specified by the user). 2. Ensures all other ranks stop, by using a torch.distributed.barrier() that will release once the debugged rank issues a continue3. Reroutes stdin from the child process such that it connects to your terminal.

To use it, simply issue torch.distributed.breakpoint(rank) on all ranks, using the same value for rank in each case.

Monitored Barrier

As of v1.10, torch.distributed.monitored_barrier() exists as an alternative to torch.distributed.barrier() which fails with helpful information about which rank may be faulty when crashing, i.e. not all ranks calling into torch.distributed.monitored_barrier() within the provided timeout. torch.distributed.monitored_barrier() implements a host-side barrier using send/recv communication primitives in a process similar to acknowledgements, allowing rank 0 to report which rank(s) failed to acknowledge the barrier in time. As an example, consider the following function where rank 1 fails to call into torch.distributed.monitored_barrier() (in practice this could be due to an application bug or hang in a previous collective):

import os from datetime import timedelta

import torch import torch.distributed as dist import torch.multiprocessing as mp

def worker(rank): dist.init_process_group("nccl", rank=rank, world_size=2) # monitored barrier requires gloo process group to perform host-side sync. group_gloo = dist.new_group(backend="gloo") if rank not in [1]: dist.monitored_barrier(group=group_gloo, timeout=timedelta(seconds=2))

if name == "main": os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" mp.spawn(worker, nprocs=2, args=())

The following error message is produced on rank 0, allowing the user to determine which rank(s) may be faulty and investigate further:

RuntimeError: Rank 1 failed to pass monitoredBarrier in 2000 ms Original exception: [gloo/transport/tcp/pair.cc:598] Connection closed by peer [2401:db00:eef0:1100:3560:0:1c05:25d]:8594

TORCH_DISTRIBUTED_DEBUG

With TORCH_CPP_LOG_LEVEL=INFO, the environment variable TORCH_DISTRIBUTED_DEBUG can be used to trigger additional useful logging and collective synchronization checks to ensure all ranks are synchronized appropriately. TORCH_DISTRIBUTED_DEBUG can be set to either OFF (default), INFO, or DETAIL depending on the debugging level required. Please note that the most verbose option, DETAIL may impact the application performance and thus should only be used when debugging issues.

Setting TORCH_DISTRIBUTED_DEBUG=INFO will result in additional debug logging when models trained with torch.nn.parallel.DistributedDataParallel() are initialized, andTORCH_DISTRIBUTED_DEBUG=DETAIL will additionally log runtime performance statistics a select number of iterations. These runtime statistics include data such as forward time, backward time, gradient communication time, etc. As an example, given the following application:

import os

import torch import torch.distributed as dist import torch.multiprocessing as mp

class TwoLinLayerNet(torch.nn.Module): def init(self): super().init() self.a = torch.nn.Linear(10, 10, bias=False) self.b = torch.nn.Linear(10, 1, bias=False)

def forward(self, x):
    a = self.a(x)
    b = self.b(x)
    return (a, b)

def worker(rank): dist.init_process_group("nccl", rank=rank, world_size=2) torch.cuda.set_device(rank) print("init model") model = TwoLinLayerNet().cuda() print("init ddp") ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

inp = torch.randn(10, 10).cuda()
print("train")

for _ in range(20):
    output = ddp_model(inp)
    loss = output[0] + output[1]
    loss.sum().backward()

if name == "main": os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" os.environ["TORCH_CPP_LOG_LEVEL"]="INFO" os.environ[ "TORCH_DISTRIBUTED_DEBUG" ] = "DETAIL" # set to DETAIL for runtime logging. mp.spawn(worker, nprocs=2, args=())

The following logs are rendered at initialization time:

I0607 16:10:35.739390 515217 logger.cpp:173] [Rank 0]: DDP Initialized with: broadcast_buffers: 1 bucket_cap_bytes: 26214400 find_unused_parameters: 0 gradient_as_bucket_view: 0 is_multi_device_module: 0 iteration: 0 num_parameter_tensors: 2 output_device: 0 rank: 0 total_parameter_size_bytes: 440 world_size: 2 backend_name: nccl bucket_sizes: 440 cuda_visible_devices: N/A device_ids: 0 dtypes: float master_addr: localhost master_port: 29501 module_name: TwoLinLayerNet nccl_async_error_handling: N/A nccl_blocking_wait: N/A nccl_debug: WARN nccl_ib_timeout: N/A nccl_nthreads: N/A nccl_socket_ifname: N/A torch_distributed_debug: INFO

The following logs are rendered during runtime (when TORCH_DISTRIBUTED_DEBUG=DETAIL is set):

I0607 16🔞58.085681 544067 logger.cpp:344] [Rank 1 / 2] Training TwoLinLayerNet unused_parameter_size=0 Avg forward compute time: 40838608 Avg backward compute time: 5983335 Avg backward comm. time: 4326421 Avg backward comm/comp overlap time: 4207652 I0607 16🔞58.085693 544066 logger.cpp:344] [Rank 0 / 2] Training TwoLinLayerNet unused_parameter_size=0 Avg forward compute time: 42850427 Avg backward compute time: 3885553 Avg backward comm. time: 2357981 Avg backward comm/comp overlap time: 2234674

In addition, TORCH_DISTRIBUTED_DEBUG=INFO enhances crash logging in torch.nn.parallel.DistributedDataParallel() due to unused parameters in the model. Currently, find_unused_parameters=Truemust be passed into torch.nn.parallel.DistributedDataParallel() initialization if there are parameters that may be unused in the forward pass, and as of v1.10, all model outputs are required to be used in loss computation as torch.nn.parallel.DistributedDataParallel() does not support unused parameters in the backwards pass. These constraints are challenging especially for larger models, thus when crashing with an error, torch.nn.parallel.DistributedDataParallel() will log the fully qualified name of all parameters that went unused. For example, in the above application, if we modify loss to be instead computed as loss = output[1], then TwoLinLayerNet.a does not receive a gradient in the backwards pass, and thus results in DDP failing. On a crash, the user is passed information about parameters which went unused, which may be challenging to manually find for large models:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel, and by making sure all forward function outputs participate in calculating loss. If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return va lue of forward of your module when reporting this issue (e.g. list, dict, iterable). Parameters which did not receive grad for rank 0: a.weight Parameter indices which did not receive grad for rank 0: 0

Setting TORCH_DISTRIBUTED_DEBUG=DETAIL will trigger additional consistency and synchronization checks on every collective call issued by the user either directly or indirectly (such as DDP allreduce). This is done by creating a wrapper process group that wraps all process groups returned bytorch.distributed.init_process_group() and torch.distributed.new_group() APIs. As a result, these APIs will return a wrapper process group that can be used exactly like a regular process group, but performs consistency checks before dispatching the collective to an underlying process group. Currently, these checks include a torch.distributed.monitored_barrier(), which ensures all ranks complete their outstanding collective calls and reports ranks which are stuck. Next, the collective itself is checked for consistency by ensuring all collective functions match and are called with consistent tensor shapes. If this is not the case, a detailed error report is included when the application crashes, rather than a hang or uninformative error message. As an example, consider the following function which has mismatched input shapes intotorch.distributed.all_reduce():

import torch import torch.distributed as dist import torch.multiprocessing as mp

def worker(rank): dist.init_process_group("nccl", rank=rank, world_size=2) torch.cuda.set_device(rank) tensor = torch.randn(10 if rank == 0 else 20).cuda() dist.all_reduce(tensor) torch.cuda.synchronize(device=rank)

if name == "main": os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" os.environ["TORCH_CPP_LOG_LEVEL"]="INFO" os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" mp.spawn(worker, nprocs=2, args=())

With the NCCL backend, such an application would likely result in a hang which can be challenging to root-cause in nontrivial scenarios. If the user enablesTORCH_DISTRIBUTED_DEBUG=DETAIL and reruns the application, the following error message reveals the root cause:

work = default_pg.allreduce([tensor], opts) RuntimeError: Error when verifying shape tensors for collective ALLREDUCE on rank 0. This likely indicates that input shapes into the collective are mismatched across ranks. Got shapes: 10 [ torch.LongTensor{1} ]

Note

For fine-grained control of the debug level during runtime the functions torch.distributed.set_debug_level(), torch.distributed.set_debug_level_from_env(), andtorch.distributed.get_debug_level() can also be used.

In addition, TORCH_DISTRIBUTED_DEBUG=DETAIL can be used in conjunction with TORCH_SHOW_CPP_STACKTRACES=1 to log the entire callstack when a collective desynchronization is detected. These collective desynchronization checks will work for all applications that use c10d collective calls backed by process groups created with thetorch.distributed.init_process_group() and torch.distributed.new_group() APIs.

Logging

In addition to explicit debugging support via torch.distributed.monitored_barrier() and TORCH_DISTRIBUTED_DEBUG, the underlying C++ library of torch.distributed also outputs log messages at various levels. These messages can be helpful to understand the execution state of a distributed training job and to troubleshoot problems such as network connection failures. The following matrix shows how the log level can be adjusted via the combination of TORCH_CPP_LOG_LEVEL and TORCH_DISTRIBUTED_DEBUG environment variables.

TORCH_CPP_LOG_LEVEL TORCH_DISTRIBUTED_DEBUG Effective Log Level
ERROR ignored Error
WARNING ignored Warning
INFO ignored Info
INFO INFO Debug
INFO DETAIL Trace (a.k.a. All)

Distributed components raise custom Exception types derived from RuntimeError:

class torch.distributed.DistError

Exception raised when an error occurs in the distributed library

class torch.distributed.DistBackendError

Exception raised when a backend error occurs in distributed

class torch.distributed.DistNetworkError

Exception raised when a network error occurs in distributed

class torch.distributed.DistStoreError

Exception raised when an error occurs in the distributed store

If you are running single node training, it may be convenient to interactively breakpoint your script. We offer a way to conveniently breakpoint a single rank:

torch.distributed.breakpoint(rank=0, skip=0)[source][source]

Set a breakpoint, but only on a single rank. All other ranks will wait for you to be done with the breakpoint before continuing.

Parameters