Training APIs — AWS Neuron Documentation (original) (raw)

Training APIs#

Table of contents

Neuronx-Distributed Training APIs:#

In Neuronx-Distributed, we provide a series of APIs under neuronx_distributed directly that helps user to apply optimizations in NxD Core easily. These APIs cover configuration, model/optimizer initialization and saving/loading checkpoint.

Initialize NxD Core config:#

def neuronx_distributed.trainer.neuronx_distributed_config( tensor_parallel_size=1, pipeline_parallel_size=1, pipeline_config=None, optimizer_config=None, activation_checkpoint_config=None, pad_model=False, sequence_parallel=False, model_init_config=None, lora_config=None, )

This method initializes NxD Core training config and initialize model parallel. This config maintains all optimization options of the distributed training, and it’s a global config (the same for all processes).

Parameters:

Initialize NxD Core Model Wrapper:#

def neuronx_distributed.trainer.initialize_parallel_model(nxd_config, model_fn, *model_args, **model_kwargs)

This method initialize NxD Core model wrapper, return a wrapped model that can be used as a regular torch.nn.Module, while has all the model optimizations in config applied. This wrapper is designed to hide the complexity of optimizations such as pipeline model parallel, so that users can simply use the wrapped model as the unwrapped version.

Parameters:

Model wrapper class and its methods:

class neuronx_distributed.trainer.model.NxDModel(torch.nn.Module): def local_module(self): # return the unwrapped local module

def run_train(self, *args, **kwargs):
    # method to run one iteration, when pipeline parallel enabled,
    # user have to use this instead of forward+backward

def named_parameters(self, *args, **kwargs):
    # only return parameters on local rank.
    # same for `parameters`, `named_buffers`, `buffers`

def named_modules(self, *args, **kwargs):
    # only return modules on local rank.
    # same for `modules`, `named_children`, `children`

Note

As a short cut, users can call model.config or model.dtype from wrapped model if original model is hugging face transformers pre-trained model.

Initialize NxD Core Optimizer Wrapper:#

def neuronx_distributed.trainer.initialize_parallel_optimizer(nxd_config, optimizer_class, parameters, **defaults)

This method initialize NxD Core optimizer wrapper, return a wrapped optimizer that can be used as a regular torch.optim.Optimizer, while has all the optimizer optimizations in config applied.

This optimizer wrapper is inherited from toch.optim.Optimizer. It takes in the nxd_config and configures the optimizer to work with different distributed training regime.

The step method of the wrapped optimizer contains necessary all-reduce operations and grad clipping. Other methods and variables work the same as the unwrapped optimizer.

Parameters:

Enable LoRA fine-tuning:#

LoRA model wrapper

class LoRAModel(module, LoraConfig)

Parameters:

The flags in LoraConfig to initialize LoRA adapter:

Usage:

We first define the LoRA configuration for fine-tuning. Suppose the target modules is [q_proj, v_proj, k_proj], it indicates that LoRA will be appied to modules whose name includes any of the keywords. An example is

lora_config = neuronx_distributed.modules.lora.LoraConfig( enable_lora=True, lora_rank=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q_proj", "v_proj", "k_proj"], )

You can enable LoRA fine-tuning like below

nxd_config = neuronx_distributed.neuronx_distributed_config( ... lora_config=lora_config, ) model = neuronx_distributed.initialize_parallel_model(nxd_config, ...)

Then the NxD model will be initialized with LoRA adapter enabled.

Save Checkpoint:#

Method to save checkpoint, return None.

This method saves checkpoints for model, optimizer, scheduler and user contents sequentially. Model states are saved on data parallel rank-0 only. When ZeRO-1 optimizer is not turned on, optimizer states are also saved like this; while when ZeRO-1 optimizer is turned on, states are saved on all ranks. Scheduler and user contents are saved on master rank only. Besides, users can use use_xser=True to boost saving performance and avoid host OOM. It’s achieved by saving tensors one by one simultaneously and keeping the original data structure. However, the resulted checkpoint cannot be loaded using load api of PyTorch. Users can also use async_save=True to further boost saving performance. It’s achieved by saving tensors in separate processes along with computation. Setting async_save to true will result in more host memory being used, therefore increase the risk of application crash due to system ran out of memory.

def neuronx_distributed.trainer.save_checkpoint( path, tag="", model=None, optimizer=None, scheduler=None, user_content=None, num_workers=8, use_xser=False, num_kept_ckpts=None, async_save=False, zero1_optimizer=False )

Parameters:

Save LoRA Checkpoint:

NxD also uses neuronx_distributed.trainer.save_checkpoint() to save LoRA models, but it can set save_lora_base and merge_lora in LoraConfig to specify how to save LoRA checkpoint. There are three modes for LoRA checkpoint saving:

Other than the adapter, NxD also needs to save the LoRA configuration file for LoRA loading. The configuration can be saved into the same checkpoint with the adapter, or saved as a seperately json file.

Note that if LoRA configuration file is saved separately, it is named as lora_adapter/adapter_config.json.

A configuration example to save the LoRA adapter only is

lora_config = neuronx_distributed.modules.lora.LoraConfig( ... save_lora_base=False, merge_lora=False, save_lora_config_adapter=True, )

Load Checkpoint:#

Method to load checkpoint saved by save_checkpoint, return user contents if exists otherwise None. If tag not provided, will try to use the newest tag tracked by save_checkpoint.

Note that the checkpoint to be loaded must have the same model parallel degrees as in current use, and if ZeRO-1 optimizer is used, must use the same data parallel degrees.

def neuronx_distributed.trainer.load_checkpoint( path, tag=None, model=None, optimizer=None, scheduler=None, num_workers=8, strict=True, )

Parameters:

This is done to avoid the host OOM, range: 1-32. - strict (bool): whether to use strict mode when loading model checkpoint. Default: True.

Load LoRA Checkpoint:

NxD loads LoRA checkpoints by setting flags in LoraConfig.

An example is:

lora_config = LoraConfig( enable_lora=True, load_lora_from_ckpt=True, lora_save_dir=checkpoint_dir, # checkpoint path lora_load_tag=tag, # sub-directory under checkpoint path ) nxd_config = nxd.neuronx_distributed_config( ... lora_config=lora_config, ) model = nxd.initialize_parallel_model(nxd_config, ...)

The NxD model with be initialized with LoRA enabled and LoRA weights loaded. LoRA-related configurations are the same as the LoRA adapter checkpoint.

Sample usage:

import neuronx_distributed as nxd

create config

nxd_config = nxd.neuronx_distributed_config( tensor_parallel_size=8, optimizer_config={"zero_one_enabled": True, "grad_clipping": True, "max_grad_norm": 1.0}, )

wrap model

model = nxd.initialize_parallel_model(nxd_config, get_model)

wrap optimizer

optimizer = nxd.initialize_parallel_optimizer(nxd_config, AdamW, model.parameters(), lr=1e-3)

... (training loop): loss = model.run_train(inputs) optimizer.step()

...

loading checkpoint (auto-resume)

user_content = nxd.load_checkpoint( "ckpts", model=model, optimizer=optimizer, scheduler=scheduler, ) ...

saving checkpoint

nxd.save_checkpoint( "ckpts", nxd_config=nxd_config, model=model, optimizer=optimizer, scheduler=scheduler, user_content={"total_steps": total_steps}, )

Modules:#

GQA-QKV Linear Module:#

class neuronx_distributed.modules.qkv_linear.GQAQKVColumnParallelLinear( input_size, output_size, bias=True, gather_output=True, sequence_parallel_enabled=False, dtype=torch.float32, device=None, kv_size_multiplier=1, fuse_qkv=True)

This module parallelizes the Q,K,V linear projections using ColumnParallelLinear layers. Instead of using 3 different linear layers, we can replace it with a single QKV module. In case of GQA module, the number of Q attention heads are N times more than the number of K and V attention heads. The K and V attention heads are replicated after projection to match the number of Q attention heads. This helps to reduce the K and V weights and is useful especially during inference. However, in case of training these modules, it restricts the tensor-parallel degree that can be used, since the attention heads should be divisible by tensor-parallel degree. Hence, to mitigate this bottleneck, the GQAQKVColumnParallelLinear takes in a kv_size_multiplierargument. The module would replicate the K and V weights kv_size_multiplier times thereby allowing you to use higher tensor-parallel degree. Note: here instead of replicating the projection N/tp_degree times, we end of replicating the weights kv_size_multiplier times. This would produce the same result, allow you to use higher tp_degree degree, however, it would result in extra memory getting consumed.

Parameters:

Checkpointing:#

These are set of APIs for saving and loading the checkpoint. These APIs take care of saving and loading the shard depending the tensor parallel rank of the worker.

Save Checkpoint:#

def neuronx_distributed.parallel_layers.save(state_dict, save_dir, save_serially=True, save_xser: bool=False, down_cast_bf16=False)

Note

This method will be deprecated, use neuronx_distributed.trainer.save_checkpoint instead.

This API will save the model from each tensor-parallel rank in the save_dir . Only workers with data parallel rank equal to 0 would be saving the checkpoints. Each tensor parallel rank would be creating atp_rank_ii_pp_rank_ii folder inside save_dir and each ones saves its shard in the tp_rank_ii_pp_rank_ii folder. If save_xser is enabled, the folder name would be tp_rank_ii_pp_rank_ii.tensorsand there will be a ref data file named as tp_rank_ii_pp_rank_ii in save_dir for each rank.

Parameters:

Load Checkpoint#

def neuronx_distributed.parallel_layers.load( load_dir, model_or_optimizer=None, model_key='model', load_xser=False, sharded=True)

Note

This method will be deprecated, use neuronx_distributed.trainer.load_checkpoint instead.

This API will automatically load checkpoint depending on the tensor parallel rank. For large models, one should pass the model object to the load API to load the weights directly into the model. This could avoid host OOM, as the load API would load the checkpoints for one tensor parallel rank at a time.

Parameters:

Gradient Clipping:#

With tensor parallelism, we need to handle the gradient clipping as we have to accumulate the total norm from all the tensor parallel ranks. This should be handled by the following API

def neuronx_distributed.parallel_layers.clip_grad_norm( parameters, max_norm, norm_type=2)

Parameters:

Neuron Zero1 Optimizer:#

In Neuronx-Distributed, we built a wrapper on the Zero1-Optimizer present in torch-xla.

class NeuronZero1Optimizer(Zero1Optimizer)

This wrapper takes into account the tensor-parallel degree and computes the grad-norm accordingly. It also provides two APIs: save_sharded_state_dict and load_sharded_state_dict. As the size of the model grows, saving the optimizer state from a single rank can result in OOMs. Hence, the api to save_sharded_state_dict can allow saving states from each data-parallel rank. To load this sharded optimizer state, there is a corresponding load_sharded_state_dict that allows each rank to pick its corresponding shard from the checkpoint directory.

optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ]

optimizer = NeuronZero1Optimizer( optimizer_grouped_parameters, AdamW, lr=flags.lr, pin_layout=False, sharding_groups=parallel_state.get_data_parallel_group(as_list=True), grad_norm_groups=parallel_state.get_tensor_model_parallel_group(as_list=True), )

The interface is same as Zero1Optimizer in torch-xla

save_sharded_state_dict(output_dir, save_serially = True)

Note

This method will be deprecated, use neuronx_distributed.trainer.save_checkpoint instead.

Parameters:

load_sharded_state_dict(output_dir, num_workers_per_step = 8)

Note

This method will be deprecated, use neuronx_distributed.trainer.load_checkpoint instead.

Parameters:

Neuron PyTorch-Lightning#

Neuron PyTorch-Lightning is currently based on Lightning version 2.1.0, and will eventually be upstreamed Lightning-AI code base

Neuron Lightning Module#

Inherited from LightningModule

class neuronx_distributed.lightning.NeuronLTModule( model_fn: Callable, nxd_config: Dict, opt_cls: Callable, scheduler_cls: Callable, model_args: Tuple = (), model_kwargs: Dict = {}, opt_args: Tuple = (), opt_kwargs: Dict = {}, scheduler_args: Tuple = (), scheduler_kwargs: Dict = {}, grad_accum_steps: int = 1, log_rank0: bool = False, manual_opt: bool = True, )

Parameters:

Neuron XLA Strategy#

Inherited from XLAStrategy

class neuronx_distributed.lightning.NeuronXLAStrategy( nxd_config: Dict = None, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, save_load_xser: bool = True, )

Parameters:

Neuron XLA Precision Plugin#

Inherited from XLAPrecisionPlugin

class neuronx_distributed.lightning.NeuronXLAPrecisionPlugin

Neuron TQDM Progress Bar#

Inherited from TQDMProgressBar

class neuronx_distributed.lightning.NeuronTQDMProgressBar

Neuron TensorBoard Logger#

Inherited from TensorBoardLogger

class neuronx_distributed.lightning.NeuronTensorBoardLogger(save_dir)

Parameters:

This document is relevant for: Inf2, Trn1, Trn2