Fix HunyuanVideo produces NaN on PyTorch<2.5 by hlky · Pull Request #10482 · huggingface/diffusers (original) (raw)
What does this PR do?
NaN
tracked to
hidden_states = F.scaled_dot_product_attention( |
---|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
) |
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) |
hidden_states = hidden_states.to(query.dtype) |
# 6. Output projection |
if encoder_hidden_states is not None: |
hidden_states, encoder_hidden_states = ( |
hidden_states[:, : -encoder_hidden_states.shape[1]], |
hidden_states[:, -encoder_hidden_states.shape[1] :], |
) |
if getattr(attn, "to_out", None) is not None: |
hidden_states = attn.to_out[0](hidden_states) |
hidden_states = attn.to_out[1](hidden_states) |
if getattr(attn, "to_add_out", None) is not None: |
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
Specifically, some elements of encoder_hidden_states
.
The dimensions of query, key, value and mask are large which suggests versions <2.5 used 32-bit indexing, this tracks with #10314 if ROCm versions are still using 32-bit indexing, this may also close that issue, awaiting confirmation from user.
Tested on CUDA 2.4.1
output.mp4 Code
import torch from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel from diffusers.utils import export_to_video
model_id = "hunyuanvideo-community/HunyuanVideo" transformer = HunyuanVideoTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ) pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to("cuda") pipe.vae.enable_tiling()
output = pipe( prompt="A cat walks on the grass, realistic", height=320, width=512, num_frames=61, num_inference_steps=30, ).frames[0] export_to_video(output, "output.mp4", fps=15)
There's also a small performance increase
2.4.1 with fix | 2.5.1 | 2.5.1 with fix |
---|---|---|
30/30 [01:56<00:00, 3.88s/it] | 30/30 [02:04<00:00, 4.16s/it] | [01:56<00:00, 3.89s/it] |
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.