Unload lora weights after they have been loaded · Issue #4027 · huggingface/diffusers (original) (raw)
Describe the bug
There's no method to unload LoRA weights after they have been loaded. An issue was opened already (#3689) but that doesn't solve the problem
Reproduction
- Generate without LoRA, a fresh loaded pipeline
- Generate with LoRA
- Remove LoRA:
self.text_to_image_pipeline.unet.set_attn_processor(AttnProcessor2_0())
All the settings are kept the same. It would be really great to have a built-in method in the pipeline to unload the LoRA weights, this is really helpful at inference time!
If needed, this is the code whole code:
"""Text to image component class"""
import random import sys from typing import List
import torch from compel import Compel from diffusers import DiffusionPipeline, StableDiffusionPipeline from diffusers.models.attention_processor import AttnProcessor2_0 from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from src.image_generation.components import utils from src.image_generation.components.postprocessor import Postprocessor
class TextToImage: """Initialize the text to image pipeline, and expose methods to perform inference"""
def __init__(
self,
model_pipeline: DiffusionPipeline,
postprocessor: Postprocessor = None,
enable_xformers: bool = False,
) -> None:
"""Initalize the text to image pipeline
Parameters
----------
model_pipeline : DiffusionPipeline
The base model pipeline to copy the components from
postprocessor : Postprocessor, optional
The postprocessor class, initialized and with the model loaded
Defaults to None
enable_xformers: bool, optional
If true, it enables xformers optimization for the pipelines. It needs the
package `xformer` to be installed for it to work.
Defaults to False
"""
self.device = utils.get_device()
self.postprocessor = postprocessor
self.text_to_image_pipeline = StableDiffusionPipeline(
**model_pipeline.components,
).to(self.device)
self.compel_processor = Compel(
tokenizer=self.text_to_image_pipeline.tokenizer,
text_encoder=self.text_to_image_pipeline.text_encoder,
)
if enable_xformers:
utils.enable_xformers(self.text_to_image_pipeline)
def generate(
self,
prompt: str,
negative_prompt: str = "",
width: int = 512,
height: int = 512,
num_inference_steps: int = 20,
guidance_scale: float = 7.5,
num_images_per_prompt: int = 1,
seed: int = None,
apply_gfpgan: bool = False,
lora_path: str = None,
) -> StableDiffusionPipelineOutput:
"""Performs text to image inference
Parameters
----------
prompt : str
The prompt or prompts to guide the image generation. It supports prompt
weight using `+` and `-`, according to https://github.com/damian0815/compel
negative_prompt : str, optional
The prompt or prompts not to guide the image generation
Defaults to ""
width : int, optional
The width of the generated image.
Defaults to 512
height : int, optional
The height of the generated image.
Defaults to 512
num_inference_steps : int, optional
The number of denoising steps. More denoising steps usually lead to a higher
quality image at the expense of slower inference.
Defaults to 20
guidance_scale : float, optional
Higher guidance scale encourages to generate images that are closely linked
to the text prompt, usually at the expense of lower image quality.
Defaults to 7.5
num_images_per_prompt : int, optional
The number of generated images per prompt
Defaults to 1
seed : int, optional
An integer used as random seed to have reproducible results
If None, random seeds will be generated.
Defaults to None
apply_gfpgan : bool, optional
If true, applies GPFGAN on the generated images. GFPGAN must be initialized
in the constructor, otherwise a warning is raised.
Defaults to False
lora_path : str, optional
If specified, it loads a LoRA weight before running inference
Defaults to None
Returns
-------
StableDiffusionPipelineOutput:
An object containing the list of generated images, and a flag expressing if
NSFW was detected.
The images can be accessed using the `images` property, and the NSFW flag
using the `nsfw_content_detected` property.
"""
if seed is None: # set random seed
seed = random.randint(0, sys.maxsize)
if lora_path is not None: # load LoRA
self.text_to_image_pipeline.load_lora_weights(lora_path)
else: # unload LoRA
self.text_to_image_pipeline.unet.set_attn_processor(AttnProcessor2_0())
result = self.text_to_image_pipeline(
prompt_embeds=self.compel_processor(prompt),
negative_prompt_embeds=self.compel_processor(negative_prompt),
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=torch.manual_seed(seed),
)
# apply postprocessing if model is initialized and flag set to true
if apply_gfpgan and self.postprocessor:
result.images = self.postprocessor.apply_gfpgan(result)
return result
Logs
No response
System Info
Diffusers 0.18.1, python 3.10.11