Added the ability to set SDXL Micro-Conditioning embeddings as 0 · Issue #4208 · huggingface/diffusers (original) (raw)

Is your feature request related to a problem? Please describe.

During the SDXL training process, it may be necessary to pass in a zero embedding as Micro-Conditioning embeddings:

https://github.com/Stability-AI/generative-models/blob/e25e4c0df1d01fb9720f62c73b4feab2e4003e3f/sgm/modules/encoders/modules.py#L151-L161

those line will randomly set embedding as zero if ucg_rate > 0

            if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
                emb = (
                    expand_dims_like(
                        torch.bernoulli(
                            (1.0 - embedder.ucg_rate)
                            * torch.ones(emb.shape[0], device=emb.device)
                        ),
                        emb,
                    )
                    * emb
                )

https://github.com/Stability-AI/generative-models/blob/e25e4c0df1d01fb9720f62c73b4feab2e4003e3f/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml#L65

SDXL set the ucg_rate of original_size_as_tuple embedder as 0.1.

so during traning, we need to pass zero embedding as added embedding for time embedding of Unet

        ucg_rate: 0.1
        input_key: original_size_as_tuple
        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
        params:
          outdim: 256  # multiplied by two

Current SDXL-UNet2DConditionModel accepts encoder_hidden_states, time_ids and add_text_embeds as condition.

text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)

To correctly finetune the SDXL model, we need to randomly set the condition embeddings to 0 with a suitable probability.
While it is easy to set encoder_hidden_states and add_text_embeds as zero embedding, It is impossible to zero time_embeds at line 849.

original SDXL uses different embedders to convert different micro-conditions into Fourier features. during training, different Fourier features are independently randomly set to 0. Therefore, UNet2DConditionModel need to be able to independently zero time_embeds part.

Describe the solution you'd like

Added the ability to set SDXL Micro-Conditioning embeddings as 0.

Describe alternatives you've considered

Perhaps it is possible to allow diffusers users to pass in a time_embeds, and if time_embeds exists, time_ids are no longer used?

if "time_embeds" in added_cond_kwargs: time_embeds = added_cond_kwargs.get("time_embeds") else: time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))