Allow DDPMPipeline half precision by sbinnee · Pull Request #9222 · huggingface/diffusers (original) (raw)

I found that DDPMPipeline couldn't run half precision.

  0%|                                                                                           | 0/1000 [00:00<?, ?it/s]

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], line 1
----> 1 out = pipeline()

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/workspace/diffusers/src/diffusers/pipelines/ddpm/pipeline_ddpm.py:114, in DDPMPipeline.__call__(self, batch_size, generator, num_inference_steps, output_type, return_dict)
    110 self.scheduler.set_timesteps(num_inference_steps)
    112 for t in self.progress_bar(self.scheduler.timesteps):
    113     # 1. predict noise model_output
--> 114     model_output = self.unet(image, t).sample
    116     # 2. compute previous image: x_t -> x_t-1
    117     image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/workspace/diffusers/src/diffusers/models/unets/unet_2d.py:303, in UNet2DModel.forward(self, sample, timestep, class_labels, return_dict)
    301 # 2. pre-process
    302 skip_sample = sample
--> 303 sample = self.conv_in(sample)
    305 # 3. down
    306 down_block_res_samples = (sample,)

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/conv.py:458, in Conv2d.forward(self, input)
    457 def forward(self, input: Tensor) -> Tensor:
--> 458     return self._conv_forward(input, self.weight, self.bias)

File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/conv.py:454, in Conv2d._conv_forward(self, input, weight, bias)
    450 if self.padding_mode != 'zeros':
    451     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    452                     weight, bias, self.stride,
    453                     _pair(0), self.dilation, self.groups)
--> 454 return F.conv2d(input, weight, bias, self.stride,
    455                 self.padding, self.dilation, self.groups)

RuntimeError: Input type (float) and bias type (c10::Half) should be the same

I compared the implementation of DDPM to that of DDIM since they are alike. DDPM simply does not take unet.dtype when initializing the noise.

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.