Refactor cross attention and allow mechanism to tweak cross attention function by patrickvonplaten · Pull Request #1639 · huggingface/diffusers (original) (raw)

Hi @evinpinar, thank you for your prompt response!

So, one thing I gather from your example is that set_attn_processor accepts a dictionary mapping layer names to processors, and will use the corresponding processor on that specific layer, right? The other example here just calls

processor = AttnEasyProc(5.0) model.set_attn_processor(processor)

which I can only assume will call the same processor on every layer. Are there any other overloads of set_attn_processor? Anyway, these two should be enough!

Another peculiarity of the API that i gather from the above example is that one can pass extra kwargs as a dictionary, like

model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample

and said kwargs will apparently be passed when calling the processor in fact the signature of __call__ there is

def call(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):

The last piece of the puzzle I am not sure about is to what extent one needs to reproduce the "normal" attention mechanism in processors. That is, even the simple processor in the example has the usual attention computation

def call(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

  query = attn.to_q(hidden_states)

  encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
  key = attn.to_k(encoder_hidden_states)
  value = attn.to_v(encoder_hidden_states)

  query = attn.head_to_batch_dim(query)
  key = attn.head_to_batch_dim(key)
  value = attn.head_to_batch_dim(value)

  attention_probs = attn.get_attention_scores(query, key, attention_mask)
  hidden_states = torch.bmm(attention_probs, value)
  hidden_states = attn.batch_to_head_dim(hidden_states)

  # linear proj
  hidden_states = attn.to_out[0](hidden_states)
  # dropout
  hidden_states = attn.to_out[1](hidden_states)

  return hidden_states

So I assume that the computation taking place in these processors will replace the default attention computation, instead of, say, augment it in some way. In other words, every processor will have to copy this first and then modify the flow to achieve whatever it needs to do, instead of getting the already computed attention maps and just having to possibly modify them.

I think this is the case, but I am just asking if there are gross misunderstandings.

Excuse me if my questions seem naive, but this monkey patching of attention maps is already delicate enough, and without some documentation on the exact API it is hard to get started.