Comparison of Traced Inference versus XLA Lazy Tensor Inference (torch-neuronx) — AWS Neuron Documentation (original) (raw)

Contents

This document is relevant for: Inf2, Trn1, Trn2

Comparison of Traced Inference versus XLA Lazy Tensor Inference (torch-neuronx)#

Table of contents

Introduction#

Using torch-neuronx, there are two ways that a model can be executed for inference:

XLA Lazy Tensor Inference Mechanics#

XLA Lazy Tensor inference uses Just-In-Time (JIT) compilation for Neuron execution.

XLA Device execution uses the built-in torch-xla functionality with torchLazy Tensor to record torch operations using the xm.xla_device(). The graph of operations is sent to the neuronx-cc compiler upon callingxm.mark_step(). Finally the compiled graph is transferred to a NeuronCore and executed in the Neuron backend.

The initial model inference will be very slow since the model binary file in the Neuron Executable File Format (NEFF) will need to be generated by the compiler. Upon each subsequent call to a model, the application will re-execute the python, rebuild the graph, and check a cache to see if an existing NEFF file is available for the given graph before attempting to recompile.

The process of recording graph operations in python can become a bottleneck for otherwise fast models. This overhead will always have an effect on performance regardless of model size but may be less noticeable on larger models. Note that this XLA Lazy Tensor execution performance may improve significantly with new torch features in the future.

Example#

Fixed Shape Example

import torch import torch_neuronx import torch_xla.core.xla_model as xm

Create XLA device

device = xm.xla_device()

Load example model and inputs to Neuron device

model = torch.nn.Sequential( torch.nn.Linear(784, 120), torch.nn.ReLU(), torch.nn.Linear(120, 10), torch.nn.Softmax(dim=-1), ) model.eval() model.to(device) example = torch.rand((1, 784), device=device)

Inference

with torch.no_grad(): result = model(example) xm.mark_step() # Compilation occurs here print(result.cpu())

Dynamic Shape Example

The following is an example of a model that dynamically changes the sequence length and batch size of the input token ID tensor to trigger recompilations. This kind of workflow would require padding when using traced inference.

import torch import torch_neuronx import torch_xla.core.xla_model as xm

Create XLA device

device = xm.xla_device()

Load example model and inputs to Neuron device

model = torch.nn.Sequential( torch.nn.Embedding(num_embeddings=30522, embedding_dim=512), torch.nn.Linear(512, 128), torch.nn.ReLU(), torch.nn.Linear(128, 2), torch.nn.Softmax(dim=-1), ) model.eval() model.to(device)

token_ids_1 = torch.tensor([ [1, 28, 748, 0], ]) # shape: [1, 4] token_ids_2 = torch.tensor([ [1, 13087, 10439, 1990, 18912, 0], [1, 12009, 7849, 2509, 3500, 0], ]) # shape: [2, 6]

Inference

with torch.no_grad():

# First compilation/inference
result = model(token_ids_1)
xm.mark_step()
print(result.cpu())  # shape: [1, 4, 2]

# Recompilation occurs here since token_ids_2 is a different shape. This infer
# would have failed if the model had been traced with shape [1, 4]
result = model(token_ids_2)
xm.mark_step()
print(result.cpu())  # shape: [2, 6, 2]

Traced Inference Mechanics#

Traced inference uses Ahead-Of-Time (AOT) compilation for Neuron execution.

Similar to XLA Lazy Tensor inference, trace() uses the operation recording mechanisms provided by torch-xla to build the graph structure. This graph structure is also sent to the neuronx-cc compiler to produce a binary (NEFF) that is executable on Neuron.

The main difference is that the call to trace() returns a new fully compiled graph as a TorchScript Module. Upon calling this new Module, rather than re-executing the python, rebuilding the graph, and checking the cache for a matching model, the new Module simply executes the precompiled graph that was preloaded during tracing. This is a significantly more optimized runtime since it avoids the python operator tracing, graph building, etc.

One disadvantage of this interface is that a model will never dynamically recompile after a trace. This means that dynamic control flow is not supported within a function/module. Tensor input/output shapes are fixed to the shapes passed to the trace() API. Dynamic batching and bucketing can be used to avoid the pitfalls of static shapes.

Example#

import torch import torch_neuronx

Create example model and inputs

model = torch.nn.Sequential( torch.nn.Linear(784, 120), torch.nn.ReLU(), torch.nn.Linear(120, 10), torch.nn.Softmax(dim=-1), ) model.eval() example = torch.rand((1, 784))

Create fixed model trace

trace = torch_neuronx.trace(model, example)

Inference

result = trace(example) # No recompilation. Input shapes must not change print(result)

Traced Inference Advantages#

Traced inference should be used for nearly all deployment purposes since it provides some key advantages over XLA Lazy Tensor execution:

Summary#

XLA Device Inference Traced Inference
Compilation JIT AOT
Serialization N/A TorchScript
Performance Slower Faster
Dynamic Yes No
C++ Usage No Yes

This document is relevant for: Inf2, Trn1, Trn2