Merge branch 'main' of https://github.com/EvolvingLMMs-Lab/lmms-eval … · EvolvingLMMs-Lab/lmms-eval@465bd42 (original) (raw)

`@@ -8,7 +8,7 @@

`

8

8

`from accelerate import Accelerator, DistributedType

`

9

9

`from accelerate.state import AcceleratorState

`

10

10

`from typing import List, Optional, Union, Tuple

`

11

``

`-

from transformers import LlavaForConditionalGeneration, AutoProcessor

`

``

11

`+

from transformers import LlavaForConditionalGeneration, LlavaNextForConditionalGeneration, AutoProcessor

`

12

12

``

13

13

`import warnings

`

14

14

``

`@@ -31,10 +31,10 @@ class LlavaHf(lmms):

`

31

31

``

32

32

` Example usage:

`

33

33

``

34

``

`-

accelerate launch --num_processes=8 -m lmms_eval \

`

``

34

`+

accelerate launch --num_processes=8 --main_process_port 12345 -m lmms_eval \

`

35

35

` --model llava_hf \

`

36

36

` --model_args pretrained=llava-hf/llava-1.5-7b-hf \

`

37

``

`-

--tasks mme \

`

``

37

`+

--tasks seedbench \

`

38

38

` --batch_size 1 \

`

39

39

` --output_path ./logs/ \

`

40

40

` --log_samples

`

`@@ -67,7 +67,16 @@ def init(

`

67

67

`self.device_map = device_map

`

68

68

`if isinstance(dtype, str) and dtype != "auto":

`

69

69

`dtype = getattr(torch, dtype)

`

70

``

`-

self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)

`

``

70

+

``

71

`+

if "1.5" in pretrained:

`

``

72

`+

self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)

`

``

73

`+

elif "1.6" in pretrained:

`

``

74

`+

self._model = LlavaNextForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)

`

``

75

`+

else:

`

``

76

`+

eval_logger.info("Not sure whether you use 1.5 or 1.6. Use 1.5 by default. This might cause bugs if you are actually using 1.6")

`

``

77

`+

self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)

`

``

78

+

``

79

`+

self.pretrained = pretrained

`

71

80

`self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code)

`

72

81

`# Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips

`

73

82

`self._image_processor.tokenizer.padding_side = "left"

`

`@@ -106,6 +115,7 @@ def init(

`

106

115

`self.model.to(self._device)

`

107

116

`self._rank = 0

`

108

117

`self._word_size = 1

`

``

118

`+

self.accelerator = accelerator

`

109

119

``

110

120

`@property

`

111

121

`def config(self):

`

`@@ -199,8 +209,8 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:

`

199

209

`labels[: len(contxt_id)] = -100

`

200

210

``

201

211

`if self.accelerator.is_main_process and doc_id % 100 == 0:

`

202

``

`-

eval_logger.info(f"Prompt for doc ID {doc_id}:\n\n{formatted_contexts[0]}\n")

`

203

``

`-

eval_logger.info(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n")

`

``

212

`+

eval_logger.debug(f"Prompt for doc ID {doc_id}:\n\n{formatted_contexts[0]}\n")

`

``

213

`+

eval_logger.debug(f"Prompt and continuation for doc ID {doc_id}:\n\n{formatted_continuation[0]}\n")

`

204

214

``

205

215

`with torch.inference_mode():

`

206

216

`outputs = self.model(**model_inputs, labels=labels)

`

`@@ -268,7 +278,9 @@ def _collate(x):

`

268

278

``

269

279

`# Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt.

`

270

280

`if DEFAULT_IMAGE_TOKEN not in context:

`

271

``

`-

context = f"{DEFAULT_IMAGE_TOKEN}\n{context}"

`

``

281

`+

image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)

`

``

282

`+

image_tokens = " ".join(image_tokens)

`

``

283

`+

context = f"{image_tokens}\n{context}"

`

272

284

`# Apply chat template

`

273

285

`messages = [{"role": "user", "content": context}]

`

274

286

`if self.chat_template is not None:

`

`@@ -281,7 +293,7 @@ def _collate(x):

`

281

293

`text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

`

282

294

``

283

295

`if self.accelerator.is_main_process and doc_id[0] % 100 == 0:

`

284

``

`-

eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")

`

``

296

`+

eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")

`

285

297

``

286

298

`inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, self.model.dtype)

`

287

299

``

`@@ -303,15 +315,21 @@ def _collate(x):

`

303

315

`num_beams=gen_kwargs["num_beams"],

`

304

316

`max_new_tokens=gen_kwargs["max_new_tokens"],

`

305

317

`use_cache=self.use_cache,

`

``

318

`+

pad_token_id=self.tokenizer.eos_token_id,

`

306

319

` )

`

307

320

`except Exception as e:

`

308

321

`eval_logger.error(f"Error {e} in generating")

`

309

322

`cont = ""

`

310

323

`text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]

`

311

``

`-

text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()

`

``

324

`+

if "1.5" in self.pretrained:

`

``

325

`+

text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()

`

``

326

`+

elif "mistral" in self.pretrained:

`

``

327

`+

text_outputs = text_outputs.split("[/INST]")[-1].strip()

`

``

328

`+

else:

`

``

329

`+

text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()

`

312

330

``

313

331

`if self.accelerator.is_main_process and doc_id[0] % 100 == 0:

`

314

``

`-

eval_logger.info(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")

`

``

332

`+

eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")

`

315

333

``

316

334

`res.append(text_outputs)

`

317

335

`self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)

`