Flux fp16 inference fix by latentCall145 · Pull Request #9097 · huggingface/diffusers (original) (raw)

As this issue mentions, FP16 significantly changes the result of the images. This issue surprisingly has to do with the text encoders (and not the clipping). Specifically, some activations in the text encoders have to be clipped when running in FP16 (it's a dynamic range problem, not a precision one). Forcing FP32 inference on the text encoders thus allows FP16 DiT + VAE inference to be similar to FP32/BF16.

Reproduction

from diffusers import FluxPipeline import matplotlib.pyplot as plt import torch import time torch.backends.cudnn.benchmark = True

DTYPE = torch.float16

ckpt_id = "black-forest-labs/FLUX.1-schnell" pipe = FluxPipeline.from_pretrained( ckpt_id, torch_dtype=torch.bfloat16, ) pipe.enable_sequential_cpu_offload() pipe.vae.enable_tiling() pipe.to(DTYPE)

images = pipe( 'A laptop whose screen displays a picture of a black forest gateau cake spelling out the words "FLUX SCHNELL". The laptop screen, keyboard, and the table is on fire. no watermark, photograph', num_inference_steps=1, num_images_per_prompt=1, guidance_scale=0.0, height=1024, width=1024, generator=torch.Generator(device='cuda').manual_seed(0), # device='cpu' results in different random tensors across different dtypes? ).images

plt.imshow(images[0]) plt.show()

Prompt

A laptop whose screen displays a picture of a black forest gateau cake spelling out the words "FLUX SCHNELL". The laptop screen, keyboard, and the table is on fire. no watermark, photograph

Other

num_inference_steps = 1
height = width = 1024

Outputs (clipped)

catted
left to right: fp32, bf16, fp16

Outputs (clipped, fp32 text encoders)

catted
left to right: fp32, bf16, fp16