Distributed Strategies APIs — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2

Distributed Strategies APIs#

NeuronX Distributed Core (NxD Core) is XLA based library for distributed training and inference on Neuron devices. As part of this library, we support 3D parallelism: Tensor-Parallelism, Pipeline-Parallelism and Data-Parallelism. We also support Zero1 optimizer to shard the optimizer weights. To support tensor-parallelism on Neuron, we adopted the Apex Library built for CUDA devices. We modified the implementations to work with XLA. This document enlist the different APIs and modules provided by the library

Table of contents

Parallel Model State:#

Initialize Model Parallelism:#

def neuronx_distributed.parallel_state.initialize_model_parallel( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, )

This module would initialize the distributed model training and allows users to set the number of tensor_parallel world size.

Parameters:

Other helper APIs:#

Parallel Layers:#

Majority of parameters within the transformer based model reside in the Embedding and Linear layers. Hence, to reduce the number of parameters on a single device because of these layers, we provided sharded Embedding and Linear layers.

Parallel Embedding:#

class neuronx_distributed.parallel_layers.ParallelEmbedding( num_embeddings, embedding_dim, init_method=init.normal_, dtype=torch.float32, device=None)

This module is intended to replace torch.nn.Embedding . In cases where the vocab size is too large, we can shard the Embedding table across workers. Note: The embedding table would be sharded across all the tensor-parallel workers.

Parameters:

ColumnParallel Linear Layer:#

class neuronx_distributed.parallel_layers.ColumnParallelLinear( input_size, output_size, bias=True, gather_output=True, sequence_parallel_enabled=False, dtype=torch.float32, device=None)

This module would perform a Column wise partition of the weight matrix. Linear layer is defined as Y = XA + b , here A is parallelized along second dimension as A = [A_1, A_2 .... A_p] . Note: This layer is designed to operate on 3-dimensional inputs.

Parameters:

RowParallel Linear Layer:#

class neuronx_distributed.parallel_layers.RowParallelLinear( input_size, output_size, bias=True, input_is_parallel=False, sequence_parallel_enabled=False, dtype=torch.float32, device=False )

The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second. Note: This layer is designed to operate on 3-dimensional inputs.

Parameters:

Padding Tensor-Parallel Layers#

def neuronx_distributed.parallel_layers.pad.pad_model( model, tp_degree, n_heads, wrapped_classes=(), pad_hook_fn=None)

Pads a generic model to function to a desired tensor parallelism degree by padding the number of attention heads. Returns the original model modified with padding. Uses 1-axis padding strategy: pads the sharded dim of the ParallelLinear layers to the size it would have been for the padded number of heads.

Parameters:

Usage:

When modifying the Attention layer, typically you must divide by TP degree like so:

self.num_heads = neuronx_dist_utils.divide(self.num_heads, get_tensor_model_parallel_size())

This line must be modified like so:

self.num_heads = neuronx_dist_utils.divide( self.num_heads + get_number_of_extra_heads(self.num_heads, get_tensor_model_parallel_size()), get_tensor_model_parallel_size())

Then, after initializing the model, you must call this wrapper:

model = get_model(config=desired_config) model = pad_model(model, tp_degree=32, desired_config.num_heads) # Use the model as desired after this point

You can specify a specific layer or class for your model to pad, so you aren’t unnecessarily padding. Typically, this layer will be your Attention layer

model = pad_model(model, tp_degree=32, desired_config.num_heads, wrapped_classes=[MyAttention])

You can also specify a pad_hook_fn, to be called whenever encountering an instance of wrapped_class, passing in said instance as a parameter, along with the tgt_src_ratio (num_heads_padded / num_heads).

def my_hook(attention_to_pad, tgt_src_ratio): attention_to_pad.split_size = int(model.split_size * tgt_src_ratio) model = pad_model( model, tp_degree=32, desired_config.num_heads, wrapped_classes=[MyAttention], pad_hook_fn=my_hook )

Loss functions:#

When you shard the final MLP layer using tensor-parallelism, instead of recollecting all the outputs from each TP rank, we can use the ParallelCrossEntropy loss function. This function would take the parallel logits produced by final parallel MLP and produce a loss by taking into account that the logits are sharded across multiple workers.

def neuronx_distributed.parallel_layers.loss_functions.parallel_cross_entropy( parallel_logits, labels, label_smoothing=0.0)

Parameters:

Pipeline parallelism:#

Neuron Distributed Pipeline Model#

class NxDPPModel( module: torch.nn.Module, transformer_layer_cls: Optional[Any] = None, num_microbatches: int = 1, virtual_pipeline_size: int = 1, output_loss_value_spec: Optional[Union[Dict, Tuple]] = None, return_mb_loss: bool = False, broadcast_and_average_loss: bool = False, pipeline_cuts: Optional[List[str]] = None, input_names: Optional[List[str]] = None, leaf_module_cls: Optional[List[Any]] = None, autowrap_functions: Optional[Tuple[ModuleType]] = None, autowrap_modules: Optional[Tuple[Callable, ...]] = None, tracer_cls: Optional[Union[str, Any]] = None, param_init_fn: Optional[Any] = None, trace_file_path: Optional[str] = None, use_zero1_optimizer: bool = False, auto_partition: Optional[bool] = False, deallocate_pipeline_outputs: bool = False, )

Parameters:

Common used APIs#

NxDPPModel.run_train(**kwargs)

Train the model with PP schedule, which will run both forward and backward in a PP manner. The kwargs should be the same as the input_names provided to the trace function. Will output the loss that provided by user from output_loss_value_spec.

NxDPPModel.run_eval(**kwargs)

Eval the model with PP schedule, which will run forward only. The kwargs should be the same as the input_names provided to the trace function. Will output the loss that provided by user from output_loss_value_spec.

NxDPPModel.local_named_parameters(**kwargs)

The parameters that are local to this PP rank. This must be called after the model is partitioned.

NxDPPModel.local_named_modules(**kwargs)

This document is relevant for: Inf2, Trn1, Trn2