First Block Cache by a-r-r-o-w · Pull Request #11180 · huggingface/diffusers (original) (raw)

This is a great branch, but I encountered this problem when testing it. The first method will fail, perhaps because some logic is not implemented.

model: CogVIew4-6B
image

Detailed Error


apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.3))

prompt = "A photo of an astronaut riding a horse on mars"
image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")

  0%|                                                                                                                            | 0/50 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[12], line 4
      1 apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.3))
      3 prompt = "A photo of an astronaut riding a horse on mars"
----> 4 image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
      5 image.save("output.png")

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4.py:623, in CogView4Pipeline.__call__(self, prompt, negative_prompt, height, width, num_inference_steps, timesteps, sigmas, guidance_scale, num_images_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, original_size, crops_coords_top_left, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    620 timestep = t.expand(latents.shape[0])
    622 with self.transformer.cache_context("cond"):
--> 623     noise_pred_cond = self.transformer(
    624         hidden_states=latent_model_input,
    625         encoder_hidden_states=prompt_embeds,
    626         timestep=timestep,
    627         original_size=original_size,
    628         target_size=target_size,
    629         crop_coords=crops_coords_top_left,
    630         attention_kwargs=attention_kwargs,
    631         return_dict=False,
    632     )[0]
    634 # perform guidance
    635 if self.do_classifier_free_guidance:

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/models/transformers/transformer_cogview4.py:740, in CogView4Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, original_size, target_size, crop_coords, attention_kwargs, return_dict, attention_mask, image_rotary_emb)
    730         hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
    731             block,
    732             hidden_states,
   (...)
    737             attention_kwargs,
    738         )
    739     else:
--> 740         hidden_states, encoder_hidden_states = block(
    741             hidden_states,
    742             encoder_hidden_states,
    743             temb,
    744             image_rotary_emb,
    745             attention_mask,
    746             attention_kwargs,
    747         )
    749 # 4. Output norm & projection
    750 hidden_states = self.norm_out(hidden_states, temb)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/hooks/hooks.py:189, in HookRegistry.register_hook.<locals>.create_new_forward.<locals>.new_forward(module, *args, **kwargs)
    187 def new_forward(module, *args, **kwargs):
    188     args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
--> 189     output = function_reference.forward(*args, **kwargs)
    190     return function_reference.post_forward(module, output)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/hooks/first_block_cache.py:89, in FBCHeadBlockHook.new_forward(self, module, *args, **kwargs)
     86 else:
     87     hidden_states_residual = output - original_hidden_states
---> 89 shared_state: FBCSharedBlockState = self.state_manager.get_state()
     90 hidden_states = encoder_hidden_states = None
     91 should_compute = self._should_compute_remaining_blocks(hidden_states_residual)

File /mnt/ceph/develop/jiawei/conda_env/diffuser/lib/python3.10/site-packages/diffusers/hooks/hooks.py:44, in StateManager.get_state(self)
     42 def get_state(self):
     43     if self._current_context is None:
---> 44         raise ValueError("No context is set. Please set a context before retrieving the state.")
     45     if self._current_context not in self._state_cache.keys():
     46         self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)

ValueError: No context is set. Please set a context before retrieving the state.