torch.nn.attention.bias.CausalBias — PyTorch 2.7 documentation (original) (raw)

class torch.nn.attention.bias.CausalBias(variant, seq_len_q, seq_len_kv)[source][source]

A bias representing causal attention patterns. For an overview of the bias structure, see the CausalVariant enum.

This class is used for defining causal (triangular) attention biases. For construing the bias, there exist two factory functions: causal_upper_left() and causal_lower_right().

Example:

from torch.nn.attention.bias import causal_lower_right

bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8

Create a lower-right causal bias

attn_bias = causal_lower_right(seqlen_q, seqlen_kv)

q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)

out = F.scaled_dot_product_attention(q, k, v, attn_bias)

Warning

This class is a prototype and subject to change.