[LoRA] fix: lora loading when using with a device_mapped model. by sayakpaul · Pull Request #9449 · huggingface/diffusers (original) (raw)
What does this PR do?
Fixes LoRA loading behaviour when used with a model that is sharded into multiple devices.
Minimal code
""" Minimal example to show how to load a LoRA into the Flux transformer that is sharded in two GPUs.
Limitation:
- Latency
- If the LoRA has text encoder layers then this needs to be revisited. """
from diffusers import FluxTransformer2DModel, FluxPipeline import torch
ckpt_id = "black-forest-labs/FLUX.1-dev" dtype = torch.bfloat16 transformer = FluxTransformer2DModel.from_pretrained( ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype ) print(transformer.hf_device_map) pipeline = FluxPipeline.from_pretrained( ckpt_id, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None, vae=None, transformer=transformer, torch_dtype=dtype ) pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
print(pipeline.transformer.hf_device_map)
Essentially you'd pre-compute these embeddings beforehand.
Reference: https://gist.github.com/sayakpaul/a9266fe2d0d510ec44a9cdc385b3dd74.
example_inputs = { "prompt_embeds": torch.randn(1, 512, 4096, dtype=dtype, device="cuda"), "pooled_projections": torch.randn(1, 768, dtype=dtype, device="cuda"), }
_ = pipeline( prompt_embeds=example_inputs["prompt_embeds"], pooled_prompt_embeds=example_inputs["pooled_projections"], num_inference_steps=50, guidance_scale=3.5, height=1024, width=1024, output_type="latent", )
Some internal discussions:
- https://huggingface.slack.com/archives/C03UQJENJTV/p1725527760353639
- https://huggingface.slack.com/archives/C04L3MWLE6B/p1726470631333599
Cc: @philschmid for awareness as you were interested in this feature.
TODOs
- Tests
- Docs
Once I get a sanity review from Marc and Benjamin, will request a review from Yiyi.