MultiheadAttention — PyTorch 1.8.1 documentation (original) (raw)

class torch.nn. MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]

Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need

MultiHead(Q,K,V)=Concat(head1,…,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

where headi=Attention(QWiQ,KWiK,VWiV)head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) .

Parameters

Note that if kdim and vdim are None, they will be set to embed_dim such that query, key, and value have the same number of features.

Examples:

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) attn_output, attn_output_weights = multihead_attn(query, key, value)

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)[source]

Parameters

Shapes for inputs:

Shapes for outputs: