LSTMCell — PyTorch 2.7 documentation (original) (raw)

class torch.nn.LSTMCell(input_size, hidden_size, bias=True, device=None, dtype=None)[source][source]

A long short-term memory (LSTM) cell.

i=σ(Wiix+bii+Whih+bhi)f=σ(Wifx+bif+Whfh+bhf)g=tanh⁡(Wigx+big+Whgh+bhg)o=σ(Wiox+bio+Whoh+bho)c′=f⊙c+i⊙gh′=o⊙tanh⁡(c′)\begin{array}{ll} i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ c' = f \odot c + i \odot g \\ h' = o \odot \tanh(c') \\ \end{array}

where σ\sigma is the sigmoid function, and ⊙\odot is the Hadamard product.

Parameters

Inputs: input, (h_0, c_0)

Outputs: (h_1, c_1)

Variables

Note

All the weights and biases are initialized from U(−k,k)\mathcal{U}(-\sqrt{k}, \sqrt{k})where k=1hidden_sizek = \frac{1}{\text{hidden\_size}}

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Examples:

rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) hx = torch.randn(3, 20) # (batch, hidden_size) cx = torch.randn(3, 20) output = [] for i in range(input.size()[0]): ... hx, cx = rnn(input[i], (hx, cx)) ... output.append(hx) output = torch.stack(output, dim=0)