[LoRA] fix: lora loading when using with a device_mapped model. by sayakpaul · Pull Request #9449 · huggingface/diffusers (original) (raw)

What does this PR do?

Fixes LoRA loading behaviour when used with a model that is sharded into multiple devices.

Minimal code

""" Minimal example to show how to load a LoRA into the Flux transformer that is sharded in two GPUs.

Limitation:

from diffusers import FluxTransformer2DModel, FluxPipeline import torch

ckpt_id = "black-forest-labs/FLUX.1-dev" dtype = torch.bfloat16 transformer = FluxTransformer2DModel.from_pretrained( ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype ) print(transformer.hf_device_map) pipeline = FluxPipeline.from_pretrained( ckpt_id, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None, vae=None, transformer=transformer, torch_dtype=dtype ) pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")

print(pipeline.transformer.hf_device_map)

Essentially you'd pre-compute these embeddings beforehand.

Reference: https://gist.github.com/sayakpaul/a9266fe2d0d510ec44a9cdc385b3dd74.

example_inputs = { "prompt_embeds": torch.randn(1, 512, 4096, dtype=dtype, device="cuda"), "pooled_projections": torch.randn(1, 768, dtype=dtype, device="cuda"), }

_ = pipeline( prompt_embeds=example_inputs["prompt_embeds"], pooled_prompt_embeds=example_inputs["pooled_projections"], num_inference_steps=50, guidance_scale=3.5, height=1024, width=1024, output_type="latent", )

Some internal discussions:

Cc: @philschmid for awareness as you were interested in this feature.

TODOs

Once I get a sanity review from Marc and Benjamin, will request a review from Yiyi.