Autobucketing for Inference (|torch-neuronx|) — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2

Autobucketing for Inference (|torch-neuronx|)#

Table of Contents

Introduction#

Autobucketing is a feature that enables you to use multiple bucket models. Each bucket model accepts a static input shape and a bucket kernel function. The models are then packaged into a single traced PyTorch model that can accept multiple different input shapes.

This gives you increased flexibility for inputs into Neuron models without the need to manage multiple Neuron models. The applications of this are extensive, from optimal model selection based on image resolution, to efficient sampling for token generation in language models.

While Autobucketing offers increased flexibility, Autobucketing is also useful for latency sensitive applications since small and large inputs can be applied on small and large models respectively, based on the bucket kernel function.

This Developer Guide will discuss best practices for implementing Autobucketing for your use case. For this Developer Guide, a BERT model will be used, where we bucket on the sequence length dimension.

Before continuing, it is recommended to familiarize yourself with the Autobucketing APIs, which can be found here.

Bucket Kernels#

Bucket kernels are user-defined functions that take in the model input as input to the function and return a tuple containing a potentially manipulated model input and a tensor representing the bucket index. An important aspect of this function is that it must be able to be adapted to the TorchScript representation using torch.jit.script(). This is because to support saving a traced bucket model with torch.jit.save() and torch.jit.load(), you need all elements of the model to be in TorchScript. The below example shows a bucket kernel that is adaptable to TorchScript in this way.

import torch from typing import List

def sequence_length_bucket_kernel(tensor_list: List[torch.Tensor]): x = tensor_list[0] bucket_dim = 1 x_shape = x.shape tensor_sequence_length = x_shape[bucket_dim] batch_size = x_shape[bucket_dim - 1] buckets = [128, 512] idx = 0 num_inputs = 3 bucket = buckets[0] reshaped_tensors: List[torch.Tensor] = [] bucket_idx = 0 for idx, bucket in enumerate(buckets): if tensor_sequence_length <= bucket: bucket_idx = idx for tensor in tensor_list: if num_inputs == 0: break delta = bucket - tensor_sequence_length padding_shape: List[int] = [batch_size, delta] zeros = torch.zeros(padding_shape, dtype=x.dtype) reshaped_tensors.append(torch.cat([tensor, zeros], dim=bucket_dim)) num_inputs -= 1 break return reshaped_tensors, torch.tensor([bucket_idx])

def get_bucket_kernel(*_): bk = torch.jit.script(sequence_length_bucket_kernel) return bk

In the above example we define a bucket kernel that takes in an input to a transformers model, which is [input_ids,attention_mask,token_type_ids]. We first obtain the first tensor in that list, since that tensor contains sequence_length as a dimension, and retrieve the sequence_length and batch_size. We also define the sequence length buckets. The next major part of the code is the for loop, which first finds the matching sequence length bucket and then iterates through the tensors in the list to right pad the tensors to the desired sequence length. After this is done, we return the padded inputs as a list of tensors and a tensor containing the bucket index. Finally, we create a function get_bucket_kernel which returns a version of the bucket kernel that has been adapted to TorchScript using using torch.jit.script(). We can use this bucket kernel to pass in a tokenized input of sequence length 1-512, which is padded to the nearest bucket size rounded up.

Note that we call torch.jit.script() instead of torch.jit.trace(). This is because we rely on control flow logic evaluating correctly for all inputs. This results in certain challenges when writing compatible and accurate bucket kernels. We discuss these challenges and resolutions in the next section.

Torchscript Best Practices for Bucket Kernels#

Below are some recommendations when creating these Bucket Kernels:

Shared State Buffers#

Autobucketing supports the concept of a shared buffer between bucket models. You can use this to define how the shared buffer can be manipulated to be fed as input to a bucket model via the shared_state_buffer_preprocessor.

The above recommendations also apply when defining a shared_state_buffer_preprocessor.

An example where a shared buffer is useful between bucket models is maintaining a KV Cache between bucket models for LLMs.

Below is an example of a KV Cache preprocessor for Autobucketing.

def state_preprocessor(shapes_collection: List[List[List[int]]], states: List[torch.Tensor], bucket_idx_tensor: torch.Tensor)->List[torch.Tensor]: bucket_idx = torch.ops.aten.Int(bucket_idx_tensor) shapes = shapes_collection[bucket_idx] sliced_state_tensors = []

for i in range(len(shapes)): expected_shape = shapes[i] state_tensor = states[i] state_tensor_shape = state_tensor.shape for j,npos in enumerate(expected_shape): state_tensor_dim_length = state_tensor_shape[j] state_tensor = torch.ops.aten.slice(state_tensor,dim=j,start=state_tensor_dim_length-npos,end=state_tensor_dim_length) sliced_state_tensors.append(state_tensor)

return sliced_state_tensors

def get_state_preprocessor(): sp = torch.jit.script(state_preprocessor) return sp

In this example, we take in shapes_collection, states, and bucket_idx_tensor. The input shapes_collection is essentially a list of expected shapes for each state tensor defined for each bucket kernel. For example, we can have shapes_collection = [[[1,128],[1,128]],[[1,512],[1,512]]] where shapes_collection[0][1] retrieves the expected shape for the second state tensor in the first bucket. The input states is the actual list of tensors in the shared buffer, which contains tensors of the largest shape. Finally, bucket_idx_tensor is the same tensor returned by the bucket kernel.

Two things to note is that we use two aten functions directly: aten::Int to convert the bucket_idx_tensor to an integer, and aten::slice to perform slicing given non-const or non-literal parameters.

Note

The above shared state function is not used in the BERT example

Bucket Model Config#

Given the above two examples, we can initialize a torch_neuronx.BucketModelConfig object as follows:

import torch import torch_neuronx

from typing import List

above code

bucket_config = torch_neuronx.BucketModelConfig(get_bucket_kernel,shared_state_buffer_preprocessor=get_state_preprocessor)

Putting it all Together#

Here is a simple example using the BERT model:

import torch import torch_neuronx

from transformers import AutoTokenizer, AutoModelForSequenceClassification

from typing import List

def encode(tokenizer, *inputs, max_length=128, batch_size=1): tokens = tokenizer.encode_plus( *inputs, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt" ) return ( torch.repeat_interleave(tokens['input_ids'], batch_size, 0), torch.repeat_interleave(tokens['attention_mask'], batch_size, 0), )

def get_bert_model(*args): name = "bert-base-cased-finetuned-mrpc" model = AutoModelForSequenceClassification.from_pretrained(name, torchscript=True)

return model,{}

def sequence_length_bucket_kernel(tensor_list: List[torch.Tensor]): x = tensor_list[0] bucket_dim = 1 x_shape = x.shape tensor_sequence_length = x_shape[bucket_dim] batch_size = x_shape[bucket_dim - 1] buckets = [128, 512] idx = 0 num_inputs = 3 bucket = buckets[0] reshaped_tensors: List[torch.Tensor] = [] bucket_idx = 0 for idx, bucket in enumerate(buckets): if tensor_sequence_length <= bucket: bucket_idx = idx for tensor in tensor_list: if num_inputs == 0: break delta = bucket - tensor_sequence_length padding_shape: List[int] = [batch_size, delta] zeros = torch.zeros(padding_shape, dtype=x.dtype) reshaped_tensors.append(torch.cat([tensor, zeros], dim=bucket_dim)) num_inputs -= 1 break return reshaped_tensors, torch.tensor([bucket_idx])

def get_bucket_kernel(*_): bk = torch.jit.script(sequence_length_bucket_kernel) return bk

if name == 'main':

name = "bert-base-cased-finetuned-mrpc"

# Build tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForSequenceClassification.from_pretrained(name, torchscript=True)

# Setup some example inputs
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "HuggingFace is named after the huggingface emoji"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"

paraphrase_s128 = encode(tokenizer, sequence_0, sequence_2)
paraphrase_s122 = encode(tokenizer, sequence_0, sequence_2, max_length=122)

paraphrase_s512 = encode(tokenizer, sequence_0, sequence_1, max_length=512)
paraphrase_s444 = encode(tokenizer, sequence_0, sequence_1, max_length=444)

# Note: Run on CPU before trace. Avoids running with XLA allocated params
paraphrase_expected_s128 = torch.argmax(model(*paraphrase_s128)[0])
paraphrase_expected_s512 = torch.argmax(model(*paraphrase_s512)[0])


# Trace model
bucket_config = torch_neuronx.BucketModelConfig(get_bucket_kernel)
bucket_trace_neuron = torch_neuronx.bucket_model_trace(get_bert_model, [paraphrase_s128,paraphrase_s512], bucket_config)

# Run traced model with shorter inputs to test bucket rounding
paraphrase_actual_s128 = torch.argmax(bucket_trace_neuron(*paraphrase_s122)[0])
paraphrase_actual_s512 = torch.argmax(bucket_trace_neuron(*paraphrase_s444)[0])


# Compare outputs
assert paraphrase_expected_s128 == paraphrase_actual_s128
assert paraphrase_expected_s512 == paraphrase_actual_s512

Autobucketing for Neuronx-Distributed#

To see this same example applied on Neuronx-Distributed, go to this section on the Neuronx-Distributed Inference Developer Guide

This document is relevant for: Inf2, Trn1, Trn2