[WIP] SD3.5 IP-Adapter Pipeline Integration by guiyrt · Pull Request #9987 · huggingface/diffusers (original) (raw)

Would you mind posting the traceback? Maybe there's something we can do but if the issue is in Siglip we may need to raise it with Transformers team. We probably don't want to keep Siglip on GPU, it's relatively heavy like CLIP Vision right?

This happens with enable_sequential_cpu_offload() but not enable_model_cpu_offload() (I can only verify enable_model_cpu_offload() until the transformer is called, then I get OOM error). As enable_model_cpu_offload()works on model level, it copies the entire image_encoder to the GPU and no issues there. But enable_sequential_cpu_offload() works on submodule level, and only moves to GPU the parameters of that submodule when its forward() is called. In MultiHeadAttention, out_proj.weightand out_proj.bias are parameters of out_proj, so they would only be moved to GPU when out_proj.forward() was invoked, which doesn't happend. Instead, these are accessed directly outside the expected scope.

SigLIP from "google/siglip-so400m-patch14-384" has about 430M params and takes about 1GB of VRAM in torch.float16 (quick testing, just loaded to GPU).

Traceback when trying to include image_encoder in CPU offloading

Traceback (most recent call last):
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/guiyrt/diffusers/run_test.py", line 41, in <module>
    images = pipe(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/guiyrt/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 1028, in __call__
    ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
  File "/home/guiyrt/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 750, in prepare_ip_adapter_image_embeds
    single_image_embeds = self.encode_image(ip_adapter_image, device)
  File "/home/guiyrt/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 715, in encode_image
    return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1190, in forward
    return self.vision_model(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1101, in forward
    pooler_output = self.head(last_hidden_state) if self.use_head else None
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1128, in forward
    hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1368, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/functional.py", line 6251, in multi_head_attention_forward
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
    result = fn(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 83, in inner
    r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 1525, in addmm
    return out + beta * self
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
    result = fn(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 141, in _fn
    result = fn(**bound.arguments)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_refs/__init__.py", line 1099, in add
    output = prims.add(a, b)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_library/fake_impl.py", line 93, in meta_kernel
    return fake_impl_holder.kernel(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_library/utils.py", line 20, in __call__
    return self.func(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/library.py", line 1151, in inner
    return func(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 614, in fake_impl
    return self._abstract_fn(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims/__init__.py", line 402, in _prim_elementwise_meta
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 742, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device meta is not on the expected device cuda:0!

Code to reproduce

Make sure to comment out `_exclude_from_cpu_offload` in `StableDiffusion3Pipeline` (line 186)

import torch from PIL import Image

from diffusers import StableDiffusion3Pipeline from transformers import SiglipVisionModel, SiglipImageProcessor

model_path = "stabilityai/stable-diffusion-3.5-large" image_encoder_path = "google/siglip-so400m-patch14-384" ip_adapter_path = "InstantX/SD3.5-Large-IP-Adapter"

feature_extractor = SiglipImageProcessor.from_pretrained( image_encoder_path, torch_dtype=torch.bfloat16 )

image_encoder = SiglipVisionModel.from_pretrained( image_encoder_path, torch_dtype=torch.bfloat16 )

pipe = StableDiffusion3Pipeline.from_pretrained( model_path, torch_dtype=torch.bfloat16, feature_extractor=feature_extractor, image_encoder=image_encoder, ) pipe.load_ip_adapter(ip_adapter_path, revision="f1f54ca369ae759f9278ae9c87d46def9f133c78") pipe.set_ip_adapter_scale(0.6) pipe.enable_sequential_cpu_offload()

ref_img = Image.open("image.jpg").convert('RGB')

please note that SD3.5 Large is sensitive to highres generation like 1536x1536

image = pipe( width=1024, height=1024, prompt="a cat", negative_prompt="lowres, low quality, worst quality", num_inference_steps=24, guidance_scale=5.0, generator=torch.manual_seed(42), ip_adapter_image=ref_img ).images[0]

image.save("result.jpg")

Has passing image_embeds fixed the issue or no? I think we can still pass it through kwargs instead of changing the signature.

This happened with enable_model_cpu_offload() but not enable_sequential_cpu_offload(). As enable_model_cpu_offload() moves entire models to GPU when their forward() is called, image_proj as part of transformer model, would only be moved to GPU when transformer() is called, and I was accessing transformer.image_proj before that. It wasn't a problem with enable_sequential_cpu_offload() because we were still calling the forward() of image_proj, and enable_sequential_cpu_offload() works on submodule level, not model.

Yes, it fixed the issue with pipe.enable_model_cpu_offload(). I thought of using kwargs at first as well, but joint_attention_kwargs is passed all the way to the attention processor, which only expects ip_hidden_states and temb, not the image embeds from image encoder. This raises a warning when unexpected kwargs are passed to the attention processor, so if we do it that way, we need to remove it or pass new kwargs without the image encoder embeds. Unless you mean to create **kwargs for SD3Transformer2DModel.forward()?

But as long as we don't access self.transformer.image_proj directly from StableDiffusion3Pipeline, it now works with using pipe.enable_model_cpu_offload().

Traceback accessing `image_proj` from `StableDiffusion3Pipeline` (fixed in last commit)

Traceback (most recent call last):
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/guiyrt/diffusers/run_test.py", line 41, in <module>
    images = pipe(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/guiyrt/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 1049, in __call__
    ip_hidden_states, temb = self.transformer.image_proj(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/diffusers/src/diffusers/models/embeddings.py", line 2564, in forward
    timestep_emb = self.time_embedding(timestep_emb)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/diffusers/src/diffusers/models/embeddings.py", line 1304, in forward
    sample = self.linear_1(sample)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

TL;DR: