@@ -843,6 +843,8 @@ def __call__( |
|
|
843 |
843 |
if do_classifier_free_guidance: |
844 |
844 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
845 |
845 |
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) |
|
846 |
+if prompt_embeds.ndim == 3: |
|
847 |
+prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d |
846 |
848 |
|
847 |
849 |
# 4. Prepare timesteps |
848 |
850 |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) |
@@ -884,17 +886,9 @@ def __call__( |
|
|
884 |
886 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML |
885 |
887 |
timestep = t.expand(latent_model_input.shape[0]) |
886 |
888 |
|
887 |
|
-if prompt_embeds.ndim == 3: |
888 |
|
-prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d |
889 |
|
- |
890 |
|
-# prepare attention_mask. |
891 |
|
-# b c t h w -> b t h w |
892 |
|
-attention_mask = torch.ones_like(latent_model_input)[:, 0] |
893 |
|
- |
894 |
889 |
# predict noise model_output |
895 |
890 |
noise_pred = self.transformer( |
896 |
|
-latent_model_input, |
897 |
|
-attention_mask=attention_mask, |
|
891 |
+hidden_states=latent_model_input, |
898 |
892 |
encoder_hidden_states=prompt_embeds, |
899 |
893 |
encoder_attention_mask=prompt_attention_mask, |
900 |
894 |
timestep=timestep, |