fix scale_shift_factor being on cpu for wan and ltx by vladmandic · Pull Request #12347 · huggingface/diffusers (original) (raw)

Conversation

This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters

[ Show hidden characters]({{ revealButtonHref }})

@vladmandic

wan transformer block creates scale_shift_table on cpu and then adds it regardless of where temb tensor actually resides
and this causes typical cpu-vs-cuda device mismatch

│  473 │   │   │   shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (                                                                                                                                                                                                                                                                                                                                        │
│❱ 474 │   │   │   │   self.scale_shift_table + temb.float()                                                                                                                                                                                                                                                                                                                                                                       │
│  475 │   │   │   ).chunk(6, dim=1)                                                                                                                                                                                                                                                                                                                                                                                               │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

cc @sayakpaul @yiyixuxu @a-r-r-o-w @DN6

@vladmandic

@vladmandic

@DN6

@vladmandic Do you have a snippet to reproduce the error?

@vladmandic

@vladmandic Do you have a snippet to reproduce the error?

I dont have easy reproduction as it happens relatively randomly when offloading is enabled and on low-memory systems.
but if you look at other pipelines (e.g. sana, pixart, etc.) that implement similar methods, they all guard against this.
and i dont see any risk/side-effects of the device cast.

@vladmandic

@vladmandic

@vladmandic vladmandic changed the titlewan fix scale_shift_factor being on cpu fix scale_shift_factor being on cpu for wan and ltx

Sep 25, 2025

@vladmandic

update: same issue occurs for some of my users for ltxvideo so i've extended the pr with the same fix in that pipeline.
see vladmandic/sdnext#4223 for details

@DN6

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vladmandic

gentle nudge - any particular reason why this pr is not merged?
i understand lack of clear reproduction, but on the other hand, same guards exist in other pipelines plus its clear that scale_shift_factor tensor is created on cpu so even if it works, the operation is performed on cpu which is non-desirable.

@DN6

Hi @vladmandic sorry for the delay. It's just that since the introduction of hooks for offloading and casting, we're trying to avoid direct assignment of device, dtype in the models in case of unexpected behaviour, which is why I asked for the repro to check if there's something else going on.

Could you just share the type of offloading you're using? Is it group offloading?

@vladmandic

@DN6 i understand the general desire, however:

if you don't want to merge this pr, then please propose an alternative as right now my users are blocked.

@vladmandic

@DN6

@github-actions

Style bot fixed some files and pushed the changes.

@github-actions

DN6

DN6 approved these changes Oct 5, 2025

@DN6

DN6 commented

Oct 5, 2025

• Loading

@vladmandic Parameters in the module are subject to offloading hooks

parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)

The direct casts in Pixart and SANA were added before offloading hooks were introduced and no other instances use this type of casting (it shouldn't be necessary)

Based on your description it seems to be an issue with custom offloading logic applied in non-CUDA environments. This really feels like an edge case to me and the issue doesn't seem related to diffusers code.

However, I've run the offloading tests on the models and they're passing so I'm okay to merge this, but FYI, there is a chance we remove direct casts from models in the future if we see any issues. I'll will give you a heads up if that's the case to avoid breaking anything downstream.

@vladmandic

thanks @DN6

Based on your description it seems to be an issue with custom offloading logic applied in non-CUDA environments.

same custom offloading logic is applied everywhere in sdnext, not specific to rocm/zluda, its just that this problem occurs only for those torch versions.

sayakpaul pushed a commit that referenced this pull request

Oct 15, 2025


Co-authored-by: Dhruv Nair dhruv.nair@gmail.com Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>