Generate: New Cache
abstraction and Attention Sinks support by tomaarsen · Pull Request #26681 · huggingface/transformers (original) (raw)
@ArthurZucker It seems that even with Llama2, passing in SinkCache to
generate
causes errors.I'm using transformers 4.39.3 and the Llama2 model was loaded using the following code:
from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer model_id = "meta-llama/Llama-2-7b-hf" model = LlamaForCausalLM.from_pretrained( model_id, low_cpu_mem_usage=True, device_map='auto', torch_dtype=torch.bfloat16) model = model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'left'
The SinkCache was passed to generate as in @fayejf's script. I am not sure if this is the correct way to use SinkCache:
from transformers import SinkCache prefix = 'Hello world!' inputs = tokenizer(prefix, return_tensors='pt').to(device) cache = SinkCache(window_length=1024, num_sink_tokens=4) gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=256, use_cache=True, past_key_values=cache, pad_token_id=tokenizer.pad_token_id) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
The code caused
TypeError: object of type 'SinkCache' has no len()
as a result of thisDynamicCache.from_legacy_cache
call (see the stack trace below). Looks like you are familiar with theStaticCache
stuff, any suggestions on how to get around this? Thanks in advance!File /miniconda3/envs/pytorch2/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:977, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position) 975 if use_cache: # kept for BC (cache positions) 976 if not isinstance(past_key_values, StaticCache): --> 977 past_key_values = DynamicCache.from_legacy_cache(past_key_values) 978 past_seen_tokens = past_key_values.get_seq_length() 980 if cache_position is None: File /miniconda3/envs/pytorch2/lib/python3.10/site-packages/transformers/cache_utils.py:181, in DynamicCache.from_legacy_cache(cls, past_key_values) 179 cache = cls() 180 if past_key_values is not None: --> 181 for layer_idx in range(len(past_key_values)): 182 key_states, value_states = past_key_values[layer_idx] 183 cache.update(key_states, value_states, layer_idx)
I got the same error with Llama2 using 4.38.2 following code here. Also tried transformers==4.39.0
, got the same error.
File "/opt/conda/envs/dn/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 982, in forward
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
File "/opt/conda/envs/dn/lib/python3.10/site-packages/transformers/cache_utils.py", line 166, in from_legacy_cache
for layer_idx in range(len(past_key_values)):
TypeError: object of type 'SinkCache' has no len()