[WIP] Sample images when checkpointing. by LucasSloan · Pull Request #2157 · huggingface/diffusers (original) (raw)

Unfortunately, this code doesn't work at present and I'm not sure why. I get the error RuntimeError: Input type (c10::Half) and bias type (float) should be the same, full stack trace:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/c/Users/lucas/Development/diffusers/examples/text_to_image/train_text_to_image.py:757 in    │
│ <module>                                                                                         │
│                                                                                                  │
│   754                                                                                            │
│   755                                                                                            │
│   756 if __name__ == "__main__":                                                                 │
│ ❱ 757 │   main()                                                                                 │
│   758                                                                                            │
│                                                                                                  │
│ /mnt/c/Users/lucas/Development/diffusers/examples/text_to_image/train_text_to_image.py:720 in    │
│ main                                                                                             │
│                                                                                                  │
│   717 │   │   │   │   │   │   │                                                                  │
│   718 │   │   │   │   │   │   │   # run inference                                                │
│   719 │   │   │   │   │   │   │   prompt = [args.validation_prompt]                              │
│ ❱ 720 │   │   │   │   │   │   │   images = pipeline(prompt, num_images_per_prompt=args.num_val   │
│   721 │   │   │   │   │   │   │                                                                  │
│   722 │   │   │   │   │   │   │   for i, image in enumerate(images):                             │
│   723 │   │   │   │   │   │   │   │   image.save(os.path.join(args.output_dir, f"sample-{globa   │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/utils/_contextlib.py:115 in                 │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_sta │
│ ble_diffusion.py:611 in __call__                                                                 │
│                                                                                                  │
│   608 │   │   │   │   latent_model_input = self.scheduler.scale_model_input(latent_model_input   │
│   609 │   │   │   │                                                                              │
│   610 │   │   │   │   # predict the noise residual                                               │
│ ❱ 611 │   │   │   │   noise_pred = self.unet(                                                    │
│   612 │   │   │   │   │   latent_model_input,                                                    │
│   613 │   │   │   │   │   t,                                                                     │
│   614 │   │   │   │   │   encoder_hidden_states=prompt_embeds,                                   │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1488 in _call_impl     │
│                                                                                                  │
│   1485 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1486 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1487 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1488 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1489 │   │   # Do not call functions when jit is used                                          │
│   1490 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1491 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py:482 in      │
│ forward                                                                                          │
│                                                                                                  │
│   479 │   │   │   emb = emb + class_emb                                                          │
│   480 │   │                                                                                      │
│   481 │   │   # 2. pre-process                                                                   │
│ ❱ 482 │   │   sample = self.conv_in(sample)                                                      │
│   483 │   │                                                                                      │
│   484 │   │   # 3. down                                                                          │
│   485 │   │   down_block_res_samples = (sample,)                                                 │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1488 in _call_impl     │
│                                                                                                  │
│   1485 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1486 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1487 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1488 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1489 │   │   # Do not call functions when jit is used                                          │
│   1490 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1491 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463 in forward           │
│                                                                                                  │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
│ ❱  463 │   │   return self._conv_forward(input, self.weight, self.bias)                          │
│    464                                                                                           │
│    465 class Conv3d(_ConvNd):                                                                    │
│    466 │   __doc__ = r"""Applies a 3D convolution over an input signal composed of several inpu  │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459 in _conv_forward     │
│                                                                                                  │
│    456 │   │   │   return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel  │
│    457 │   │   │   │   │   │   │   weight, bias, self.stride,                                    │
│    458 │   │   │   │   │   │   │   _pair(0), self.dilation, self.groups)                         │
│ ❱  459 │   │   return F.conv2d(input, weight, bias, self.stride,                                 │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (c10::Half) and bias type (float) should be the same

I tried to fix it on line 713 by setting torch_dtype=weight_dtype on the StableDiffusionPipeline, but that didn't work.