PyTorch API — sagemaker 2.199.0 documentation (original) (raw)

sagemaker

Supported versions: 1.7.1, 1.6.0

This API document assumes you use the following import statements in your training scripts.

import smdistributed.modelparallel.torch as smp

class smp. DistributedModel

A sub-class of torch.nn.Module which specifies the model to be partitioned. Accepts a torch.nn.Module object module which is the model to be partitioned. The returned DistributedModel object internally manages model parallelism and data parallelism. Only one model in the training script can be wrapped withsmp.DistributedModel.

Example:

model = smp.DistributedModel(model)

Important: The __call__ and backward method calls on thesmp.DistributedModel object (in the following example, the object is model) can only be made inside a smp.step-decorated function.

Since DistributedModel is a torch.nn.Module, a forward pass can be performed by calling the DistributedModel object on the input tensors.

predictions = model(inputs)   # model is a smp.DistributedModel object

For a backward pass, one needs to call the backward function on the DistributedModel object, with tensors and gradients as arguments, replacing the PyTorch operations torch.Tensor.backwardor torch.autograd.backward.

The API for model.backward is very similar totorch.autograd.backward. For example, the followingbackward calls:

torch.autograd.backward(loss) or loss.backward()

should be replaced with:

model.backward(loss) # loss is a tensor with only one element as its data

Similarly, for non-scalar tensors, replace the followingbackward call containing incoming gradient arguments:

torch.autograd.backward(outputs, out_grads)

with the following line:

model.backward(outputs, out_grads)

In these examples, all __call__ and backward method calls on the model objects (model(inputs) and model.backward(loss)) must be made inside a smp.step-decorated function.

Using DDP

If DDP is enabled, do not not place a PyTorchDistributedDataParallel wrapper around the DistributedModel because the DistributedModel wrapper will also handle data parallelism.

Unlike the original DDP wrapper, when you use DistributedModel, model parameters and buffers are not immediately broadcast across processes when the wrapper is called. Instead, the broadcast is deferred to the first call of thesmp.step-decorated function when the partition is done.

Parameters

Properties

Methods

backward(tensors, grad_tensors)

Triggers a distributed backward pass across model partitions. Example usage provided in the previous section. The API is very similar to https://pytorch.org/docs/stable/autograd.html#torch.autograd.backward.retain_grad and create_graph flags are not supported.

local_buffers()

Returns an iterator over buffers for the modules in the partitioned model that have been assigned to the current process.

local_named_buffers()

Returns an iterator over buffers for the modules in the partitioned model that have been assigned to the current process. This yields both the name of the buffer as well as the buffer itself.

local_parameters()

Returns an iterator over parameters for the modules in the partitioned model that have been assigned to the current process.

local_named_parameters()

Returns an iterator over parameters for the modules in the partitioned model that have been assigned to the current process. This yields both the name of the parameter as well as the parameter itself.

local_modules()

Returns an iterator over the modules in the partitioned model that have been assigned to the current process.

local_named_modules()

Returns an iterator over the modules in the partitioned model that have been assigned to the current process. This yields both the name of the module as well as the module itself.

local_state_dict()

Returns the state_dict that contains local parameters that belong to the current mp_rank. This state_dictcontains a key _smp_is_partial to indicate this is a partial state_dict, which indicates whether thestate_dict contains elements corresponding to only the current partition, or to the entire model.

state_dict()

Returns the state_dict that contains parameters for the entire model. It first collects the local_state_dict and gathers and merges the local_state_dict from all mp_ranks to create a full state_dict. Please note that this needs to be called on all ranks withdp_rank()==0 to ensure the gather happens properly. If it is only called on all such ranks, it can hang.

load_state_dict()

Same as the torch.module.load_state_dict() , except: It first gathers and merges the state_dicts acrossmp_ranks, if they are partial. The actual loading happens after the model partition so that each rank knows its local parameters.

register_post_partition_hook(hook)

Registers a callable hook to be executed after the model is partitioned. This is useful in situations where an operation needs to be executed after the model partition during the first call to smp.step, but before the actual execution of the first forward pass. Returns a RemovableHandle object handle, which can be used to remove the hook by calling handle.remove().

cpu()

Allgathers parameters and buffers across all mp_ranks and moves them to the CPU.

join()

Available for PyTorch 1.7.1 only

A context manager to be used in conjunction with an instance ofsmp.DistributedModel to be able to train with uneven inputs across participating processes. This is only supported when ddp=True forsmp.DistributedModel. This will use the join with the wrappedDistributedDataParallel instance. For more information, see:joinin the PyTorch documentation.

class smp. DistributedOptimizer

Parameters- optimizer

An optimizer wrapper for saving/loading optimizer states. This wrapper returns optimizer with the following methods overridden:

state_dict()

Returns the state_dict that contains optimizer state for the entire model. It first collects the local_state_dict and gathers and merges the local_state_dict from all mp_rank``s to create a full ``state_dict.

load_state_dict()

Same as the torch.optimizer.load_state_dict() , except:

local_state_dict()

Returns the state_dict that contains the local optimizer state that belongs to the current mp_rank. Thisstate_dict contains a key _smp_is_partial to indicate this is a partial state_dict, which indicates whether thestate_dict contains elements corresponding to only the current partition, or to the entire model.

smp. partition(index)

Inputs

A context manager which places all modules defined inside into the partition with ID index. The index argument must be less than the number of partitions.

Use smp.partition to implement manual partitioning. If "auto_partition" is True, then thesmp.partition contexts are ignored. Any module that is not placed in any smp.partition context is placed in thedefault_partition defined through the SageMaker Python SDK.

When smp.partition contexts are nested, the innermost context overrides the rest (see the following example). In PyTorch, manual partitioning should be done inside the module __init__, and the partition assignment applies to the modules that are created inside the smp.partition context.

Example:

class Model(torch.nn.Module):     def init(self):         with smp.partition(1):             self.child0 = Child0()            # child0 on partition 1             with smp.partition(2):                 self.child1 = Child1()        # child1 on partition 2             self.child2 = Child2()            # child2 on partition 1         self.child3 = Child3()                # child3 on default_partition

smp. get_world_process_group()

Returns a torch.distributed ProcessGroup that consists of all processes, which can be used with the torch.distributed API. Requires "ddp": True in SageMaker Python SDK parameters.

smp. get_mp_process_group()

Returns a torch.distributed ProcessGroup that consists of the processes in the MP_GROUP which contains the current process, which can be used with the torch.distributed API. Requires"ddp": True in SageMaker Python SDK parameters.

smp. get_dp_process_group()

Returns a torch.distributed ProcessGroup that consists of the processes in the DP_GROUP which contains the current process, which can be used with the torch.distributed API. Requires"ddp": True in SageMaker Python SDK parameters.

smp. is_initialized()

Returns True if smp.init has already been called for the process, and False otherwise.

smp.nn. FusedLayerNorm

Apex Fused Layer Norm is currently not supported by the library. smp.nn.FusedLayerNorm replaces apex FusedLayerNorm and provides the same functionality. This requiresapex to be installed on the system.

smp.optimizers. FusedNovoGrad

Fused Novo Grad optimizer is currently not supported by the library. smp.optimizers.FusedNovoGrad replaces apex FusedNovoGradoptimizer and provides the same functionality. This requires apex to be installed on the system.

smp.optimizers. FusedLamb

FusedLamb optimizercurrently doesn’t work with the library. smp.optimizers.FusedLamb replacesapex FusedLamb optimizer and provides the same functionality. This requires apex to be installed on the system.

smp.amp. GradScaler

Torch AMP Gradscalercurrently doesn’t work with the library. smp.amp.GradScaler replacestorch.amp.GradScaler and provides the same functionality.

APIs for Saving and Loading

smp. save()

Saves an object. This operation is similar to torch.save(), except it has an additional keyword argument, partial, and accepts only string type for the argument f (file). If partial=True, eachmp_rank saves a separate checkpoint file and the library adds an mp_rankindex to your saved file.

Parameters

smp. load()

Loads an object saved with smp.save() from a file.

Similar to, torch.load(), except it has an additional keyword argument, partial, and accepts only string type for the argument f (file). If partial=True, then each mp_rank loads a separate checkpoint file.

Parameters

General Instruction For Saving and Loading

The library can save partial or full checkpoints.

When saving using smp.save(), each rank only holds its own parameters. If you want to save the full model, there will be some communication between the ranks to create the full model. If you save checkpoints often, you should save partial checkpoints for best performance.

When loading using smp.load(), the library can load either partial or | full checkpoints or full checkpoints saved by a non-model-parallel model. If you want to resume training with a non-model-parallel model or do inference, you need a full checkpoint.

The following is an example of how you can save and load a checkpoint:

Original model and optimizer

model = MyModel(...) optimizer = MyOpt(...)

model parallel wrapper

model = smp.DistributedModel(model) optimizer = smp.DistributedOptimizer(optimizer)

To save, always save on dp_rank 0 to avoid data racing

if partial:     # To save the partial model on each mp rank     # the library will create checkpoint.pt_{mprank} for each mp rank     if save_partial_model:         if smp.dp_rank() == 0:             model_dict = model.local_state_dict() # save the partial model             opt_dict = optimizer.local_state_dict() # save the partial optimizer state             smp.save(                 {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},                 f"/checkpoint.pt",                 partial=True,             )

    # To save the full model     if save_full_model:         if smp.dp_rank() == 0:             model_dict = model.state_dict() # save the full model             opt_dict = optimizer.state_dict() # save the full optimizer state             smp.save(                 {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},                 "/checkpoint.pt",                 partial=False,             )

To load, load on all ranks.

The only difference for partial/full loading is the partial flag in smp.load

Load partial checkpoint

if partial_checkpoint:    checkpoint = smp.load("/checkpoint.pt", partial=True)    model.load_state_dict(checkpoint["model_state_dict"])    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

Load full checkpoint

if full_checkpoint:    checkpoint = smp.load("/checkpoint.pt", partial=False)    model.load_state_dict(checkpoint["model_state_dict"])    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])