torch.nn.attention.flex_attention (original) (raw)

BlockMask is our format for representing a block-sparse attention mask. It is somewhat of a cross in-between BCSR and a non-sparse format.

Basics

A block-sparse mask means that instead of representing the sparsity of individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is considered sparse only if every element within that block is sparse. This aligns well with hardware, which generally expects to perform contiguous loads and computation.

This format is primarily optimized for 1. simplicity, and 2. kernel efficiency. Notably, it is not optimized for size, as this mask is always reduced by a factor of KV_BLOCK_SIZE * Q_BLOCK_SIZE. If the size is a concern, the tensors can be reduced in size by increasing the block size.

The essentials of our format are:

num_blocks_in_row: Tensor[ROWS]: Describes the number of blocks present in each row.

col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]:col_indices[i] is the sequence of block positions for row i. The values of this row after col_indices[i][num_blocks_in_row[i]] are undefined.

For example, to reconstruct the original tensor from this format:

dense_mask = torch.zeros(ROWS, COLS) for row in range(ROWS): for block_idx in range(num_blocks_in_row[row]): dense_mask[row, col_indices[row, block_idx]] = 1

Notably, this format makes it easier to implement a reduction along the_rows_ of the mask.

Details

The basics of our format require only kv_num_blocks and kv_indices. But, we have up to 8 tensors on this object. This represents 4 pairs:

1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as we reduce along the KV dimension.

2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and purely an optimization. As it turns out, applying masking to every block is quite expensive! If we specifically know which blocks are “full” and don’t require masking at all, then we can skip applying mask_mod to these blocks. This requires the user to split out a separate mask_mod from the score_mod. For causal masks, this is about a 15% speedup.

3. [GENERATED] (q_num_blocks, q_indices): Required for the backwards pass, as computing dKV requires iterating along the mask along the Q dimension. These are autogenerated from 1.

4. [GENERATED] (full_q_num_blocks, full_q_indices): Same as above, but for the backwards pass. These are autogenerated from 2.

BLOCK_SIZE_: tuple[int, int]_#

as_tuple(flatten=True)[source]#

Returns a tuple of the attributes of the BlockMask.

Parameters

flatten (bool) – If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE)

classmethod from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks=None, full_kv_indices=None, BLOCK_SIZE=128, mask_mod=None, seq_lengths=None, compute_q_blocks=True)[source]#

Creates a BlockMask instance from key-value block information.

Parameters

Returns

Instance with full Q information generated via _transposed_ordered

Return type

BlockMask

Raises

full_kv_indices_: Optional[Tensor]_#

full_kv_num_blocks_: Optional[Tensor]_#

full_q_indices_: Optional[Tensor]_#

full_q_num_blocks_: Optional[Tensor]_#

kv_indices_: Tensor_#

kv_num_blocks_: Tensor_#

mask_mod_: Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]_#

numel()[source]#

Returns the number of elements (not accounting for sparsity) in the mask.

q_indices_: Optional[Tensor]_#

q_num_blocks_: Optional[Tensor]_#

seq_lengths_: tuple[int, int]_#

property shape#

sparsity()[source]#

Computes the percentage of blocks that are sparse (i.e. not computed)

Return type

float

to(device)[source]#

Moves the BlockMask to the specified device.

Parameters

device (torch.device or str) – The target device to move the BlockMask to. Can be a torch.device object or a string (e.g., ‘cpu’, ‘cuda:0’).

Returns

A new BlockMask instance with all tensor components moved to the specified device.

Return type

BlockMask

Note

This method does not modify the original BlockMask in-place. Instead, it returns a new BlockMask instance where individual tensor attributes may or may not be moved to the specified device, depending on their current device placement.

to_dense()[source]#

Returns a dense block that is equivalent to the block mask.

Return type

Tensor

to_string(grid_size=(20, 20), limit=4)[source]#

Returns a string representation of the block mask. Quite nifty.

If grid_size is -1, prints out an uncompressed version. Warning, it can be quite big!