Modify a PyTorch Training Script (original) (raw)

In this section, you learn how to modify PyTorch training scripts to configure the SageMaker model parallelism library for auto-partitioning and manual partitioning.

Note that auto-partitioning is enabled by default. Unless otherwise specified, the following scripts use auto-partitioning.

Topics

Automated splitting with PyTorch

The following training script changes are required to run a PyTorch training script with SageMaker's model parallelism library:

  1. Import and initialize the library with smdistributed.modelparallel.torch.init().
  2. Wrap the model with smdistributed.modelparallel.torch.DistributedModel. Be mindful that any tensors returned from the forward method of the underlying nn.Module object will be broadcast across model-parallel devices, incurring communication overhead, so any tensors that are not needed outside the call method (such as intermediate activations) should not be returned.
  3. Wrap the optimizer with smdistributed.modelparallel.torch.DistributedOptimizer.
  4. Use the returned DistributedModel object instead of a user model.
  5. Put the forward and backward logic in a step function and decorate it withsmdistributed.modelparallel.torch.step.
  6. Restrict each process to its own device throughtorch.cuda.set_device(smp.local_rank()).
  7. Move the input tensors to the GPU using the .to() API before the smp.step call (see example below).
  8. Replace torch.Tensor.backward andtorch.autograd.backward withDistributedModel.backward.
  9. Perform post-processing on the outputs across microbatches using StepOutput methods such asreduce_mean.
  10. If there is an evaluation step, similarly place the forward logic inside an smp.step-decorated function and post-process the outputs using StepOutput API.
  11. Set drop_last=True in DataLoader. Alternatively, manually skip a batch in the training loop if the batch size is not divisible by the number of microbatches.

To learn more about the SageMaker's model parallelism library API, refer to the API documentation.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchnet.dataset import SplitDataset
from torchvision import datasets

import smdistributed.modelparallel.torch as smp

class GroupedNet(nn.Module):
    def __init__(self):
        super(GroupedNet, self).__init__()
        # define layers

    def forward(self, x):
        # define forward pass and return model outputs


# smdistributed: Define smp.step. Return any tensors needed outside.
@smp.step
def train_step(model, data, target):
    output = model(data)
    loss = F.nll_loss(output, target, reduction="mean")
    model.backward(loss)
    return output, loss


def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # smdistributed: Move input tensors to the GPU ID used by the current process,
        # based on the set_device call.
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        # Return value, loss_mb is a StepOutput object
        _, loss_mb = train_step(model, data, target)

        # smdistributed: Average the loss across microbatches.
        loss = loss_mb.reduce_mean()

        optimizer.step()

# smdistributed: initialize the backend
smp.init()

# smdistributed: Set the device to the GPU ID used by the current process.
# Input tensors should be transferred to this device.
torch.cuda.set_device(smp.local_rank())
device = torch.device("cuda")

# smdistributed: Download only on a single process per instance.
# When this is not present, the file is corrupted by multiple processes trying
# to download and extract at the same time
dataset = datasets.MNIST("../data", train=True, download=False)

# smdistributed: Shard the dataset based on data-parallel ranks
if smp.dp_size() > 1:
    partitions_dict = {f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size())}
    dataset = SplitDataset(dataset, partitions=partitions_dict)
    dataset.select(f"{smp.dp_rank()}")

# smdistributed: Set drop_last=True to ensure that batch size is always divisible
# by the number of microbatches
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True)

model = GroupedNet()
optimizer = optim.Adadelta(model.parameters(), lr=4.0)

# smdistributed: Use the DistributedModel container to provide the model
# to be partitioned across different ranks. For the rest of the script,
# the returned DistributedModel object should be used in place of
# the model provided for DistributedModel class instantiation.
model = smp.DistributedModel(model)
optimizer = smp.DistributedOptimizer(optimizer)

train(model, device, train_loader, optimizer)

Manual splitting with PyTorch

Use smp.partition context managers to place modules in specific devices. Any module not placed in any smp.partition contexts is placed in the default_partition. The default_partition needs to be provided if auto_partition is set to False. The modules that are created within a specific smp.partition context are placed on the corresponding partition.

To learn more about the SageMaker's model parallelism library API, refer to the API documentation.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchnet.dataset import SplitDataset
from torchvision import datasets

import smdistributed.modelparallel.torch as smp

class GroupedNet(nn.Module):
    def __init__(self):
        super(GroupedNet, self).__init__()
        with smp.partition(0):
            # define child modules on device 0
        with smp.partition(1):
            # define child modules on device 1

    def forward(self, x):
        # define forward pass and return model outputs


# smdistributed: Define smp.step. Return any tensors needed outside.
@smp.step
def train_step(model, data, target):
    output = model(data)
    loss = F.nll_loss(output, target, reduction="mean")
    model.backward(loss)
    return output, loss


def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # smdistributed: Move input tensors to the GPU ID used by the current process,
        # based on the set_device call.
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        # Return value, loss_mb is a StepOutput object
        _, loss_mb = train_step(model, data, target)

        # smdistributed: Average the loss across microbatches.
        loss = loss_mb.reduce_mean()

        optimizer.step()

# smdistributed: initialize the backend
smp.init()

# smdistributed: Set the device to the GPU ID used by the current process.
# Input tensors should be transferred to this device.
torch.cuda.set_device(smp.local_rank())
device = torch.device("cuda")

# smdistributed: Download only on a single process per instance.
# When this is not present, the file is corrupted by multiple processes trying
# to download and extract at the same time
dataset = datasets.MNIST("../data", train=True, download=False)

# smdistributed: Shard the dataset based on data-parallel ranks
if smp.dp_size() > 1:
    partitions_dict = {f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size())}
    dataset = SplitDataset(dataset, partitions=partitions_dict)
    dataset.select(f"{smp.dp_rank()}")

# smdistributed: Set drop_last=True to ensure that batch size is always divisible
# by the number of microbatches
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True)

model = GroupedNet()
optimizer = optim.Adadelta(model.parameters(), lr=4.0)

# smdistributed: Use the DistributedModel container to provide the model
# to be partitioned across different ranks. For the rest of the script,
# the returned DistributedModel object should be used in place of
# the model provided for DistributedModel class instantiation.
model = smp.DistributedModel(model)
optimizer = smp.DistributedOptimizer(optimizer)

train(model, device, train_loader, optimizer)

Considerations

When you configure a PyTorch training script using SageMaker's model parallelism library, you should be aware of the following:

dataloader = torch.utils.data.DataLoader(  
            data,  
            batch_size=batch_size,  
            num_workers=0,  
            pin_memory=True,  
            drop_last=True,  
            shuffle=shuffle,  
        )  
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True)  

You should access the model inputs generated by train_loader and pass those to an smp.step decorated function. Do not passtrain_loader directly to the smp.step function.

def train(model, device, train_loader, optimizer):  
    model.train()  
    for batch_idx, (data, target) in enumerate(train_loader):  
        ...  
        _, loss_mb = train_step(model, data, target)  
        ...  
@smp.step  
def train_step(model, data, target):  
    ...  
    return output, loss  
def train(model, device, train_loader, optimizer):  
    model.train()  
    for batch_idx, (data, target) in enumerate(train_loader):  
        # smdistributed: Move input tensors to the GPU ID used by the current process,  
        # based on the set_device call.  
        data, target = data.to(device), target.to(device)  
        optimizer.zero_grad()  
        # Return value, loss_mb is a StepOutput object  
        _, loss_mb = train_step(model, data, target)  
        # smdistributed: Average the loss across microbatches.  
        loss = loss_mb.reduce_mean()  
        optimizer.step()  

The input tensors to this smp.set decorated function have been moved to the current device in the train function above. The model does not need to be moved to the current device. The library automatically moves the part of the model assigned to a rank to its GPU.

@smp.step  
def train_step(model, data, target):  
    output = model(data)  
    loss = F.nll_loss(output, target, reduction="mean")  
    model.backward(loss)  
    return output, loss  

Unsupported framework features

The following PyTorch features are unsupported by SageMaker's model parallelism library: