torch.neuron.DataParallel API — AWS Neuron Documentation (original) (raw)

Contents

This document is relevant for: Inf1

torch.neuron.DataParallel API#

The torch.neuron.DataParallel() Python API implements data parallelism onScriptModule models created by thePyTorch-Neuron trace python API. This function is analogous to DataParallel in PyTorch. The Data Parallel Inference on Torch Neuron application note provides an overview of how torch.neuron.DataParallel() can be used to improve the performance of inference workloads on Inferentia.

torch.neuron.DataParallel(model, device_ids=None, dim=0)#

Applies data parallelism by replicating the model on available NeuronCores and distributing data across the different NeuronCores for parallelized inference.

By default, DataParallel will use all available NeuronCores allocated for the current process for parallelism. DataParallel will apply parallelism on dim=0 if dim is not specified.

DataParallel automatically enablesdynamic batching on eligible models if dim=0. Dynamic batching can be dsiabled usingtorch.neuron.DataParallel.disable_dynamic_batching(). If dynamic batching is not enabled, the batch size at compilation-time must be equal to the batch size at inference-time divided by the number of NeuronCores being used. Specifically, the following must be true when dynamic batching is disabled:input.shape[dim] / len(device_ids) == compilation_input.shape[dim]. DataParallel will throw a warning if dynamic batching cannot be enabled.

DataParallel will try load all of a model’s NEFFs onto a single NeuronCore, only if all of the NEFFs can fit on a single NeuronCore. DataParallel does not currently support models that have been compiled with NeuronCore Pipeline.

torch.neuron.DataParallel() requires PyTorch >= 1.8.

Required Arguments

Parameters:

model (ScriptModule) – Model created by thePyTorch-Neuron trace python APIto be parallelized.

Optional Arguments

Parameters:

Attributes

Parameters:

torch.neuron.DataParallel.disable_dynamic_batching()#

Disables automatic dynamic batching on the DataParallel module. SeeDynamic batching disabledfor example of how DataParallel can be used with dynamic batching disabled. Use as follows:

model_parallel = torch.neuron.DataParallel(model_neuron) model_parallel.disable_dynamic_batching()

Note

device_ids uses per-process NeuronCore granularity and zero-based indexing. Per-process granularity means that each Python process “sees” its own view of the world. Specifically, this means that device_idsonly “sees” the NeuronCores that are allocated for the current process. Zero-based indexing means that each Python process will index its allocated NeuronCores starting at 0, regardless of the “global” index of the NeuronCores. Zero-based indexing makes it possible to redeploy the exact same code unchanged in different process. This behavior is analogous to the device_ids argument in the PyTorchDataParallel function.

As an example, assume DataParallel is run on an inf1.6xlarge, which contains four Inferentia chips each of which contains four NeuronCores:

Examples#

The following sections provide example usages of thetorch.neuron.DataParallel() module.

Default usage#

The default DataParallel use mode will replicate the model on all available NeuronCores in the current process. The inputs will be split on dim=0.

import torch import torch_neuron from torchvision import models

Load the model and set it to evaluation mode

model = models.resnet50(pretrained=True) model.eval()

Compile with an example input

image = torch.rand([1, 3, 224, 224]) model_neuron = torch.neuron.trace(model, image)

Create the DataParallel module

model_parallel = torch.neuron.DataParallel(model_neuron)

Create a batched input

batch_size = 5 image_batched = torch.rand([batch_size, 3, 224, 224])

Run inference with a batched input

output = model_parallel(image_batched)

Specifying NeuronCores#

The following example uses the device_ids argument to use the first three NeuronCores for DataParallel inference.

import torch import torch_neuron from torchvision import models

Load the model and set it to evaluation mode

model = models.resnet50(pretrained=True) model.eval()

Compile with an example input

image = torch.rand([1, 3, 224, 224]) model_neuron = torch.neuron.trace(model, image)

Create the DataParallel module, run on the first three NeuronCores

Equivalent to model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=[0, 1, 2])

model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=['nc:0', 'nc:1', 'nc:2'])

Create a batched input

batch_size = 5 image_batched = torch.rand([batch_size, 3, 224, 224])

Run inference with a batched input

output = model_parallel(image_batched)

DataParallel with dim != 0#

In this example we run DataParallel inference using four NeuronCores anddim = 2. Because dim != 0, dynamic batching is not enabled. Consequently, the DataParallel inference-time batch size must be four times the compile-time batch size. DataParallel will generate a warning that dynamic batching is disabled because dim != 0.

import torch import torch_neuron

Create an example model

class Model(torch.nn.Module): def init(self): super().init() self.conv = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
    return self.conv(x) + 1

model = Model() model.eval()

Compile with an example input

image = torch.rand([1, 3, 8, 8]) model_neuron = torch.neuron.trace(model, image)

Create the DataParallel module using 4 NeuronCores and dim = 2

model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=[0, 1, 2, 3], dim=2)

Create a batched input

Note that image_batched.shape[dim] / len(device_ids) == image.shape[dim]

batch_size = 4 * 8 image_batched = torch.rand([1, 3, batch_size, 8])

Run inference with a batched input

output = model_parallel(image_batched)

Dynamic batching#

In the following example, we use the torch.neuron.DataParallel() module to run inference using several different batch sizes without recompiling the Neuron model.

import torch import torch_neuron from torchvision import models

Load the model and set it to evaluation mode

model = models.resnet50(pretrained=True) model.eval()

Compile with an example input

image = torch.rand([1, 3, 224, 224]) model_neuron = torch.neuron.trace(model, image)

Create the DataParallel module

model_parallel = torch.neuron.DataParallel(model_neuron)

Create batched inputs and run inference on the same model

batch_sizes = [2, 3, 4, 5, 6] for batch_size in batch_sizes: image_batched = torch.rand([batch_size, 3, 224, 224])

# Run inference with a batched input
output = model_parallel(image_batched)

Dynamic batching disabled#

In the following example, we usetorch.neuron.DataParallel.disable_dynamic_batching() to disable dynamic batching. We provide an example of a batch size that will not work when dynamic batching is disabled as well as an example of a batch size that does work when dynamic batching is disabled.

import torch import torch_neuron from torchvision import models

Load the model and set it to evaluation mode

model = models.resnet50(pretrained=True) model.eval()

Compile with an example input

image = torch.rand([1, 3, 224, 224]) model_neuron = torch.neuron.trace(model, image)

Create the DataParallel module and use 4 NeuronCores

model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=[0, 1, 2, 3], dim=0)

Disable dynamic batching

model_parallel.disable_dynamic_batching()

Create a batched input (this won't work)

batch_size = 8 image_batched = torch.rand([batch_size, 3, 224, 224])

This will fail because dynamic batching is disabled and

image_batched.shape[dim] / len(device_ids) != image.shape[dim]

output = model_parallel(image_batched)

Create a batched input (this will work)

batch_size = 4 image_batched = torch.rand([batch_size, 3, 224, 224])

This will work because

image_batched.shape[dim] / len(device_ids) == image.shape[dim]

output = model_parallel(image_batched)

Full tutorial with torch.neuron.DataParallel#

For an end-to-end tutorial that uses DataParallel, see thePyTorch Resnet Tutorial.

This document is relevant for: Inf1