torch_geometric.nn.conv.GatedGraphConv — pytorch_geometric documentation (original) (raw)

pytorch_geometric

class GatedGraphConv(out_channels: int, num_layers: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]

Bases: MessagePassing

The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper.

\[ \begin{align}\begin{aligned}\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}\\\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)}\\\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)})\end{aligned}\end{align} \]

up to representation \(\mathbf{h}_i^{(L)}\). The number of input channels of \(\mathbf{x}_i\) needs to be less or equal than out_channels.\(e_{j,i}\) denotes the edge weight from source node j to target node i (default: 1)

Parameters:

Shapes:

forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) → Tensor[source]

Runs the forward pass of the module.

Return type:

Tensor

reset_parameters()[source]

Resets all learnable parameters of the module.