GLU — PyTorch 2.7 documentation (original) (raw)

class torch.nn.GLU(dim=-1)[source][source]

Applies the gated linear unit function.

GLU(a,b)=a⊗σ(b){GLU}(a, b)= a \otimes \sigma(b) where aa is the first half of the input matrices and bb is the second half.

Parameters

dim (int) – the dimension on which to split the input. Default: -1

Shape:

Examples:

m = nn.GLU() input = torch.randn(4, 2) output = m(input)