Merge branch 'main' of https://github.com/EvolvingLMMs-Lab/lmms-eval … · EvolvingLMMs-Lab/lmms-eval@465bd42 (original) (raw)
`@@ -8,7 +8,7 @@
`
8
8
`from accelerate import Accelerator, DistributedType
`
9
9
`from accelerate.state import AcceleratorState
`
10
10
`from typing import List, Optional, Union, Tuple
`
11
``
`-
from transformers import LlavaForConditionalGeneration, AutoProcessor
`
``
11
`+
from transformers import LlavaForConditionalGeneration, LlavaNextForConditionalGeneration, AutoProcessor
`
12
12
``
13
13
`import warnings
`
14
14
``
`@@ -31,10 +31,10 @@ class LlavaHf(lmms):
`
31
31
``
32
32
` Example usage:
`
33
33
``
34
``
`-
accelerate launch --num_processes=8 -m lmms_eval \
`
``
34
`+
accelerate launch --num_processes=8 --main_process_port 12345 -m lmms_eval \
`
35
35
` --model llava_hf \
`
36
36
` --model_args pretrained=llava-hf/llava-1.5-7b-hf \
`
37
``
`-
--tasks mme \
`
``
37
`+
--tasks seedbench \
`
38
38
` --batch_size 1 \
`
39
39
` --output_path ./logs/ \
`
40
40
` --log_samples
`
`@@ -67,7 +67,16 @@ def init(
`
67
67
`self.device_map = device_map
`
68
68
`if isinstance(dtype, str) and dtype != "auto":
`
69
69
`dtype = getattr(torch, dtype)
`
70
``
`-
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
`
``
70
+
``
71
`+
if "1.5" in pretrained:
`
``
72
`+
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
`
``
73
`+
elif "1.6" in pretrained:
`
``
74
`+
self._model = LlavaNextForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
`
``
75
`+
else:
`
``
76
`+
eval_logger.info("Not sure whether you use 1.5 or 1.6. Use 1.5 by default. This might cause bugs if you are actually using 1.6")
`
``
77
`+
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
`
``
78
+
``
79
`+
self.pretrained = pretrained
`
71
80
`self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code)
`
72
81
`# Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips
`
73
82
`self._image_processor.tokenizer.padding_side = "left"
`
`@@ -106,6 +115,7 @@ def init(
`
106
115
`self.model.to(self._device)
`
107
116
`self._rank = 0
`
108
117
`self._word_size = 1
`
``
118
`+
self.accelerator = accelerator
`
109
119
``
110
120
`@property
`
111
121
`def config(self):
`
`@@ -199,8 +209,8 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
`
199
209
`labels[: len(contxt_id)] = -100
`
200
210
``
201
211
`if self.accelerator.is_main_process and doc_id % 100 == 0:
`
202
``
`-
eval_logger.info(f"Prompt for doc ID {doc_id}:\n\n{formatted_contexts[0]}\n")
`
203
``
`-
eval_logger.info(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n")
`
``
212
`+
eval_logger.debug(f"Prompt for doc ID {doc_id}:\n\n{formatted_contexts[0]}\n")
`
``
213
`+
eval_logger.debug(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n")
`
204
214
``
205
215
`with torch.inference_mode():
`
206
216
`outputs = self.model(**model_inputs, labels=labels)
`
`@@ -268,7 +278,9 @@ def _collate(x):
`
268
278
``
269
279
`# Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt.
`
270
280
`if DEFAULT_IMAGE_TOKEN not in context:
`
271
``
`-
context = f"{DEFAULT_IMAGE_TOKEN}\n{context}"
`
``
281
`+
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
`
``
282
`+
image_tokens = " ".join(image_tokens)
`
``
283
`+
context = f"{image_tokens}\n{context}"
`
272
284
`# Apply chat template
`
273
285
`messages = [{"role": "user", "content": context}]
`
274
286
`if self.chat_template is not None:
`
`@@ -281,7 +293,7 @@ def _collate(x):
`
281
293
`text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
`
282
294
``
283
295
`if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
`
284
``
`-
eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")
`
``
296
`+
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")
`
285
297
``
286
298
`inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype)
`
287
299
``
`@@ -303,15 +315,21 @@ def _collate(x):
`
303
315
`num_beams=gen_kwargs["num_beams"],
`
304
316
`max_new_tokens=gen_kwargs["max_new_tokens"],
`
305
317
`use_cache=self.use_cache,
`
``
318
`+
pad_token_id=self.tokenizer.eos_token_id,
`
306
319
` )
`
307
320
`except Exception as e:
`
308
321
`eval_logger.error(f"Error {e} in generating")
`
309
322
`cont = ""
`
310
323
`text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
`
311
``
`-
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()
`
``
324
`+
if "1.5" in self.pretrained:
`
``
325
`+
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()
`
``
326
`+
elif "mistral" in self.pretrained:
`
``
327
`+
text_outputs = text_outputs.split("[/INST]")[-1].strip()
`
``
328
`+
else:
`
``
329
`+
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()
`
312
330
``
313
331
`if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
`
314
``
`-
eval_logger.info(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")
`
``
332
`+
eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")
`
315
333
``
316
334
`res.append(text_outputs)
`
317
335
`self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)
`