Avoiding graph break by changing the way we infer dtype in vae.decoder by ppadjinTT · Pull Request #12512 · huggingface/diffusers (original) (raw)
…sors
What does this PR do?
This PR addresses the problem disscused in #12501, where the usage of upscale_dtype = next(iter(self.up_blocks.parameters())).dtype to infer the dtype in the forward pass of the vae.decoder causes the graph break when compiling the model with torch.compile.
The issue is that the usage of next(iter(...)) forces the lazy tensors in the initial compiled model pass to materialize, resulting in graph break, which decreases performance.
This PR proposes a simple fix by infering the dtype as:
upscale_dtype = self.conv_out.weight.dtype
Fixes #12501