torch.nn.attention.bias.CausalBias — PyTorch 2.7 documentation (original) (raw)
class torch.nn.attention.bias.CausalBias(variant, seq_len_q, seq_len_kv)[source][source]¶
A bias representing causal attention patterns. For an overview of the bias structure, see the CausalVariant enum.
This class is used for defining causal (triangular) attention biases. For construing the bias, there exist two factory functions: causal_upper_left() and causal_lower_right().
Example:
from torch.nn.attention.bias import causal_lower_right
bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8
Create a lower-right causal bias
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
out = F.scaled_dot_product_attention(q, k, v, attn_bias)
Warning
This class is a prototype and subject to change.