PyTorch Neuron (torch-neuronx) Weight Replacement API for Inference — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2

PyTorch Neuron (torch-neuronx) Weight Replacement API for Inference#

torch_neuronx.replace_weights(neuron_model, weights)#

Replaces the weights in a Neuron Model with split weights. This function will emit a warning of the supplied Neuron model does not contain any separated weights.

Warning

The below API is only applicable for models traced with the parameter inline_weights_to_neff=False, which is True by default. See torch_neuronx.trace() for details.

Parameters:

Returns:

None, this function performs the weight replacement inline.

Return type:

None

Examples

Using a model

import torch import torch_neuronx

class Network(torch.nn.Module): def init(self, hidden_size=4, layers=3) -> None: super().init() self.layers = torch.nn.Sequential( *(torch.nn.Linear(hidden_size, hidden_size) for _ in range(layers)))

def forward(self, tensor):
    return self.layers(tensor)

initialize two networks

network = Network() network2 = Network() network.eval() network2.eval()

inp = torch.rand(2,4)

trace weight separated model with first network

weight_separated_trace = torch_neuronx.trace(network,inp,inline_weights_to_neff=False)

replace with weights from second network

torch_neuronx.replace_weights(weight_separated_trace,network2.state_dict())

get outputs from neuron and cpu networks

out_network2 = network2(inp) out_neuron = weight_separated_trace(inp)

check that they are equal

print(out_network2,out_neuron)

Using safetensors

The safetensors library is useful for storing/loading model tensors safely and quickly.

import torch import torch_neuronx

from safetensors import safe_open from safetensors.torch import save_model

class Network(torch.nn.Module): def init(self, hidden_size=4, layers=3) -> None: super().init() self.layers = torch.nn.Sequential( *(torch.nn.Linear(hidden_size, hidden_size) for _ in range(layers)))

def forward(self, tensor):
    return self.layers(tensor)

initialize two networks

network = Network() network2 = Network() network.eval() network2.eval()

inp = torch.rand(2,4)

trace weight separated model with first network

weight_separated_trace = torch_neuronx.trace(network,inp,inline_weights_to_neff=False)

save network2 weights to safetensors

safetensor_path = f"{directory}/network2.safetensors" save_model(network2,safetensor_path)

#load safetensors from network2 into traced_weight separated model tensors = {} with safe_open(safetensor_path,framework="pt") as f: for k in f.keys(): tensors[k] = f.get_tensor(k)

replace with weights from second network

torch_neuronx.replace_weights(weight_separated_trace,tensors)

get outputs from neuron and cpu networks

out_network2 = network2(inp) out_neuron = weight_separated_trace(inp)

check that they are equal

print(out_network2,out_neuron)

Note

For non-safetensors models, use torch.load to load the model, and pass the model’s state_dictinside like the first example.

This document is relevant for: Inf2, Trn1, Trn2