Support latest FA3 API changes 路 JerryWu-code/diffusers@de4c6f1 (original) (raw)
`@@ -621,24 +621,32 @@ def _wrapped_flash_attn_3(
`
621
621
`) -> Tuple[torch.Tensor, torch.Tensor]:
`
622
622
`# Hardcoded for now because pytorch does not support tuple/int type hints
`
623
623
`window_size = (-1, -1)
`
624
``
`-
out, lse, *_ = flash_attn_3_func(
`
625
``
`-
q=q,
`
626
``
`-
k=k,
`
627
``
`-
v=v,
`
628
``
`-
softmax_scale=softmax_scale,
`
629
``
`-
causal=causal,
`
630
``
`-
qv=qv,
`
631
``
`-
q_descale=q_descale,
`
632
``
`-
k_descale=k_descale,
`
633
``
`-
v_descale=v_descale,
`
634
``
`-
window_size=window_size,
`
635
``
`-
attention_chunk=attention_chunk,
`
636
``
`-
softcap=softcap,
`
637
``
`-
num_splits=num_splits,
`
638
``
`-
pack_gqa=pack_gqa,
`
639
``
`-
deterministic=deterministic,
`
640
``
`-
sm_margin=sm_margin,
`
641
``
`-
)
`
``
624
+
``
625
`+
kwargs = {
`
``
626
`+
"q": q,
`
``
627
`+
"k": k,
`
``
628
`+
"v": v,
`
``
629
`+
"softmax_scale": softmax_scale,
`
``
630
`+
"causal": causal,
`
``
631
`+
"qv": qv,
`
``
632
`+
"q_descale": q_descale,
`
``
633
`+
"k_descale": k_descale,
`
``
634
`+
"v_descale": v_descale,
`
``
635
`+
"window_size": window_size,
`
``
636
`+
"attention_chunk": attention_chunk,
`
``
637
`+
"softcap": softcap,
`
``
638
`+
"num_splits": num_splits,
`
``
639
`+
"pack_gqa": pack_gqa,
`
``
640
`+
"deterministic": deterministic,
`
``
641
`+
"sm_margin": sm_margin,
`
``
642
`+
}
`
``
643
+
``
644
`+
For backward compatibility with early flash-attn-3 APIs.
`
``
645
`+
if "return_attn_probs" in inspect.signature(flash_attn_3_func).parameters:
`
``
646
`+
kwargs["return_attn_probs"] = True
`
``
647
+
``
648
`+
out, lse, *_ = flash_attn_3_func(**kwargs)
`
``
649
+
642
650
`lse = lse.permute(0, 2, 1)
`
643
651
`return out, lse
`
644
652
``
`@@ -1504,17 +1512,29 @@ def _flash_varlen_attention_3(
`
1504
1512
`key_packed = torch.cat(key_valid, dim=0)
`
1505
1513
`value_packed = torch.cat(value_valid, dim=0)
`
1506
1514
``
1507
``
`-
out, lse, *_ = flash_attn_3_varlen_func(
`
1508
``
`-
q=query_packed,
`
1509
``
`-
k=key_packed,
`
1510
``
`-
v=value_packed,
`
1511
``
`-
cu_seqlens_q=cu_seqlens_q,
`
1512
``
`-
cu_seqlens_k=cu_seqlens_k,
`
1513
``
`-
max_seqlen_q=max_seqlen_q,
`
1514
``
`-
max_seqlen_k=max_seqlen_k,
`
1515
``
`-
softmax_scale=scale,
`
1516
``
`-
causal=is_causal,
`
1517
``
`-
)
`
``
1515
`+
kwargs = {
`
``
1516
`+
"q": query_packed,
`
``
1517
`+
"k": key_packed,
`
``
1518
`+
"v": value_packed,
`
``
1519
`+
"cu_seqlens_q": cu_seqlens_q,
`
``
1520
`+
"cu_seqlens_k": cu_seqlens_k,
`
``
1521
`+
"max_seqlen_q": max_seqlen_q,
`
``
1522
`+
"max_seqlen_k": max_seqlen_k,
`
``
1523
`+
"softmax_scale": scale,
`
``
1524
`+
"causal": is_causal,
`
``
1525
`+
}
`
``
1526
+
``
1527
`+
if "return_attn_probs" in inspect.signature(flash_attn_3_varlen_func).parameters:
`
``
1528
`+
kwargs["return_attn_probs"] = return_lse
`
``
1529
`+
out = flash_attn_3_varlen_func(**kwargs)
`
``
1530
`+
if return_lse:
`
``
1531
`+
out, lse = out[0], out[1]
`
``
1532
`+
else:
`
``
1533
`+
lse = None
`
``
1534
`+
else:
`
``
1535
`+
For backward compatibility with early flash-attn-3 APIs.
`
``
1536
`+
out, lse, *_ = flash_attn_3_varlen_func(**kwargs)
`
``
1537
+
1518
1538
`out = out.unflatten(0, (batch_size, -1))
`
1519
1539
``
1520
1540
`return (out, lse) if return_lse else out
`