Distributed Data Parallel — PyTorch 2.7 documentation (original) (raw)

torch.nn.parallel.DistributedDataParallel (DDP) transparently performs distributed data parallel training. This page describes how it works and reveals implementation details.

Example

Let us start with a simple torch.nn.parallel.DistributedDataParallelexample. This example uses a torch.nn.Linear as the local model, wraps it with DDP, and then runs one forward pass, one backward pass, and an optimizer step on the DDP model. After that, parameters on the local model will be updated, and all models on different processes should be exactly the same.

import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim import os from torch.nn.parallel import DistributedDataParallel as DDP

def example(rank, world_size): # create default process group dist.init_process_group("gloo", rank=rank, world_size=world_size) # create local model model = nn.Linear(10, 10).to(rank) # construct DDP model ddp_model = DDP(model, device_ids=[rank]) # define loss function and optimizer loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()

def main(): world_size = 2 mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)

if name=="main": # Environment variables which need to be # set when using c10d's default "env" # initialization mode. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" main()

DDP works with TorchDynamo. When used with TorchDynamo, apply the DDP model wrapper before compiling the model, such that torchdynamo can apply DDPOptimizer(graph-break optimizations) based on DDP bucket sizes. (See TorchDynamo DDPOptimizer for more information.)

ddp_model = DDP(model, device_ids=[rank]) ddp_model = torch.compile(ddp_model)

Internal Design

This section reveals how it works under the hood oftorch.nn.parallel.DistributedDataParallel by diving into details of every step in one iteration.

ddp_grad_sync.png

Note

DDP requires Reducer instances on all processes to invoke allreducein exactly the same order, which is done by always running allreducein the bucket index order instead of actual bucket ready order. Mismatchedallreduce order across processes can lead to wrong results or DDP backward hang.

Implementation

Below are pointers to the DDP implementation components. The stacked graph shows the structure of the code.

ProcessGroup

DistributedDataParallel

ddp_code.png

TorchDynamo DDPOptimizer

DDP’s performance advantage comes from overlapping allreduce collectives with computations during backwards. AotAutograd prevents this overlap when used with TorchDynamo for compiling a whole forward and whole backward graph, because allreduce ops are launched by autograd hooks _after_ the whole optimized backwards computation finishes.

TorchDynamo’s DDPOptimizer helps by breaking the forward graph at the logical boundaries of DDP’s allreduce buckets during backwards. Note: the goal is to break the graph during backwards, and the simplest implementation is to break the forward graphs and then call AotAutograd and compilation on each section. This allows DDP’s allreduce hooks to fire in-between sections of backwards, and schedule communications to overlap with compute.

See this blog post for a more in-depth explanation and experimental results, or read the docs and code attorch/_dynamo/optimizations/distributed.py

To Debug DDPOptimizer, set TORCH_LOGS=’ddp_graphs’ for full graph dumps. For logs without graphs, add any of ‘dynamo’, ‘distributed’, or ‘dist_ddp’ to TORCH_LOGS(for basic info about bucket boundaries). To disable DDPOptimizer, set torch._dynamo.config.optimize_ddp=False. DDP and TorchDynamo should still work correctly without DDPOptimizer, but with performance degradation.