Checkpointing and Fine-Tuning a Model with Model Parallelism (original) (raw)

The SageMaker model parallelism library provides checkpointing APIs to save the model state and the optimizer state split by the various model parallelism strategies, and to load checkpoints for continuous training from where you want to restart training and fine-tune. The APIs also support options to save the model and optimizer states partially or fully.

Topics

Checkpointing a distributed model

Choose one of the following topics depending on the framework between PyTorch and TensorFlow and the version of the SageMaker model parallelism library you use.

Topics

Checkpointing a distributed PyTorch model (for the SageMaker model parallelism library v1.10.0 and later)

The SageMaker model parallelism library provides checkpoint APIs to save and load full or partial checkpoints of the distributed model state and its optimizer state.

Note

This checkpointing method is recommended if you use PyTorch and the SageMaker model parallelism library v1.10.0 or later.

Partial checkpointing

To save checkpoints of a model trained with model parallelism, use the smdistributed.modelparallel.torch.save_checkpoint API with the partial checkpointing option set to true (partial=True). This saves each model partition individually. In addition to the model and the optimizer state, you can also save any additional custom data through theuser_content argument. The checkpointed model, optimizer, and user content are saved as separate files. The save_checkpoint API call creates checkpoint folders in the following structure.

- path
  - ${tag}_partial (folder for partial checkpoints)
    - model_rankinfo.pt
    - optimizer_rankinfo.pt
    - fp16_states_rankinfo.pt
    - user_content.pt
  - $tag (checkpoint file for full checkpoints)
  - user_content_$tag (user_content file for full checkpoints)
  - newest (a file that indicates the newest checkpoint)

To resume training from partial checkpoints, use the smdistributed.modelparallel.torch.resume_from_checkpoint API with partial=True, and specify the checkpoint directory and the tag used while saving the partial checkpoints. Note that the actual loading of model weights happens after model partitioning, during the first run of thesmdistributed.modelparallel.torch.step-decorated training step function.

When saving a partial checkpoint, the library also saves the model partition decision as files with .pt file extension. Conversely, when resuming from the partial checkpoint, the library loads the partition decision files together. Once the partition decision is loaded, you can't change the partition.

The following code snippet shows how to set the checkpoint APIs in a PyTorch training script.

import smdistributed.modelparallel.torch as smp

model = ...
model = smp.DistributedModel(model)
optimizer = ...
optimizer = smp.DistributedOptimizer(optimizer)
user_content = ...     # additional custom data
checkpoint_path = "/opt/ml/checkpoint/model_parallel"

# Save a checkpoint.
smp.save_checkpoint(
    path=checkpoint_path,
    tag=f"total_steps{total_steps}",
    partial=True,
    model=model,
    optimizer=optimizer,
    user_content=user_content
    num_kept_partial_checkpoints=5
)

# Load a checkpoint.
# This automatically loads the most recently saved checkpoint.
smp_checkpoint = smp.resume_from_checkpoint(
    path=checkpoint_path, 
    partial=True
)

Full checkpointing

To save the final model artifact for inference purposes, use thesmdistributed.modelparallel.torch.save_checkpoint API withpartial=False, which combines the model partitions to create a single model artifact. Note that this does not combine the optimizer states.

To initialize training with particular weights, given a full model checkpoint, you can use the smdistributed.modelparallel.torch.resume_from_checkpoint API withpartial=False. Note that this does not load optimizer states.

Note

With tensor parallelism, in general, the state_dict must be translated between the original model implementation and theDistributedModel implementation. Optionally, you can provide thestate_dict translation function as an argument to thesmdistributed.modelparallel.torch.resume_from_checkpoint. However, for Supported Models Out of the Box, the library takes care of this translation automatically.

The following code shows an example of how to use the checkpoint APIs for fully checkpointing a PyTorch model trained with model parallelism.

import smdistributed.modelparallel.torch as smp

model = ...
model = smp.DistributedModel(model)
optimizer = ...
optimizer = smp.DistributedOptimizer(optimizer)
user_content = ...     # additional custom data
checkpoint_path = "/opt/ml/checkpoint/model_parallel"

# Save a checkpoint.
smp.save_checkpoint(
    path=checkpoint_path,
    tag=f"total_steps{total_steps}",
    partial=False,
    model=model,
    optimizer=optimizer,
    user_content=user_content
    num_kept_partial_checkpoints=5
)

# Load a checkpoint.
# This automatically loads the most recently saved checkpoint.
smp_checkpoint = smp.resume_from_checkpoint(
    path=checkpoint_path, 
    partial=False
)

Checkpointing a distributed PyTorch model (for the SageMaker model parallelism library between v1.6.0 and v1.9.0)

The SageMaker model parallelism library provides Python functions for saving partial or full checkpoints for training jobs with tensor parallelism. The following procedure shows how to use smp.save() and smp.load() to save and load a checkpoint when you use tensor parallelism.

Note

This checkpointing method is recommended if you use PyTorch, Tensor Parallelism, and the SageMaker model parallelism library between v1.6.0 and v1.9.0.

  1. Prepare a model object and wrap it with the library's wrapper functionsmp.DistributedModel().
model = MyModel(...)  
model = smp.DistributedModel(model)  
  1. Prepare an optimizer for the model. A set of model parameters is an iterable argument required by optimizer functions. To prepare a set of model parameters, you must process model.parameters() to assign unique IDs to individual model parameters.
    If there are parameters with duplicated IDs in the model parameter iterable, loading the checkpointed optimizer state fails. To create an iterable of model parameters with unique IDs for your optimizer, see the following:
unique_params = []  
unique_params_set = set()  
for p in model.parameters():  
  if p not in unique_params_set:  
    unique_params.append(p)  
    unique_params_set.add(p)  
del unique_params_set  
optimizer = MyOpt(unique_params, ...)  
  1. Wrap the optimizer using the library's wrapper functionsmp.DistributedOptimizer().
optimizer = smp.DistributedOptimizer(optimizer)  
  1. Save the model and the optimizer state using smp.save(). Depending on how you want to save checkpoints, choose one of the following two options:
    • Option 1: Save a partial model on each mp_rank for a single MP_GROUP.
    model_dict = model.local_state_dict() # save a partial model  
    opt_dict = optimizer.local_state_dict() # save a partial optimizer state  
    # Save the dictionaries at rdp_rank 0 as a checkpoint  
    if smp.rdp_rank() == 0:  
        smp.save(  
            {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},  
            f"/checkpoint.pt",  
            partial=True,  
        )  

    With tensor parallelism, the library saves checkpointed files named in the following format:checkpoint.pt_{pp_rank}_{tp_rank}.

    Note

    With tensor parallelism, make sure you set the if statement asif smp.rdp_rank() == 0 instead of if smp.dp_rank() == 0. When the optimizer state is sharded with tensor parallelism, all reduced-data parallel ranks must save their own partition of the optimizer state. Using a wrong if statement for checkpointing might result in a stalling training job. For more information about using if smp.dp_rank() == 0 without tensor parallelism, see General Instruction for Saving and Loading in the_SageMaker Python SDK documentation_.

    • Option 2: Save the full model.
    if smp.rdp_rank() == 0:  
        model_dict = model.state_dict(gather_to_rank0=True) # save the full model  
        if smp.rank() == 0:  
            smp.save(  
                {"model_state_dict": model_dict},  
                "/checkpoint.pt",  
                partial=False,  
            )  
    Note

    Consider the following for full checkpointing:
    * If you set gather_to_rank0=True, all ranks other than 0 return empty dictionaries.
    * For full checkpointing, you can only checkpoint the model. Full checkpointing of optimizer states is currently not supported.
    * The full model only needs to be saved atsmp.rank() == 0.

  2. Load the checkpoints using smp.load(). Depending on how you checkpointed in the previous step, choose one of the following two options:
    • Option 1: Load the partial checkpoints.
    checkpoint = smp.load("/checkpoint.pt", partial=True)  
    model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False)  
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])  

    You can set same_partition_load=True inmodel.load_state_dict() for a faster load, if you know that the partition will not change.

    • Option 2: Load the full checkpoints.
    if smp.rdp_rank() == 0:  
        checkpoint = smp.load("/checkpoint.pt", partial=False)  
        model.load_state_dict(checkpoint["model_state_dict"])  

    The if smp.rdp_rank() == 0 condition is not required, but it can help avoid redundant loading among differentMP_GROUPs. Full checkpointing optimizer state dict is currently not supported with tensor parallelism.

Checkpointing a distributed TensorFlow model

To save a TensorFlow model while training with model parallelism, use the following functions provided by the SageMaker model parallelism library.

Fine-tuning a distributed model

The fine-tuning needs to be configured in your training script. The following code snippet shows an example structure of a training script using the AutoModelForCausalLM class of Hugging Face Transformers with modifications for registering the smdistributed.model.parallel.torch modules and settings for fine-tuning.

Note

Fine-tuning a distributed transformer (a Transformer model wrapped bysmp.DistributedModel()) with the smp.delayed_param_initialization function activated requires the fine-tuning job to be configured with an FSx for Lustre file system. In cases where you want to fine-tune a large-scale model with the delayed parameter initialization option, you should set up an FSx for Lustre file system.

import argparse
from transformers import AutoModelForCausalLM
import smdistributed.modelparallel
import smdistributed.modelparallel.torch as smp

def parse_args():

    parser = argparse.ArgumentParser()

    # set an arg group for model
    model_grp = parser.add_argument_group(
        title="model", description="arguments to describe model configuration"
    )

    ... # set up numerous args to parse from the configuration dictionary to the script for training

    # add arg for activating fine-tuning
    model_grp.add_argument(
        "--fine_tune",
        type=int,
        default=0,
        help="Fine-tune model from checkpoint or pretrained model",
    )

def main():
    """Main function to train GPT."""
    args = parse_args()

    ... # parse numerous args

    if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0:
        pretrained_model = AutoModelForCausalLM.from_pretrained(
            args.model_name or args.model_dir
        )
        model_state_dict = pretrained_model.state_dict()
        path = os.path.join(args.model_dir, "fullmodel.pt")
        torch.save(model_state_dict, path)

    # create a Transformer model and wrap by smp.model_creation() 
    # with options to configure model parallelism parameters offered by SageMaker AI
    with smp.model_creation(
        tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0,
        zero_init=args.use_distributed_transformer == 0,
        dtype=dtype,
        distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1,
        use_alibi=args.alibi > 0,
        attention_in_fp32=args.attention_in_fp32 > 0,
        fp32_residual_addition=args.residual_addition_in_fp32 > 0,
        query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1,
        fused_softmax=args.fused_softmax > 0,
        fused_dropout=args.fused_dropout > 0,
        fused_bias_gelu=args.fused_bias_gelu > 0,
        flash_attention=args.flash_attention > 0,
    ):
        if args.fine_tune > 0 and args.delayed_param == 0:
            model = AutoModelForCausalLM.from_pretrained(
                args.model_name or args.model_dir
            )
        else:
            model = AutoModelForCausalLM.from_config(model_config)

    # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism
    model = smp.DistributedModel(
        model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation
    )

    # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism
    optimizer= ... # define an optimizer
    optimizer = smp.DistributedOptimizer(
        optimizer,
        static_loss_scale=None,
        dynamic_loss_scale=True,
        dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2},
    )

    # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model
    if args.fine_tune > 0 and args.delayed_param > 0:
        smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)

For a complete example of training scripts and Jupyter notebooks, see the GPT-2 examples for PyTorch in the SageMaker AI Examples GitHub repository.