First Block Cache by a-r-r-o-w · Pull Request #11180 · huggingface/diffusers (original) (raw)
This is a great branch, but I encountered this problem when testing it. The first method will fail, perhaps because some logic is not implemented.
Detailed Error
apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.3))
prompt = "A photo of an astronaut riding a horse on mars"
image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")
0%| | 0/50 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[12], line 4
1 apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.3))
3 prompt = "A photo of an astronaut riding a horse on mars"
----> 4 image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
5 image.save("output.png")
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4.py:623, in CogView4Pipeline.__call__(self, prompt, negative_prompt, height, width, num_inference_steps, timesteps, sigmas, guidance_scale, num_images_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, original_size, crops_coords_top_left, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
620 timestep = t.expand(latents.shape[0])
622 with self.transformer.cache_context("cond"):
--> 623 noise_pred_cond = self.transformer(
624 hidden_states=latent_model_input,
625 encoder_hidden_states=prompt_embeds,
626 timestep=timestep,
627 original_size=original_size,
628 target_size=target_size,
629 crop_coords=crops_coords_top_left,
630 attention_kwargs=attention_kwargs,
631 return_dict=False,
632 )[0]
634 # perform guidance
635 if self.do_classifier_free_guidance:
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/models/transformers/transformer_cogview4.py:740, in CogView4Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, original_size, target_size, crop_coords, attention_kwargs, return_dict, attention_mask, image_rotary_emb)
730 hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
731 block,
732 hidden_states,
(...)
737 attention_kwargs,
738 )
739 else:
--> 740 hidden_states, encoder_hidden_states = block(
741 hidden_states,
742 encoder_hidden_states,
743 temb,
744 image_rotary_emb,
745 attention_mask,
746 attention_kwargs,
747 )
749 # 4. Output norm & projection
750 hidden_states = self.norm_out(hidden_states, temb)
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/hooks/hooks.py:189, in HookRegistry.register_hook.<locals>.create_new_forward.<locals>.new_forward(module, *args, **kwargs)
187 def new_forward(module, *args, **kwargs):
188 args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
--> 189 output = function_reference.forward(*args, **kwargs)
190 return function_reference.post_forward(module, output)
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/hooks/first_block_cache.py:89, in FBCHeadBlockHook.new_forward(self, module, *args, **kwargs)
86 else:
87 hidden_states_residual = output - original_hidden_states
---> 89 shared_state: FBCSharedBlockState = self.state_manager.get_state()
90 hidden_states = encoder_hidden_states = None
91 should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/hooks/hooks.py:44, in StateManager.get_state(self)
42 def get_state(self):
43 if self._current_context is None:
---> 44 raise ValueError("No context is set. Please set a context before retrieving the state.")
45 if self._current_context not in self._state_cache.keys():
46 self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
ValueError: No context is set. Please set a context before retrieving the state.
