Hybrid sharded data parallelism - Amazon SageMaker AI (original) (raw)
Sharded data parallelism is a memory-saving distributed training technique that splits the state of a model (model parameters, gradients, and optimizer states) across devices. This helps you fit a larger model or increase the batch size using the freed-up GPU memory. The SMP library offers a capability of running sharded data parallelism with PyTorch Fully Sharded Data Parallel (FSDP). PyTorch FSDP by default shards across the whole set of GPUs being used. In SMP v2, the library offers this sharded data parallelism on top of PyTorch FSDP by extending PyTorch hybrid sharding (HYBRID_SHARD
), which is one of the sharding strategies provided by PyTorch FSDP: FULL_SHARD
,SHARD_GRAD_OP
, HYBRID_SHARD
,_HYBRID_SHARD_ZERO2
. Extending hybrid sharding in this manner helps implement scale-aware-sharding as described in the blog Near-linear scaling of gigantic-model training on AWS for PyTorch FSDP.
The SMP library makes it easy to use HYBRID_SHARD
and_HYBRID_SHARD_ZERO2
across any configurable number of GPUs, extending the native PyTorch FSDP that supports sharding across a single node (HYBRID_SHARD
) or all GPUs (FULL_SHARD
). PyTorch FSDP calls can stay as is, and you only need to add the hybrid_shard_degree
argument to the SMP configuration, as shown in the following code example. You don't need to change the value of the sharding_strategy
argument in the PyTorch FSDP wrapper around your PyTorch model. You can passShardingStrategy.HYBRID_SHARD
as the value. Alternatively, the SMP library overrides the strategy in the script and sets it toShardingStrategy.HYBRID_SHARD
if you specify a value equal to or greater than 2 to the hybrid_shard_degree
parameter.
The following code snippets show how to add the SMP initialization moduletorch.sagemaker.init()
to your training script and set up the SMP configuration dictionary in JSON format for training job launcher while following the two-step process introduced in Use the SageMaker model parallelism library v2. You don’t need to make any changes to your PyTorch model or PyTorch FSDP configuration. For more information about thehybrid_shard_degree
parameter, see SMP v2 core feature configuration parameters.
SMP configuration dictionary
{ "hybrid_shard_degree": 16 }
In training script
import torch.sagemaker as tsm
tsm.init()
# Set up a PyTorch model
model = ...
# Wrap the PyTorch model using the PyTorch FSDP module
model = FSDP(
model,
...
)
# Optimizer needs to be created after FSDP wrapper
optimizer = ...