Developer Guide - PyTorch Neuron (torch-neuron) LSTM Support — AWS Neuron Documentation (original) (raw)
This document is relevant for: Inf1
Developer Guide - PyTorch Neuron (torch-neuron
) LSTM
Support#
The torch-neuron package can support LSTM
operations and yield high performance on both fixed-length and variable-length sequences. Most network configurations can be supported, with the exception of those that require PackedSequence
usage outside of LSTM
or pad_packed_sequence()
operations. Neuron must guarantee that the shapes can remain fixed throughout the network.
The following sections describe which scenarios can and cannot be supported.
Supported Usage#
Fixed-Length Sequences#
In normal usage of an LSTM
, the inputs and outputs are expected to be a fixed size sequence length. This is the most basic usage of an LSTM
but may not be applicable to applications where the input sequence length may vary.
import torch import torch_neuron
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)
def forward(self, inputs):
output, (ht, ct) = self.lstm(inputs)
return output, (ht, ct)
Example Inputs
seq_len, batch_size, input_size = 5, 2, 3 inputs = torch.rand(seq_len, batch_size, input_size)
Trace
torch_neuron.trace(Network(), (inputs,))
Packed Input, Padded Output, Pre-Sorted Inputs#
A common usage of an LSTM
is when the input sequence sizes vary according to an input sequence lengths (such as tokens).
For example, the following sentences could result in two different sequence lengths after tokenization:
Input
text = [ 'Hello, sailor', 'Example', ]
... Tokenization ...
Result
tokens = [ [101, 7592, 1010, 11803, 102], [101, 2742, 102, 0, 0], ] lengths = [5, 3]
Because the lengths are different, the final LSTM
state will be dependent upon the lengths of each sequence in the batch. Torch provides a way to deal with these types of sequences by densely packing batches into a PackedSequence
. The most common way this is constructed is by using the pack_padded_sequence()
utility function prior to feeding inputs into the LSTM
.
Packing the above sequences would result in the following data and batch size tensors.
data = [101, 101, 7592, 2742, 1010, 102, 11803, 102] batch_sizes = [2, 2, 2, 1, 1]
In addition to correctly computing final LSTM
state, using a packed sequence instead of a padded sequence also improves model performance on CPU. On Neuron, where computation is fixed to the maximum length ahead of time,this is does not improve performance.
When an LSTM
is processing a PackedSequence
, it must do so in a descending sorted length order. To ensure that sequences are sorted, pack_padded_sequence()
provides an enforce_sorted
flag. When enforce_sorted
is True
, the input is already expected to contain sequences sorted by length in a decreasing order along the batch dimension. Note that this must be enforced in the application-level code but is only relevant when batch size > 1.
The following network can compile successfully because the input and output to the network are guaranteed to be a fixed shape. The input shape is expected to be a padded tensor and the output tensor is expected to be padded to the maximum sequence length using the pad_packed_sequence()
function call:
import torch import torch_neuron
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)
def forward(self, inputs, lengths):
packed_input = torch.nn.utils.rnn.pack_padded_sequence(
inputs,
lengths=lengths,
enforce_sorted=True,
)
packed_result, (ht, ct) = self.lstm(packed_input)
padded_result, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_result)
return padded_result, ht, ct
Example Inputs
seq_len, batch_size, input_size = 5, 2, 3 inputs = torch.rand(seq_len, batch_size, input_size) lengths = torch.tensor([seq_len] * batch_size)
Trace
torch_neuron.trace(Network(), (inputs, lengths))
Packed Input, Padded Output, Unsorted Inputs#
When enforce_sorted
is False
, the input will be sorted unconditionally. This causes some CPU overhead on Neuron because unsupported operators will be inserted into the graph such as aten::sort
and aten::scatter_
. Theaten::lstm
operation can still be supported, but it will be less efficient than when enforce_sorted
is True
.
The following code is able to be traced, but results in the sorting operations running on CPU. This is not problematic in this case because theaten::sort
and aten::scatter_
are executed on CPU at the very beginning of the graph just prior to Neuron execution.
Like the previous example, the call to pad_packed_sequence()
ensures that the output is a fixed-shape based on the maximum sequence length.
import torch import torch_neuron
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)
def forward(self, inputs, lengths):
packed_input = torch.nn.utils.rnn.pack_padded_sequence(
inputs,
lengths=lengths,
enforce_sorted=False,
)
packed_result, (ht, ct) = self.lstm(packed_input)
padded_result, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_result)
return padded_result, ht, ct
Example Inputs
seq_len, batch_size, input_size = 5, 2, 3 inputs = torch.rand(seq_len, batch_size, input_size) lengths = torch.tensor([seq_len] * batch_size)
Trace
trace = torch_neuron.trace(Network(), (inputs, lengths))
Packed Inputs, Final Hidden & Cell State Only#
When only the final LSTM
hidden & cell state is used, it does not matter if the inputs are packed or unpacked since these state tensors will not vary in size.
import torch import torch_neuron
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)
def forward(self, inputs, lengths):
packed_input = torch.nn.utils.rnn.pack_padded_sequence(
inputs,
lengths=lengths,
enforce_sorted=True,
)
packed_output, (ht, ct) = self.lstm(packed_input)
return ht, ct
Example Inputs
seq_len, batch_size, input_size = 5, 2, 3 inputs = torch.rand(seq_len, batch_size, input_size) lengths = torch.tensor([seq_len] * batch_size)
Trace
trace = torch_neuron.trace(Network(), (inputs, lengths))
Note that when the packed_output
is unused, it does not need to be passed to the pad_packed_sequence()
to enable the LSTM
to be compiled.
Unsupported Usage#
Neuron does not support the use of a PackedSequence
outside of the LSTM
operation and the pad_packed_sequence()
operation. This is because the shape of a PackedSequence
can vary depending on the input data. This is incompatible with the Neuron restriction that all tensor sizes must be known at compilation time. When a PackedSequence
is used only by an LSTM
or pad_packed_sequence()
operation, Neuron can guarantee the size of the intermediary tensors by padding on behalf of the application.
This means that If the PackedSequence
is either used by a different operation or returned from the network this would result in all of the LSTM
operations to be executed on CPU or the network compilation will fail.
PackedSequence
Returned#
The following is unsupported because the PackedSequence
result of the LSTM
is returned by the network:
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)
def forward(self, inputs, lengths):
packed_input = torch.nn.utils.rnn.pack_padded_sequence(
inputs,
lengths=lengths,
enforce_sorted=False,
)
packed_result, (ht, ct) = self.lstm(packed_input)
return packed_result.data, ht, ct
Behavior: In this case, compilation fails and the following warning is generated:
Operator "aten::lstm" consuming a PackedSequence input can only be supported when its corresponding PackedSequence output is unused or unpacked using "aten::_pad_packed_input". Found usage by "prim::Return"
Resolution: To avoid this error, the packed_result
should be padded prior to being returned from the network by using pad_packed_sequence()
Invalid PackedSequence
Usage#
The following is unsupported because the PackedSequence
result of the LSTM
is used by a non-LSTM operator:
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(input_size=3, hidden_size=7)
def forward(self, inputs, lengths):
packed_input = torch.nn.utils.rnn.pack_padded_sequence(
inputs,
lengths=lengths,
enforce_sorted=False,
)
packed_result, (ht, ct) = self.lstm(packed_input)
return torch.max(packed_result.data)
Behavior: In this case, compilation fails and the following warning is generated:
Operator "aten::lstm" consuming a PackedSequence input can only be supported when its corresponding PackedSequence output is unused or unpacked using "aten::_pad_packed_input". Found usage by "aten::max"
Resolution: To avoid this error, the packed_result
should be padded prior to being used in the max()
from the network by using pad_packed_sequence()
.
This document is relevant for: Inf1