Fix HunyuanVideo produces NaN on PyTorch<2.5 by hlky · Pull Request #10482 · huggingface/diffusers (original) (raw)

What does this PR do?

NaN tracked to

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, : -encoder_hidden_states.shape[1]],
hidden_states[:, -encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

Specifically, some elements of encoder_hidden_states.

The dimensions of query, key, value and mask are large which suggests versions <2.5 used 32-bit indexing, this tracks with #10314 if ROCm versions are still using 32-bit indexing, this may also close that issue, awaiting confirmation from user.

Tested on CUDA 2.4.1

output.mp4 Code

import torch from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel from diffusers.utils import export_to_video

model_id = "hunyuanvideo-community/HunyuanVideo" transformer = HunyuanVideoTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ) pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to("cuda") pipe.vae.enable_tiling()

output = pipe( prompt="A cat walks on the grass, realistic", height=320, width=512, num_frames=61, num_inference_steps=30, ).frames[0] export_to_video(output, "output.mp4", fps=15)

There's also a small performance increase

2.4.1 with fix 2.5.1 2.5.1 with fix
30/30 [01:56<00:00, 3.88s/it] 30/30 [02:04<00:00, 4.16s/it] [01:56<00:00, 3.89s/it]

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @a-r-r-o-w