MultiheadAttention — PyTorch 2.7 documentation (original) (raw)

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source][source]

Allows the model to jointly attend to information from different representation subspaces.

Note

See this tutorialfor an in depth discussion of the performant building blocks PyTorch offers for building your own transformer layers.

Method described in the paper:Attention Is All You Need.

Multi-Head Attention is defined as:

MultiHead(Q,K,V)=Concat(head1,…,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O

where headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V).

nn.MultiheadAttention will use the optimized implementations ofscaled_dot_product_attention() when possible.

In addition to support for the new scaled_dot_product_attention()function, for speeding up Inference, MHA will use fastpath inference with support for Nested Tensors, iff:

If the optimized inference fastpath implementation is in use, aNestedTensor can be passed forquery/key/value to represent padding more efficiently than using a padding mask. In this case, a NestedTensorwill be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected.

Parameters

Examples:

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) attn_output, attn_output_weights = multihead_attn(query, key, value)

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source][source]

Compute attention outputs using query, key, and value embeddings.

Supports optional parameters for padding, masks and attention weights.

Parameters

Return type

tuple[torch.Tensor, Optional[torch.Tensor]]

Outputs:

Note

batch_first argument is ignored for unbatched inputs.

merge_masks(attn_mask, key_padding_mask, query)[source][source]

Determine mask type and combine masks if necessary.

If only one mask is provided, that mask and the corresponding mask type will be returned. If both masks are provided, they will be both expanded to shape (batch_size, num_heads, seq_len, seq_len), combined with logical orand mask type 2 will be returned :param attn_mask: attention mask of shape (seq_len, seq_len), mask type 0 :param key_padding_mask: padding mask of shape (batch_size, seq_len), mask type 1 :param query: query embeddings of shape (batch_size, seq_len, embed_dim)

Returns

merged mask mask_type: merged mask type (0, 1, or 2)

Return type

merged_mask