Support latest FA3 API changes 路 JerryWu-code/diffusers@de4c6f1 (original) (raw)

`@@ -621,24 +621,32 @@ def _wrapped_flash_attn_3(

`

621

621

`) -> Tuple[torch.Tensor, torch.Tensor]:

`

622

622

`# Hardcoded for now because pytorch does not support tuple/int type hints

`

623

623

`window_size = (-1, -1)

`

624

``

`-

out, lse, *_ = flash_attn_3_func(

`

625

``

`-

q=q,

`

626

``

`-

k=k,

`

627

``

`-

v=v,

`

628

``

`-

softmax_scale=softmax_scale,

`

629

``

`-

causal=causal,

`

630

``

`-

qv=qv,

`

631

``

`-

q_descale=q_descale,

`

632

``

`-

k_descale=k_descale,

`

633

``

`-

v_descale=v_descale,

`

634

``

`-

window_size=window_size,

`

635

``

`-

attention_chunk=attention_chunk,

`

636

``

`-

softcap=softcap,

`

637

``

`-

num_splits=num_splits,

`

638

``

`-

pack_gqa=pack_gqa,

`

639

``

`-

deterministic=deterministic,

`

640

``

`-

sm_margin=sm_margin,

`

641

``

`-

)

`

``

624

+

``

625

`+

kwargs = {

`

``

626

`+

"q": q,

`

``

627

`+

"k": k,

`

``

628

`+

"v": v,

`

``

629

`+

"softmax_scale": softmax_scale,

`

``

630

`+

"causal": causal,

`

``

631

`+

"qv": qv,

`

``

632

`+

"q_descale": q_descale,

`

``

633

`+

"k_descale": k_descale,

`

``

634

`+

"v_descale": v_descale,

`

``

635

`+

"window_size": window_size,

`

``

636

`+

"attention_chunk": attention_chunk,

`

``

637

`+

"softcap": softcap,

`

``

638

`+

"num_splits": num_splits,

`

``

639

`+

"pack_gqa": pack_gqa,

`

``

640

`+

"deterministic": deterministic,

`

``

641

`+

"sm_margin": sm_margin,

`

``

642

`+

}

`

``

643

+

``

644

`+

For backward compatibility with early flash-attn-3 APIs.

`

``

645

`+

if "return_attn_probs" in inspect.signature(flash_attn_3_func).parameters:

`

``

646

`+

kwargs["return_attn_probs"] = True

`

``

647

+

``

648

`+

out, lse, *_ = flash_attn_3_func(**kwargs)

`

``

649

+

642

650

`lse = lse.permute(0, 2, 1)

`

643

651

`return out, lse

`

644

652

``

`@@ -1504,17 +1512,29 @@ def _flash_varlen_attention_3(

`

1504

1512

`key_packed = torch.cat(key_valid, dim=0)

`

1505

1513

`value_packed = torch.cat(value_valid, dim=0)

`

1506

1514

``

1507

``

`-

out, lse, *_ = flash_attn_3_varlen_func(

`

1508

``

`-

q=query_packed,

`

1509

``

`-

k=key_packed,

`

1510

``

`-

v=value_packed,

`

1511

``

`-

cu_seqlens_q=cu_seqlens_q,

`

1512

``

`-

cu_seqlens_k=cu_seqlens_k,

`

1513

``

`-

max_seqlen_q=max_seqlen_q,

`

1514

``

`-

max_seqlen_k=max_seqlen_k,

`

1515

``

`-

softmax_scale=scale,

`

1516

``

`-

causal=is_causal,

`

1517

``

`-

)

`

``

1515

`+

kwargs = {

`

``

1516

`+

"q": query_packed,

`

``

1517

`+

"k": key_packed,

`

``

1518

`+

"v": value_packed,

`

``

1519

`+

"cu_seqlens_q": cu_seqlens_q,

`

``

1520

`+

"cu_seqlens_k": cu_seqlens_k,

`

``

1521

`+

"max_seqlen_q": max_seqlen_q,

`

``

1522

`+

"max_seqlen_k": max_seqlen_k,

`

``

1523

`+

"softmax_scale": scale,

`

``

1524

`+

"causal": is_causal,

`

``

1525

`+

}

`

``

1526

+

``

1527

`+

if "return_attn_probs" in inspect.signature(flash_attn_3_varlen_func).parameters:

`

``

1528

`+

kwargs["return_attn_probs"] = return_lse

`

``

1529

`+

out = flash_attn_3_varlen_func(**kwargs)

`

``

1530

`+

if return_lse:

`

``

1531

`+

out, lse = out[0], out[1]

`

``

1532

`+

else:

`

``

1533

`+

lse = None

`

``

1534

`+

else:

`

``

1535

`+

For backward compatibility with early flash-attn-3 APIs.

`

``

1536

`+

out, lse, *_ = flash_attn_3_varlen_func(**kwargs)

`

``

1537

+

1518

1538

`out = out.unflatten(0, (batch_size, -1))

`

1519

1539

``

1520

1540

`return (out, lse) if return_lse else out

`