torch.nn.attention.sdpa_kernel — PyTorch 2.7 documentation (original) (raw)

torch.nn.attention.sdpa_kernel(backends, set_priority=False)[source][source]

Context manager to select which backend to use for scaled dot product attention.

Warning

This function is beta and subject to change.

Parameters

Example:

from torch.nn.functional import scaled_dot_product_attention from torch.nn.attention import SDPBackend, sdpa_kernel

Only enable flash attention backend

with sdpa_kernel(SDPBackend.FLASH_ATTENTION): scaled_dot_product_attention(...)

Enable the Math or Efficient attention backends

with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): scaled_dot_product_attention(...)

This context manager can be used to select which backend to use for scaled dot product attention. Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.