Added the ability to set SDXL Micro-Conditioning
embeddings as 0 · Issue #4208 · huggingface/diffusers (original) (raw)
Is your feature request related to a problem? Please describe.
During the SDXL training process, it may be necessary to pass in a zero embedding as Micro-Conditioning
embeddings:
those line will randomly set embedding as zero if ucg_rate
> 0
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
emb = (
expand_dims_like(
torch.bernoulli(
(1.0 - embedder.ucg_rate)
* torch.ones(emb.shape[0], device=emb.device)
),
emb,
)
* emb
)
SDXL set the ucg_rate
of original_size_as_tuple
embedder as 0.1.
so during traning, we need to pass zero embedding as added embedding for time embedding of Unet
ucg_rate: 0.1
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
Current SDXL-UNet2DConditionModel
accepts encoder_hidden_states
, time_ids
and add_text_embeds
as condition.
text_embeds = added_cond_kwargs.get("text_embeds") |
---|
if "time_ids" not in added_cond_kwargs: |
raise ValueError( |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" |
) |
time_ids = added_cond_kwargs.get("time_ids") |
time_embeds = self.add_time_proj(time_ids.flatten()) |
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) |
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) |
add_embeds = add_embeds.to(emb.dtype) |
aug_emb = self.add_embedding(add_embeds) |
To correctly finetune the SDXL model, we need to randomly set the condition embeddings to 0 with a suitable probability.
While it is easy to set encoder_hidden_states
and add_text_embeds
as zero embedding, It is impossible to zero time_embeds
at line 849.
original SDXL uses different embedders to convert different micro-conditions into Fourier features. during training, different Fourier features are independently randomly set to 0. Therefore, UNet2DConditionModel
need to be able to independently zero time_embeds
part.
Describe the solution you'd like
Added the ability to set SDXL Micro-Conditioning
embeddings as 0.
Describe alternatives you've considered
Perhaps it is possible to allow diffusers users to pass in a time_embeds
, and if time_embeds
exists, time_ids
are no longer used?
if "time_embeds" in added_cond_kwargs: time_embeds = added_cond_kwargs.get("time_embeds") else: time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))