PyTorch-Neuron trace python API — AWS Neuron Documentation (original) (raw)

Contents

This document is relevant for: Inf1

PyTorch-Neuron trace python API#

The PyTorch-Neuron trace Python API provides a method to generate PyTorch models for execution on Inferentia, which can be serialized as TorchScript. It is analogous to torch.jit.trace() function in PyTorch.

torch_neuron.trace(model, example_inputs, **kwargs)#

The torch_neuron.trace() method sends operations to the Neuron-Compiler (neuron-cc) for compilation and embeds compiled artifacts in a TorchScript graph.

Compilation can be done on any EC2 machine with sufficient memory and compute resources. c5.4xlarge or larger is recommended.

Options can be passed to Neuron compiler via the compile function. SeeNeuron compiler CLI Reference Guide (neuron-cc)for more information about compiler options.

This function partitions nodes into operations that are supported by Neuron and operations which are not. Operations which are not supported by Neuron are run on CPU. Graph partitioning can be controlled by thesubgraph_builder_function, minimum_segment_size, and fallbackparameters (See below). By default all supported operations are compiled and run on Neuron.

The compiled graph can be saved using the torch.jit.save() function and restored using torch.jit.load() function for inference on Inf1 instances. During inference, the previously compiled artifacts will be loaded into the Neuron Runtime for inference execution.

Required Arguments

Parameters:

Optional Keyword Arguments

Keyword Arguments:

Returns:

The traced ScriptModule with embedded compiled neuron sub-graphs. Operations in this module will run on Neuron unless they are not supported by Neuron or manually partitioned to run on CPU.

Note that in torch<1.8 This would return aScriptFunction if the input was function type.

Return type:

ScriptModule, ScriptFunction

class torch_neuron.Optimization#

A set of optimization passes that can be applied to the model.

FLOAT32_TO_FLOAT16#

A post-processing pass that converts all torch.float32 tensors to torch.float16 tensors. The advantage to this optimization pass is that input/output tensors will be type cast. This reduces the amount of data that will be copied to and from Inferentia hardware. The resulting traced model will accept bothtorch.float32 and torch.float16 inputs where the model used torch.float32 inputs during tracing. It is only beneficial to enable this optimization if the throughput of a model is highly dependent upon data transfer speed. This optimization is not recommended if the final application will use torch.float32inputs since the torch.float16 type cast will occur on CPU during inference.

Example Usage#

Function Compilation#

import torch import torch_neuron

def foo(x, y): return 2 * x + y

Run foo with the provided inputs and record the tensor operations

traced_foo = torch.neuron.trace(foo, (torch.rand(3), torch.rand(3)))

traced_foo can now be run with the TorchScript interpreter or saved

and loaded in a Python-free environment

torch.jit.save(traced_foo, 'foo.pt') traced_foo = torch.jit.load('foo.pt')

Module Compilation#

import torch import torch_neuron import torch.nn as nn

class Net(nn.Module): def init(self): super(Net, self).init() self.conv = nn.Conv2d(1, 1, 3)

def forward(self, x):
    return self.conv(x) + 1

n = Net() n.eval()

inputs = torch.rand(1, 1, 3, 3)

Trace a specific method and construct ScriptModule with

a single forward method

neuron_forward = torch.neuron.trace(n.forward, inputs)

Trace a module (implicitly traces forward) and constructs a

ScriptModule with a single forward method

neuron_net = torch.neuron.trace(n, inputs)

Pre-Trained Model Compilation#

The following is an example usage of the compilation Python API, with default compilation arguments, using a pretrained torch.nn.Module:

import torch import torch_neuron from torchvision import models

Load the model and set it to evaluation mode

model = models.resnet50(pretrained=True) model.eval()

Compile with an example input

image = torch.rand([1, 3, 224, 224]) model_neuron = torch.neuron.trace(model, image)

Compiling models with torch.jit.trace kwargs#

This example uses the strict=False flag to compile a model with dictionary outputs. Similarly, any other keyword argument oftorch.jit.trace() can be passed directly totorch_neuron.trace() so that it is passed to the underlying trace call.

import torch import torch_neuron import torch.nn as nn

class Model(nn.Module): def init(self): super(Model, self).init() self.conv = nn.Conv2d(1, 1, 3)

def forward(self, x):
    return {'conv': self.conv(x) + 1}

model = Model() model.eval()

inputs = torch.rand(1, 1, 3, 3)

use the strict=False kwarg to compile a model with dictionary outputs

the model output format does not change

model_neuron = torch.neuron.trace(model, inputs, strict=False)

Dynamic Batching#

This example uses the optional dynamic_batch_size option in order to support variable sized batches at inference time.

import torch import torch_neuron from torchvision import models

Load the model and set it to evaluation mode

model = models.resnet50(pretrained=True) model.eval()

Compile with an example input of batch size 1

image = torch.rand([1, 3, 224, 224]) model_neuron = torch.neuron.trace(model, image, dynamic_batch_size=True)

Execute with a batch of 7 images

batch = torch.rand([7, 3, 224, 224]) results = model_neuron(batch)

Manual Partitioning#

The following example uses the optional subgraph_builder_functionparameter to ensure that only a specific convolution layer is compiled to Neuron. The remaining operations are executed on CPU.

import torch import torch_neuron import torch.nn as nn

class ExampleConvolutionLayer(nn.Module): def init(self): super().init() self.conv = nn.Conv2d(1, 1, 3)

def forward(self, x):
    return self.conv(x) + 1

class Model(nn.Module): def init(self): super().init() self.layer = ExampleConvolutionLayer()

def forward(self, x):
    return self.layer(x) * 100

def subgraph_builder_function(node) -> bool: """Select if the node will be included in the Neuron graph"""

# Node names are tuples of Module names.
if 'ExampleConvolutionLayer' in node.name:
    return True

# Ignore all operations not in the example convolution layer
return False

model = Model() model.eval()

inputs = torch.rand(1, 1, 3, 3)

Log output shows that aten::_convolution and aten::add are compiled

but aten::mul is not. This will seamlessly switch between Neuron/CPU

execution in a single graph.

neuron_model = torch_neuron.trace( model, inputs, subgraph_builder_function=subgraph_builder_function )

Separate Weights#

This example uses the optional separate_weights option in order to support compilation of models greater than 1.9GB.

import torch import torch_neuron from torchvision import models

Load the model

model = models.resnet50(pretrained=True) model.eval()

Compile with an example input

image = torch.rand([1, 3, 224, 224]) #the models' output format does not change model_neuron = torch.neuron.trace(model, image, separate_weights=True)

This document is relevant for: Inf1