PyTorch Neuron (torch-neuronx) Weight Replacement API for Inference — AWS Neuron Documentation (original) (raw)
This document is relevant for: Inf2
, Trn1
, Trn2
PyTorch Neuron (torch-neuronx
) Weight Replacement API for Inference#
torch_neuronx.replace_weights(neuron_model, weights)#
Replaces the weights in a Neuron Model with split weights. This function will emit a warning of the supplied Neuron model does not contain any separated weights.
Warning
The below API is only applicable for models traced with the parameter inline_weights_to_neff=False
, which is True
by default. See torch_neuronx.trace() for details.
Parameters:
- neuron_model (RecursiveScriptModule) – A Neuron model compiled with split weights
- weights (Module , Dict [_str,_ Tensor ]) – Either the original model with the new weights, or the state_dict of a model.
Returns:
None
, this function performs the weight replacement inline.
Return type:
None
Examples
Using a model
import torch import torch_neuronx
class Network(torch.nn.Module): def init(self, hidden_size=4, layers=3) -> None: super().init() self.layers = torch.nn.Sequential( *(torch.nn.Linear(hidden_size, hidden_size) for _ in range(layers)))
def forward(self, tensor):
return self.layers(tensor)
initialize two networks
network = Network() network2 = Network() network.eval() network2.eval()
inp = torch.rand(2,4)
trace weight separated model with first network
weight_separated_trace = torch_neuronx.trace(network,inp,inline_weights_to_neff=False)
replace with weights from second network
torch_neuronx.replace_weights(weight_separated_trace,network2.state_dict())
get outputs from neuron and cpu networks
out_network2 = network2(inp) out_neuron = weight_separated_trace(inp)
check that they are equal
print(out_network2,out_neuron)
Using safetensors
The safetensors library is useful for storing/loading model tensors safely and quickly.
import torch import torch_neuronx
from safetensors import safe_open from safetensors.torch import save_model
class Network(torch.nn.Module): def init(self, hidden_size=4, layers=3) -> None: super().init() self.layers = torch.nn.Sequential( *(torch.nn.Linear(hidden_size, hidden_size) for _ in range(layers)))
def forward(self, tensor):
return self.layers(tensor)
initialize two networks
network = Network() network2 = Network() network.eval() network2.eval()
inp = torch.rand(2,4)
trace weight separated model with first network
weight_separated_trace = torch_neuronx.trace(network,inp,inline_weights_to_neff=False)
save network2 weights to safetensors
safetensor_path = f"{directory}/network2.safetensors" save_model(network2,safetensor_path)
#load safetensors from network2 into traced_weight separated model tensors = {} with safe_open(safetensor_path,framework="pt") as f: for k in f.keys(): tensors[k] = f.get_tensor(k)
replace with weights from second network
torch_neuronx.replace_weights(weight_separated_trace,tensors)
get outputs from neuron and cpu networks
out_network2 = network2(inp) out_neuron = weight_separated_trace(inp)
check that they are equal
print(out_network2,out_neuron)
Note
For non-safetensors models, use torch.load
to load the model, and pass the model’s state_dict
inside like the first example.
This document is relevant for: Inf2
, Trn1
, Trn2