Generate: New Cache abstraction and Attention Sinks support by tomaarsen · Pull Request #26681 · huggingface/transformers (original) (raw)

@ArthurZucker It seems that even with Llama2, passing in SinkCache to generate causes errors.

I'm using transformers 4.39.3 and the Llama2 model was loaded using the following code:

from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-hf"
model = LlamaForCausalLM.from_pretrained(
     model_id, low_cpu_mem_usage=True, device_map='auto',
     torch_dtype=torch.bfloat16)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

The SinkCache was passed to generate as in @fayejf's script. I am not sure if this is the correct way to use SinkCache:

from transformers import SinkCache

prefix = 'Hello world!'
inputs = tokenizer(prefix, return_tensors='pt').to(device)

cache = SinkCache(window_length=1024, num_sink_tokens=4)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=256,
                        use_cache=True,
                        past_key_values=cache,
                        pad_token_id=tokenizer.pad_token_id)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

The code caused TypeError: object of type 'SinkCache' has no len() as a result of this DynamicCache.from_legacy_cache call (see the stack trace below). Looks like you are familiar with the StaticCache stuff, any suggestions on how to get around this? Thanks in advance!

File /miniconda3/envs/pytorch2/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:977, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   975 if use_cache:  # kept for BC (cache positions)
   976     if not isinstance(past_key_values, StaticCache):
--> 977         past_key_values = DynamicCache.from_legacy_cache(past_key_values)
   978         past_seen_tokens = past_key_values.get_seq_length()
   980 if cache_position is None:

File /miniconda3/envs/pytorch2/lib/python3.10/site-packages/transformers/cache_utils.py:181, in DynamicCache.from_legacy_cache(cls, past_key_values)
   179 cache = cls()
   180 if past_key_values is not None:
--> 181     for layer_idx in range(len(past_key_values)):
   182         key_states, value_states = past_key_values[layer_idx]
   183         cache.update(key_states, value_states, layer_idx)

I got the same error with Llama2 using 4.38.2 following code here. Also tried transformers==4.39.0, got the same error.

File "/opt/conda/envs/dn/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 982, in forward
    past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  File "/opt/conda/envs/dn/lib/python3.10/site-packages/transformers/cache_utils.py", line 166, in from_legacy_cache
    for layer_idx in range(len(past_key_values)):
TypeError: object of type 'SinkCache' has no len()