remove attention mask for self-attention 路 huggingface/diffusers@9214f4a (original) (raw)

Original file line number Diff line number Diff line change
@@ -843,6 +843,8 @@ def __call__(
843 843 if do_classifier_free_guidance:
844 844 prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
845 845 prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
846 +if prompt_embeds.ndim == 3:
847 +prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
846 848
847 849 # 4. Prepare timesteps
848 850 timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -884,17 +886,9 @@ def __call__(
884 886 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
885 887 timestep = t.expand(latent_model_input.shape[0])
886 888
887 -if prompt_embeds.ndim == 3:
888 -prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
889 -
890 -# prepare attention_mask.
891 -# b c t h w -> b t h w
892 -attention_mask = torch.ones_like(latent_model_input)[:, 0]
893 -
894 889 # predict noise model_output
895 890 noise_pred = self.transformer(
896 -latent_model_input,
897 -attention_mask=attention_mask,
891 +hidden_states=latent_model_input,
898 892 encoder_hidden_states=prompt_embeds,
899 893 encoder_attention_mask=prompt_attention_mask,
900 894 timestep=timestep,