Allow DDPMPipeline half precision by sbinnee · Pull Request #9222 · huggingface/diffusers (original) (raw)
I found that DDPMPipeline couldn't run half precision.
0%| | 0/1000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[16], line 1
----> 1 out = pipeline()
File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/workspace/diffusers/src/diffusers/pipelines/ddpm/pipeline_ddpm.py:114, in DDPMPipeline.__call__(self, batch_size, generator, num_inference_steps, output_type, return_dict)
110 self.scheduler.set_timesteps(num_inference_steps)
112 for t in self.progress_bar(self.scheduler.timesteps):
113 # 1. predict noise model_output
--> 114 model_output = self.unet(image, t).sample
116 # 2. compute previous image: x_t -> x_t-1
117 image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/workspace/diffusers/src/diffusers/models/unets/unet_2d.py:303, in UNet2DModel.forward(self, sample, timestep, class_labels, return_dict)
301 # 2. pre-process
302 skip_sample = sample
--> 303 sample = self.conv_in(sample)
305 # 3. down
306 down_block_res_samples = (sample,)
File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/conv.py:458, in Conv2d.forward(self, input)
457 def forward(self, input: Tensor) -> Tensor:
--> 458 return self._conv_forward(input, self.weight, self.bias)
File ~/.conda/envs/repaint/lib/python3.11/site-packages/torch/nn/modules/conv.py:454, in Conv2d._conv_forward(self, input, weight, bias)
450 if self.padding_mode != 'zeros':
451 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
452 weight, bias, self.stride,
453 _pair(0), self.dilation, self.groups)
--> 454 return F.conv2d(input, weight, bias, self.stride,
455 self.padding, self.dilation, self.groups)
RuntimeError: Input type (float) and bias type (c10::Half) should be the same
I compared the implementation of DDPM to that of DDIM since they are alike. DDPM simply does not take unet.dtype when initializing the noise.
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.