[P/D][NixlConnector] Enable FlashInfer backend by NickLucche · Pull Request #19090 · vllm-project/vllm (original) (raw)

This PR enables the use of VLLM_ATTENTION_BACKEND=FLASHINFER in disaggregated prefill setups leveraging NixlConnector (which is currently allowed but broken on main).

The main difference wrt default FA backend is that FlashInfer swaps the cache first two dims (K/V and num_blocks) resulting in [num_blocks, KV(2), N,H,D]. The easiest approach here is to maintain the layout and just transfer the whole region (for each layer) instead of trying to split the K/V dim.

As a result, the message size should be twice as big when running FlashInfer, resulting in an interesting trade-off that we should monitor to ensure optimal transfer size.

As a side note, this will also enable the TRITON_MLA_VLLM_V1 backend when an MLA model is detected.
Note that for MLA model the behavior should be unchanged as the kv shape is not backend-dependent.

Test with

VLLM_ATTENTION_BACKEND=FLASHINFER NUM_DECODE_INSTANCES=1 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh