[WIP] Sample images when checkpointing. by LucasSloan · Pull Request #2157 · huggingface/diffusers (original) (raw)
Unfortunately, this code doesn't work at present and I'm not sure why. I get the error RuntimeError: Input type (c10::Half) and bias type (float) should be the same
, full stack trace:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/c/Users/lucas/Development/diffusers/examples/text_to_image/train_text_to_image.py:757 in │
│ <module> │
│ │
│ 754 │
│ 755 │
│ 756 if __name__ == "__main__": │
│ ❱ 757 │ main() │
│ 758 │
│ │
│ /mnt/c/Users/lucas/Development/diffusers/examples/text_to_image/train_text_to_image.py:720 in │
│ main │
│ │
│ 717 │ │ │ │ │ │ │ │
│ 718 │ │ │ │ │ │ │ # run inference │
│ 719 │ │ │ │ │ │ │ prompt = [args.validation_prompt] │
│ ❱ 720 │ │ │ │ │ │ │ images = pipeline(prompt, num_images_per_prompt=args.num_val │
│ 721 │ │ │ │ │ │ │ │
│ 722 │ │ │ │ │ │ │ for i, image in enumerate(images): │
│ 723 │ │ │ │ │ │ │ │ image.save(os.path.join(args.output_dir, f"sample-{globa │
│ │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /home/lucas/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_sta │
│ ble_diffusion.py:611 in __call__ │
│ │
│ 608 │ │ │ │ latent_model_input = self.scheduler.scale_model_input(latent_model_input │
│ 609 │ │ │ │ │
│ 610 │ │ │ │ # predict the noise residual │
│ ❱ 611 │ │ │ │ noise_pred = self.unet( │
│ 612 │ │ │ │ │ latent_model_input, │
│ 613 │ │ │ │ │ t, │
│ 614 │ │ │ │ │ encoder_hidden_states=prompt_embeds, │
│ │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1488 in _call_impl │
│ │
│ 1485 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1486 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1487 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1488 │ │ │ return forward_call(*args, **kwargs) │
│ 1489 │ │ # Do not call functions when jit is used │
│ 1490 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1491 │ │ backward_pre_hooks = [] │
│ │
│ /home/lucas/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py:482 in │
│ forward │
│ │
│ 479 │ │ │ emb = emb + class_emb │
│ 480 │ │ │
│ 481 │ │ # 2. pre-process │
│ ❱ 482 │ │ sample = self.conv_in(sample) │
│ 483 │ │ │
│ 484 │ │ # 3. down │
│ 485 │ │ down_block_res_samples = (sample,) │
│ │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1488 in _call_impl │
│ │
│ 1485 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1486 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1487 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1488 │ │ │ return forward_call(*args, **kwargs) │
│ 1489 │ │ # Do not call functions when jit is used │
│ 1490 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1491 │ │ backward_pre_hooks = [] │
│ │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463 in forward │
│ │
│ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │
│ 461 │ │
│ 462 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 463 │ │ return self._conv_forward(input, self.weight, self.bias) │
│ 464 │
│ 465 class Conv3d(_ConvNd): │
│ 466 │ __doc__ = r"""Applies a 3D convolution over an input signal composed of several inpu │
│ │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459 in _conv_forward │
│ │
│ 456 │ │ │ return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel │
│ 457 │ │ │ │ │ │ │ weight, bias, self.stride, │
│ 458 │ │ │ │ │ │ │ _pair(0), self.dilation, self.groups) │
│ ❱ 459 │ │ return F.conv2d(input, weight, bias, self.stride, │
│ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │
│ 461 │ │
│ 462 │ def forward(self, input: Tensor) -> Tensor: │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (c10::Half) and bias type (float) should be the same
I tried to fix it on line 713 by setting torch_dtype=weight_dtype
on the StableDiffusionPipeline, but that didn't work.