Comment out parse result in xcomposer · EvolvingLMMs-Lab/lmms-eval@662f05c (original) (raw)
``
1
`+
from multiprocessing import context
`
``
2
`+
import torch
`
``
3
`+
from transformers import AutoModel, AutoTokenizer
`
``
4
`+
from PIL import Image
`
``
5
`+
import numpy as np
`
``
6
`+
import torchvision.transforms as transforms
`
``
7
`+
from datetime import timedelta
`
``
8
`+
import logging
`
``
9
+
``
10
`+
from lmms_eval import utils
`
``
11
`+
from lmms_eval.api.instance import Instance
`
``
12
`+
from lmms_eval.api.model import lmms
`
``
13
`+
from lmms_eval.api.registry import register_model
`
``
14
`+
from lmms_eval.utils import stop_sequences_criteria
`
``
15
+
``
16
`+
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
`
``
17
`+
from accelerate.state import AcceleratorState
`
``
18
+
``
19
`+
from typing import Optional, Sequence, List, Tuple, Union
`
``
20
`+
import re
`
``
21
`+
from tqdm import tqdm
`
``
22
+
``
23
`+
pattern = re.compile(r"[A-Z]")
`
``
24
+
``
25
`+
eval_logger = logging.getLogger("lmms-eval")
`
``
26
+
``
27
`+
meta_instruction = """You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).
`
``
28
`+
- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed\
`
``
29
`+
by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
`
``
30
`+
- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by\
`
``
31
`+
the user such as English and 中文.
`
``
32
`+
- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses\
`
``
33
`+
effectively based on the provided image."""
`
``
34
+
``
35
+
``
36
`+
@register_model("xcomposer2_4khd")
`
``
37
`+
class XComposer2_4KHD(lmms):
`
``
38
`+
def init(
`
``
39
`+
self,
`
``
40
`+
pretrained: str = "internlm/internlm-xcomposer2-4khd-7b",
`
``
41
`+
device: Optional[str] = "cuda:0",
`
``
42
`+
batch_size: Optional[Union[int, str]] = 1,
`
``
43
`+
device_map="cuda:0",
`
``
44
`+
need_bos: bool = True,
`
``
45
`+
padding: bool = False,
`
``
46
`+
half: bool = False,
`
``
47
`+
**kwargs,
`
``
48
`+
) -> None:
`
``
49
`+
super().init()
`
``
50
+
``
51
`+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
`
``
52
`+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
`
``
53
`+
if accelerator.num_processes > 1:
`
``
54
`+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
`
``
55
`+
self.device_map = f"cuda:{accelerator.local_process_index}"
`
``
56
`+
elif accelerator.num_processes == 1 and device_map == "auto":
`
``
57
`+
self._device = torch.device(device)
`
``
58
`+
self.device_map = device_map
`
``
59
`+
else:
`
``
60
`+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
`
``
61
`+
self.device_map = f"cuda:{accelerator.local_process_index}"
`
``
62
+
``
63
`+
self.pretrained = pretrained
`
``
64
`+
self.need_bos = need_bos
`
``
65
`+
self.padding = padding
`
``
66
`+
self._model = AutoModel.from_pretrained(self.pretrained, device_map=self.device_map, trust_remote_code=True)
`
``
67
`+
self._tokenizer = AutoTokenizer.from_pretrained(self.pretrained, trust_remote_code=True)
`
``
68
`+
self.model.tokenizer = self.tokenizer
`
``
69
`+
self.batch_size_per_gpu = batch_size
`
``
70
+
``
71
`+
if accelerator.num_processes > 1:
`
``
72
`+
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
`
``
73
`+
If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
`
``
74
`+
Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
`
``
75
`+
I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
`
``
76
`+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
`
``
77
`+
kwargs = {
`
``
78
`+
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
`
``
79
`+
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
`
``
80
`+
}
`
``
81
`+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
`
``
82
`` +
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run accelerate config
and set zero stage to 0")
``
``
83
`+
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
`
``
84
`+
self._model = accelerator.prepare(self.model)
`
``
85
`+
else:
`
``
86
`+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
`
``
87
`+
self.accelerator = accelerator
`
``
88
`+
if self.accelerator.is_local_main_process:
`
``
89
`+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
`
``
90
`+
self._rank = self.accelerator.local_process_index
`
``
91
`+
self._world_size = self.accelerator.num_processes
`
``
92
`+
elif accelerator.num_processes == 1 and device_map == "auto":
`
``
93
`+
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
`
``
94
`+
self._rank = 0
`
``
95
`+
self._word_size = 1
`
``
96
`+
else:
`
``
97
`+
eval_logger.info(f"Using single device: {self._device}")
`
``
98
`+
self.model.to(self._device)
`
``
99
`+
self._rank = 0
`
``
100
`+
self._world_size = 1
`
``
101
+
``
102
`+
@property
`
``
103
`+
def config(self):
`
``
104
`+
return the associated transformers.AutoConfig for the given pretrained model.
`
``
105
`+
return self._config
`
``
106
+
``
107
`+
@property
`
``
108
`+
def tokenizer(self):
`
``
109
`+
return self._tokenizer
`
``
110
+
``
111
`+
@property
`
``
112
`+
def model(self):
`
``
113
`+
returns the model, unwrapping it if using Accelerate
`
``
114
`+
if hasattr(self, "accelerator"):
`
``
115
`+
return self.accelerator.unwrap_model(self._model)
`
``
116
`+
else:
`
``
117
`+
return self._model
`
``
118
+
``
119
`+
@property
`
``
120
`+
def batch_size(self):
`
``
121
`+
return self.batch_size_per_gpu
`
``
122
+
``
123
`+
@property
`
``
124
`+
def device(self):
`
``
125
`+
return self._device
`
``
126
+
``
127
`+
@property
`
``
128
`+
def rank(self):
`
``
129
`+
return self._rank
`
``
130
+
``
131
`+
@property
`
``
132
`+
def world_size(self):
`
``
133
`+
return self._world_size
`
``
134
+
``
135
`+
def flatten(self, input):
`
``
136
`+
new_list = []
`
``
137
`+
for i in input:
`
``
138
`+
for j in i:
`
``
139
`+
new_list.append(j)
`
``
140
`+
return new_list
`
``
141
+
``
142
`+
def generate_until(self, requests) -> List[str]:
`
``
143
`+
res = []
`
``
144
`+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
`
``
145
+
``
146
`+
for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
`
``
147
`+
encode, pad, and truncate contexts for this batch
`
``
148
`+
if "[UNUSED_TOKEN_146]" not in contexts:
`
``
149
`+
contexts = f"[UNUSED_TOKEN_146]user\n{contexts}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"
`
``
150
`+
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
`
``
151
`+
visuals = self.flatten(visuals)
`
``
152
+
``
153
`+
if "hd_num" not in gen_kwargs:
`
``
154
`+
if listinstr(["docvqa_test", "infovqa_test"], task.lower()):
`
``
155
`+
self.model.hd_num = 65
`
``
156
`+
elif listinstr(["docvqa_val", "infovqa_val", "OCRBench"], task.lower()):
`
``
157
`+
self.model.hd_num = 55
`
``
158
`+
elif listinstr(["mmmu", "mmbench", "mmvet"], task.lower()):
`
``
159
`+
self.model.hd_num = 16
`
``
160
`+
else:
`
``
161
`+
self.model.hd_num = 25
`
``
162
`+
else:
`
``
163
`+
self.model.hd_num = gen_kwargs.pop("hd_num")
`
``
164
+
``
165
`+
pt1 = 0
`
``
166
`+
embeds = []
`
``
167
`+
im_mask = []
`
``
168
`+
images_loc = [0]
`
``
169
`+
need_bos = self.need_bos
`
``
170
`+
padding = self.padding
`
``
171
`+
for i, pts in enumerate(images_loc + [len(contexts)]):
`
``
172
`+
subtext = contexts[pt1:pts]
`
``
173
`+
if need_bos or len(subtext) > 0:
`
``
174
`+
text_embeds = self.model.encode_text(subtext, add_special_tokens=need_bos).to(self.device)
`
``
175
`+
embeds.append(text_embeds)
`
``
176
`+
im_mask.append(torch.zeros(text_embeds.shape[:2]).to(self.device))
`
``
177
`+
need_bos = False
`
``
178
`+
if i < len(visuals):
`
``
179
`+
image = visuals[i]
`
``
180
+
``
181
`+
image = HD_transform(image, im_num=self.model.hd_num)
`
``
182
`+
image = self.model.vis_processor(image).unsqueeze(0).to(self.device)
`
``
183
`+
image_embeds = self.model.encode_img(image)
`
``
184
`+
embeds.append(image_embeds)
`
``
185
`+
im_mask.append(torch.ones(image_embeds.shape[:2]).to(self.device))
`
``
186
`+
pt1 = pts
`
``
187
`+
embeds = torch.cat(embeds, dim=1)
`
``
188
`+
im_mask = torch.cat(im_mask, dim=1)
`
``
189
`+
im_mask = im_mask.bool()
`
``
190
+
``
191
`+
if "max_new_tokens" not in gen_kwargs:
`
``
192
`+
gen_kwargs["max_new_tokens"] = 1024
`
``
193
`+
if "temperature" not in gen_kwargs:
`
``
194
`+
gen_kwargs["temperature"] = 0
`
``
195
`+
if "top_p" not in gen_kwargs:
`
``
196
`+
gen_kwargs["top_p"] = None
`
``
197
`+
if "num_beams" not in gen_kwargs:
`
``
198
`+
gen_kwargs["num_beams"] = 1
`
``
199
`+
if "do_sample" not in gen_kwargs:
`
``
200
`+
gen_kwargs["do_sample"] = False
`
``
201
`+
if "repetition_penalty" not in gen_kwargs:
`
``
202
`+
gen_kwargs["repetition_penalty"] = 1.0
`
``
203
+
``
204
`+
outputs = self.model.generate(
`
``
205
`+
inputs_embeds=embeds,
`
``
206
`+
im_mask=im_mask,
`
``
207
`+
temperature=gen_kwargs["temperature"],
`
``
208
`+
max_new_tokens=gen_kwargs["max_new_tokens"],
`
``
209
`+
num_beams=gen_kwargs["num_beams"],
`
``
210
`+
do_sample=gen_kwargs["do_sample"],
`
``
211
`+
repetition_penalty=gen_kwargs["repetition_penalty"],
`
``
212
`+
)
`
``
213
`+
output_token = outputs[0]
`
``
214
`+
if output_token[0] == 0 or output_token[0] == 1:
`
``
215
`+
output_token = output_token[1:]
`
``
216
`+
output_text = self.model.tokenizer.decode(output_token, add_special_tokens=False)
`
``
217
`+
output_text = output_text.split("[UNUSED_TOKEN_145]")[0].strip()
`
``
218
`+
output_text = output_text.split("<|im_end|>")[0].strip()
`
``
219
`+
if DATASET_TYPE(task) == "multi-choice":
`
``
220
`+
output_text = pattern.findall(output_text)
`
``
221
`+
if len(output_text) == 0:
`
``
222
`+
print("Error:", output_text)
`
``
223
`+
output_text = "Z"
`
``
224
`+
if type(output_text) == list:
`
``
225
`+
output_text = output_text[0]
`
``
226
`+
res.append(output_text)
`
``
227
`+
pbar.update(1)
`
``
228
`+
pbar.close()
`
``
229
`+
return res
`
``
230
+
``
231
`+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
`
``
232
`+
return super().loglikelihood(requests)
`
``
233
+
``
234
+
``
235
`+
def padding_336(b):
`
``
236
`+
width, height = b.size
`
``
237
`+
tar = int(np.ceil(height / 336) * 336)
`
``
238
`+
top_padding = int((tar - height) / 2)
`
``
239
`+
bottom_padding = tar - height - top_padding
`
``
240
`+
left_padding = 0
`
``
241
`+
right_padding = 0
`
``
242
`+
b = transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255])
`
``
243
+
``
244
`+
return b
`
``
245
+
``
246
+
``
247
`+
def HD_transform(img, im_num=16):
`
``
248
`+
width, height = img.size
`
``
249
`+
trans = False
`
``
250
`+
if width < height:
`
``
251
`+
img = img.transpose(Image.TRANSPOSE)
`
``
252
`+
trans = True
`
``
253
`+
width, height = img.size
`
``
254
`+
ratio = width / height
`
``
255
`+
scale = 1
`
``
256
`+
while scale * np.ceil(scale / ratio) <= im_num:
`
``
257
`+
scale += 1
`
``
258
`+
scale -= 1
`
``
259
`+
new_w = int(scale * 336)
`
``
260
`+
new_h = int(new_w / ratio)
`
``
261
+
``
262
`+
img = transforms.functional.resize(
`
``
263
`+
img,
`
``
264
`+
[new_h, new_w],
`
``
265
`+
)
`
``
266
`+
img = padding_336(img)
`
``
267
`+
width, height = img.size
`
``
268
`+
assert width * height <= im_num * 336 * 336
`
``
269
`+
if trans:
`
``
270
`+
img = img.transpose(Image.TRANSPOSE)
`
``
271
+
``
272
`+
return img
`
``
273
+
``
274
+
``
275
`+
def listinstr(lst, s):
`
``
276
`+
assert isinstance(lst, list)
`
``
277
`+
for item in lst:
`
``
278
`+
if item in s:
`
``
279
`+
return True
`
``
280
`+
return False
`
``
281
+
``
282
+
``
283
`+
def DATASET_TYPE(dataset):
`
``
284
`+
Dealing with Custom Dataset
`
``
285
`+
dataset = dataset.lower()
`
``
286
`+
if listinstr(["mmbench", "seedbench", "ccbench", "mmmu", "scienceqa", "ai2d", "mmstar"], dataset):
`
``
287
`+
return "multi-choice"
`
``
288
`+
elif listinstr(["mme", "hallusion"], dataset):
`
``
289
`+
return "Y/N"
`
``
290
`+
elif "coco" in dataset:
`
``
291
`+
return "Caption"
`
``
292
`+
elif listinstr(["ocrvqa", "textvqa", "chartqa", "mathvista", "docvqa", "infovqa", "llavabench", "mmvet", "ocrbench"], dataset):
`
``
293
`+
return "VQA"
`
``
294
`+
else:
`
``
295
`+
return "QA"
`