Cast precision for custom diffusion attention processor. · Issue #4139 · huggingface/diffusers (original) (raw)

Describe the bug

I see in lora, the dtype are explicitly upcast:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L529

However, this is not done in custom diffusion. The weights need to be FP32 for mixed precision training, and during validation custom diffusion will throw an error about precision mismatch. Autocast is not recommended in inference AFAIK.

Reproduction

Run the custom diffusion example with BF16.
Make sure you have enough steps (or lower validation epochs) so the validation during training is run. The final inference has a unet.to(fp32) inside so it runs without any problems.

Logs

Something like expected BFloat16 but found Float at key = self.to_k_custom_diffusion(encoder_hidden_states)

System Info

Who can help?

No response