Unload lora weights after they have been loaded · Issue #4027 · huggingface/diffusers (original) (raw)

Describe the bug

There's no method to unload LoRA weights after they have been loaded. An issue was opened already (#3689) but that doesn't solve the problem

Reproduction

  1. Generate without LoRA, a fresh loaded pipeline

image

  1. Generate with LoRA

image

  1. Remove LoRA:
self.text_to_image_pipeline.unet.set_attn_processor(AttnProcessor2_0())

image

All the settings are kept the same. It would be really great to have a built-in method in the pipeline to unload the LoRA weights, this is really helpful at inference time!

If needed, this is the code whole code:

"""Text to image component class"""

import random import sys from typing import List

import torch from compel import Compel from diffusers import DiffusionPipeline, StableDiffusionPipeline from diffusers.models.attention_processor import AttnProcessor2_0 from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

from src.image_generation.components import utils from src.image_generation.components.postprocessor import Postprocessor

class TextToImage: """Initialize the text to image pipeline, and expose methods to perform inference"""

def __init__(
    self,
    model_pipeline: DiffusionPipeline,
    postprocessor: Postprocessor = None,
    enable_xformers: bool = False,
) -> None:
    """Initalize the text to image pipeline

    Parameters
    ----------
    model_pipeline : DiffusionPipeline
        The base model pipeline to copy the components from
    postprocessor : Postprocessor, optional
        The postprocessor class, initialized and with the model loaded

        Defaults to None
    enable_xformers: bool, optional
        If true, it enables xformers optimization for the pipelines. It needs the
        package `xformer` to be installed for it to work.

        Defaults to False
    """
    self.device = utils.get_device()
    self.postprocessor = postprocessor

    self.text_to_image_pipeline = StableDiffusionPipeline(
        **model_pipeline.components,
    ).to(self.device)

    self.compel_processor = Compel(
        tokenizer=self.text_to_image_pipeline.tokenizer,
        text_encoder=self.text_to_image_pipeline.text_encoder,
    )

    if enable_xformers:
        utils.enable_xformers(self.text_to_image_pipeline)

def generate(
    self,
    prompt: str,
    negative_prompt: str = "",
    width: int = 512,
    height: int = 512,
    num_inference_steps: int = 20,
    guidance_scale: float = 7.5,
    num_images_per_prompt: int = 1,
    seed: int = None,
    apply_gfpgan: bool = False,
    lora_path: str = None,
) -> StableDiffusionPipelineOutput:
    """Performs text to image inference

    Parameters
    ----------
    prompt : str
        The prompt or prompts to guide the image generation. It supports prompt
        weight using `+` and `-`, according to https://github.com/damian0815/compel
    negative_prompt : str, optional
        The prompt or prompts not to guide the image generation

        Defaults to ""
    width : int, optional
        The width of the generated image.

        Defaults to 512
    height : int, optional
        The height of the generated image.

        Defaults to 512
    num_inference_steps : int, optional
        The number of denoising steps. More denoising steps usually lead to a higher
        quality image at the expense of slower inference.

        Defaults to 20
    guidance_scale : float, optional
        Higher guidance scale encourages to generate images that are closely linked
        to the text prompt, usually at the expense of lower image quality.

        Defaults to 7.5
    num_images_per_prompt : int, optional
        The number of generated images per prompt

        Defaults to 1
    seed : int, optional
        An integer used as random seed to have reproducible results
        If None, random seeds will be generated.

        Defaults to None
    apply_gfpgan : bool, optional
        If true, applies GPFGAN on the generated images. GFPGAN must be initialized
        in the constructor, otherwise a warning is raised.

        Defaults to False
    lora_path : str, optional
        If specified, it loads a LoRA weight before running inference

        Defaults to None

    Returns
    -------
    StableDiffusionPipelineOutput:
        An object containing the list of generated images, and a flag expressing if
        NSFW was detected.

        The images can be accessed using the `images` property, and the NSFW flag
        using the `nsfw_content_detected` property.
    """
    if seed is None:  # set random seed
        seed = random.randint(0, sys.maxsize)
    
    if lora_path is not None:  # load LoRA
        self.text_to_image_pipeline.load_lora_weights(lora_path)
    else: # unload LoRA
        self.text_to_image_pipeline.unet.set_attn_processor(AttnProcessor2_0())

    result = self.text_to_image_pipeline(
        prompt_embeds=self.compel_processor(prompt),
        negative_prompt_embeds=self.compel_processor(negative_prompt),
        width=width,
        height=height,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
        generator=torch.manual_seed(seed),
    )

    # apply postprocessing if model is initialized and flag set to true
    if apply_gfpgan and self.postprocessor:
        result.images = self.postprocessor.apply_gfpgan(result)

    return result

Logs

No response

System Info

Diffusers 0.18.1, python 3.10.11

Who can help?

@sayakpaul