[LoRA] introduce LoraBaseMixin to promote reusability. by sayakpaul · Pull Request #8774 · huggingface/diffusers (original) (raw)
Nice initiative 👍🏽 . A lot to unpack here, so perhaps it's best to start bit by bit. I just went over the pipeline related components here.
Regarding the LoraBaseMixin
, at the moment I think it might be doing a bit too much.
There are quite a few methods in there that are making assumptions about the inheriting class using the method, which isn't really how a base class should behave. So loading methods related to specific model components are better left out e.g. load_lora_into_text_encoder
. If this method is used across different pipelines with no changes, then it's better to create a utility function that does this and call it from the inheriting class. Or redefine the method in the inheriting class and use copied from.
I would assume that these are the methods that need to be defined for managing LoRAs across all pipelines?
class LoraBaseMixin:
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
raise NotImplementedError()
@classmethod
def _fetch_state_dict(
cls,
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
resume_download,
proxies,
token,
revision,
subfolder,
user_agent,
allow_pickle,
):
raise NotImplementedError()
@classmethod
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
return NotImplementedError()
@classmethod
def save_lora_weights(cls, **kwargs):
raise NotImplementedError("`save_lora_weights()` not implemented.")
@classmethod
def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.")
def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
def unload_lora_weights(self, **kwargs):
raise NotImplementedError("`unload_lora_weights()` is not implemented.")
def fuse_lora(self, **kwargs):
raise NotImplementedError("`fuse_lora()` is not implemented.")
def unfuse_lora(self, **kwargs):
raise NotImplementedError("`unfuse_lora()` is not implemented.")
def disable_lora(self):
raise NotImplementedError("`disable_lora()` is not implemented.")
def enable_lora(self):
raise NotImplementedError("`unfuse_lora()` is not implemented.")
def get_active_adapters(self):
raise NotImplementedError("`delete_adapters()` is not implemented.")
def delete_adapters(self, adapter_names):
raise NotImplementedError("`delete_adapters()` is not implemented.")
def set_lora_device(self, adapter_names):
raise NotImplementedError("`delete_adapters()` is not implemented.")
@staticmethod
def pack_weights(layers, prefix):
raise NotImplementedError()
@staticmethod
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
save_function: Callable,
safe_serialization: bool,
):
raise NotImplementedError()
@property
def lora_scale(self) -> float:
raise NotImplementedError()
Quite a few of these methods probably cannot be defined in the base class, such as load_lora_weights
and unload_lora_weights
, fuse_lora
and unfuse_lora
, since they deal with specific pipeline components
They might also require arguments specific to the pipeline type or pipeline components.
I think it might be better to define these methods in a pipeline specific class that inherits from the LoraBaseMixin
. Or just as it's own Mixin class. I don't have a strong feeling about either approach. e.g. StableDiffusionLoraLoaderMixin could look like:
class StableDiffusionLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["unet", "text_encoder"]
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
**kwargs,
):
_load_lora_into_unet(**kwargs)
_load_lora_into_text_encoder(**kwargs)
def fuse_lora(self, components=["unet", "text_encoder"], **kwargs):
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError()
model = getattr(self, fuse_component)
# check if diffusers model
if issubclass(model, ModelMixin):
model.fuse_lora()
# handle transformers models.
if issubclass(model, PretrainedModel):
fuse_text_encoder()
I saw this comment about using the term "fuse_denoiser" in the fusing methods. I'm not so sure about that. I think if we want to fuse the LoRA in a specific component, it's better to pass in the actual name of the component used in pipeline, rather than track another attribute such as denoiser
I also think the constants and class attributes such as TEXT_ENCODER_NAME and is_unet_denoiser
might not be needed if we use a single class attribute with a list of the names of the lora loadable components.