GRUCell — PyTorch 2.7 documentation (original) (raw)

class torch.nn.GRUCell(input_size, hidden_size, bias=True, device=None, dtype=None)[source][source]

A gated recurrent unit (GRU) cell.

r=σ(Wirx+bir+Whrh+bhr)z=σ(Wizx+biz+Whzh+bhz)n=tanh⁡(Winx+bin+r⊙(Whnh+bhn))h′=(1−z)⊙n+z⊙h\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ h' = (1 - z) \odot n + z \odot h \end{array}

where σ\sigma is the sigmoid function, and ⊙\odot is the Hadamard product.

Parameters

Inputs: input, hidden

Outputs: h’

Shape:

Variables

Note

All the weights and biases are initialized from U(−k,k)\mathcal{U}(-\sqrt{k}, \sqrt{k})where k=1hidden_sizek = \frac{1}{\text{hidden\_size}}

On certain ROCm devices, when using float16 inputs this module will use different precision for backward.

Examples:

rnn = nn.GRUCell(10, 20) input = torch.randn(6, 3, 10) hx = torch.randn(3, 20) output = [] for i in range(6): ... hx = rnn(input[i], hx) ... output.append(hx)