[LoRA] introduce LoraBaseMixin to promote reusability. by sayakpaul · Pull Request #8774 · huggingface/diffusers (original) (raw)

Nice initiative 👍🏽 . A lot to unpack here, so perhaps it's best to start bit by bit. I just went over the pipeline related components here.

Regarding the LoraBaseMixin, at the moment I think it might be doing a bit too much.

There are quite a few methods in there that are making assumptions about the inheriting class using the method, which isn't really how a base class should behave. So loading methods related to specific model components are better left out e.g. load_lora_into_text_encoder. If this method is used across different pipelines with no changes, then it's better to create a utility function that does this and call it from the inheriting class. Or redefine the method in the inheriting class and use copied from.

I would assume that these are the methods that need to be defined for managing LoRAs across all pipelines?

class LoraBaseMixin:

@classmethod
def _optionally_disable_offloading(cls, _pipeline):
    raise NotImplementedError()

@classmethod
def _fetch_state_dict(
    cls,
    pretrained_model_name_or_path_or_dict,
    weight_name,
    use_safetensors,
    local_files_only,
    cache_dir,
    force_download,
    resume_download,
    proxies,
    token,
    revision,
    subfolder,
    user_agent,
    allow_pickle,
):
    raise NotImplementedError()

@classmethod
def _best_guess_weight_name(
    cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
    return NotImplementedError()

@classmethod
def save_lora_weights(cls, **kwargs):
    raise NotImplementedError("`save_lora_weights()` not implemented.")

@classmethod
def lora_state_dict(cls, **kwargs):
    raise NotImplementedError("`lora_state_dict()` is not implemented.")

def load_lora_weights(self, **kwargs):
    raise NotImplementedError("`load_lora_weights()` is not implemented.")

def unload_lora_weights(self, **kwargs):
    raise NotImplementedError("`unload_lora_weights()` is not implemented.")

def fuse_lora(self, **kwargs):
    raise NotImplementedError("`fuse_lora()` is not implemented.")

def unfuse_lora(self, **kwargs):
    raise NotImplementedError("`unfuse_lora()` is not implemented.")

def disable_lora(self):
    raise NotImplementedError("`disable_lora()` is not implemented.")

def enable_lora(self):
    raise NotImplementedError("`unfuse_lora()` is not implemented.")

def get_active_adapters(self):
    raise NotImplementedError("`delete_adapters()` is not implemented.")

def delete_adapters(self, adapter_names):
    raise NotImplementedError("`delete_adapters()` is not implemented.")

def set_lora_device(self, adapter_names):
    raise NotImplementedError("`delete_adapters()` is not implemented.")

@staticmethod
def pack_weights(layers, prefix):
    raise NotImplementedError()

@staticmethod
def write_lora_layers(
    state_dict: Dict[str, torch.Tensor],
    save_directory: str,
    is_main_process: bool,
    weight_name: str,
    save_function: Callable,
    safe_serialization: bool,
):
    raise NotImplementedError()

@property
def lora_scale(self) -> float:
    raise NotImplementedError()

Quite a few of these methods probably cannot be defined in the base class, such as load_lora_weights and unload_lora_weights, fuse_lora and unfuse_lora, since they deal with specific pipeline components
They might also require arguments specific to the pipeline type or pipeline components.

I think it might be better to define these methods in a pipeline specific class that inherits from the LoraBaseMixin. Or just as it's own Mixin class. I don't have a strong feeling about either approach. e.g. StableDiffusionLoraLoaderMixin could look like:

class StableDiffusionLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["unet", "text_encoder"]

def load_lora_weights(
    self,
    pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
    adapter_name: Optional[str] = None,
    **kwargs,
):
    _load_lora_into_unet(**kwargs)
    _load_lora_into_text_encoder(**kwargs)

def fuse_lora(self, components=["unet", "text_encoder"], **kwargs):
    for fuse_component in components:
        if fuse_component not in self._lora_loadable_modules:
            raise ValueError()

        model = getattr(self, fuse_component)
        # check if diffusers model
        if issubclass(model, ModelMixin):
            model.fuse_lora()
        # handle transformers models. 
        if issubclass(model, PretrainedModel):
            fuse_text_encoder()

I saw this comment about using the term "fuse_denoiser" in the fusing methods. I'm not so sure about that. I think if we want to fuse the LoRA in a specific component, it's better to pass in the actual name of the component used in pipeline, rather than track another attribute such as denoiser

I also think the constants and class attributes such as TEXT_ENCODER_NAME and is_unet_denoiser might not be needed if we use a single class attribute with a list of the names of the lora loadable components.