RNNCell — PyTorch 2.7 documentation (original) (raw)

class torch.nn.RNNCell(input_size, hidden_size, bias=True, nonlinearity='tanh', device=None, dtype=None)[source][source]

An Elman RNN cell with tanh or ReLU non-linearity.

h′=tanh⁡(Wihx+bih+Whhh+bhh)h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})

If nonlinearity is ‘relu’, then ReLU is used in place of tanh.

Parameters

Inputs: input, hidden

Outputs: h’

Shape:

Variables

Note

All the weights and biases are initialized from U(−k,k)\mathcal{U}(-\sqrt{k}, \sqrt{k})where k=1hidden_sizek = \frac{1}{\text{hidden\_size}}

Examples:

rnn = nn.RNNCell(10, 20) input = torch.randn(6, 3, 10) hx = torch.randn(3, 20) output = [] for i in range(6): ... hx = rnn(input[i], hx) ... output.append(hx)