Comment out parse result in xcomposer · EvolvingLMMs-Lab/lmms-eval@662f05c (original) (raw)

``

1

`+

from multiprocessing import context

`

``

2

`+

import torch

`

``

3

`+

from transformers import AutoModel, AutoTokenizer

`

``

4

`+

from PIL import Image

`

``

5

`+

import numpy as np

`

``

6

`+

import torchvision.transforms as transforms

`

``

7

`+

from datetime import timedelta

`

``

8

`+

import logging

`

``

9

+

``

10

`+

from lmms_eval import utils

`

``

11

`+

from lmms_eval.api.instance import Instance

`

``

12

`+

from lmms_eval.api.model import lmms

`

``

13

`+

from lmms_eval.api.registry import register_model

`

``

14

`+

from lmms_eval.utils import stop_sequences_criteria

`

``

15

+

``

16

`+

from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs

`

``

17

`+

from accelerate.state import AcceleratorState

`

``

18

+

``

19

`+

from typing import Optional, Sequence, List, Tuple, Union

`

``

20

`+

import re

`

``

21

`+

from tqdm import tqdm

`

``

22

+

``

23

`+

pattern = re.compile(r"[A-Z]")

`

``

24

+

``

25

`+

eval_logger = logging.getLogger("lmms-eval")

`

``

26

+

``

27

`+

meta_instruction = """You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).

`

``

28

`+

`

``

29

`+

by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.

`

``

30

`+

`

``

31

`+

the user such as English and 中文.

`

``

32

`+

`

``

33

`+

effectively based on the provided image."""

`

``

34

+

``

35

+

``

36

`+

@register_model("xcomposer2_4khd")

`

``

37

`+

class XComposer2_4KHD(lmms):

`

``

38

`+

def init(

`

``

39

`+

self,

`

``

40

`+

pretrained: str = "internlm/internlm-xcomposer2-4khd-7b",

`

``

41

`+

device: Optional[str] = "cuda:0",

`

``

42

`+

batch_size: Optional[Union[int, str]] = 1,

`

``

43

`+

device_map="cuda:0",

`

``

44

`+

need_bos: bool = True,

`

``

45

`+

padding: bool = False,

`

``

46

`+

half: bool = False,

`

``

47

`+

**kwargs,

`

``

48

`+

) -> None:

`

``

49

`+

super().init()

`

``

50

+

``

51

`+

accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))

`

``

52

`+

accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])

`

``

53

`+

if accelerator.num_processes > 1:

`

``

54

`+

self._device = torch.device(f"cuda:{accelerator.local_process_index}")

`

``

55

`+

self.device_map = f"cuda:{accelerator.local_process_index}"

`

``

56

`+

elif accelerator.num_processes == 1 and device_map == "auto":

`

``

57

`+

self._device = torch.device(device)

`

``

58

`+

self.device_map = device_map

`

``

59

`+

else:

`

``

60

`+

self._device = torch.device(f"cuda:{accelerator.local_process_index}")

`

``

61

`+

self.device_map = f"cuda:{accelerator.local_process_index}"

`

``

62

+

``

63

`+

self.pretrained = pretrained

`

``

64

`+

self.need_bos = need_bos

`

``

65

`+

self.padding = padding

`

``

66

`+

self._model = AutoModel.from_pretrained(self.pretrained, device_map=self.device_map, trust_remote_code=True)

`

``

67

`+

self._tokenizer = AutoTokenizer.from_pretrained(self.pretrained, trust_remote_code=True)

`

``

68

`+

self.model.tokenizer = self.tokenizer

`

``

69

`+

self.batch_size_per_gpu = batch_size

`

``

70

+

``

71

`+

if accelerator.num_processes > 1:

`

``

72

`+

assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."

`

``

73

`+

If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model

`

``

74

`+

Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works

`

``

75

`+

I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.

`

``

76

`+

if accelerator.distributed_type == DistributedType.DEEPSPEED:

`

``

77

`+

kwargs = {

`

``

78

`+

"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,

`

``

79

`+

"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,

`

``

80

`+

}

`

``

81

`+

AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)

`

``

82

`` +

eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run accelerate config and set zero stage to 0")

``

``

83

`+

if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:

`

``

84

`+

self._model = accelerator.prepare(self.model)

`

``

85

`+

else:

`

``

86

`+

self._model = accelerator.prepare_model(self.model, evaluation_mode=True)

`

``

87

`+

self.accelerator = accelerator

`

``

88

`+

if self.accelerator.is_local_main_process:

`

``

89

`+

eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")

`

``

90

`+

self._rank = self.accelerator.local_process_index

`

``

91

`+

self._world_size = self.accelerator.num_processes

`

``

92

`+

elif accelerator.num_processes == 1 and device_map == "auto":

`

``

93

`+

eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")

`

``

94

`+

self._rank = 0

`

``

95

`+

self._word_size = 1

`

``

96

`+

else:

`

``

97

`+

eval_logger.info(f"Using single device: {self._device}")

`

``

98

`+

self.model.to(self._device)

`

``

99

`+

self._rank = 0

`

``

100

`+

self._world_size = 1

`

``

101

+

``

102

`+

@property

`

``

103

`+

def config(self):

`

``

104

`+

return the associated transformers.AutoConfig for the given pretrained model.

`

``

105

`+

return self._config

`

``

106

+

``

107

`+

@property

`

``

108

`+

def tokenizer(self):

`

``

109

`+

return self._tokenizer

`

``

110

+

``

111

`+

@property

`

``

112

`+

def model(self):

`

``

113

`+

returns the model, unwrapping it if using Accelerate

`

``

114

`+

if hasattr(self, "accelerator"):

`

``

115

`+

return self.accelerator.unwrap_model(self._model)

`

``

116

`+

else:

`

``

117

`+

return self._model

`

``

118

+

``

119

`+

@property

`

``

120

`+

def batch_size(self):

`

``

121

`+

return self.batch_size_per_gpu

`

``

122

+

``

123

`+

@property

`

``

124

`+

def device(self):

`

``

125

`+

return self._device

`

``

126

+

``

127

`+

@property

`

``

128

`+

def rank(self):

`

``

129

`+

return self._rank

`

``

130

+

``

131

`+

@property

`

``

132

`+

def world_size(self):

`

``

133

`+

return self._world_size

`

``

134

+

``

135

`+

def flatten(self, input):

`

``

136

`+

new_list = []

`

``

137

`+

for i in input:

`

``

138

`+

for j in i:

`

``

139

`+

new_list.append(j)

`

``

140

`+

return new_list

`

``

141

+

``

142

`+

def generate_until(self, requests) -> List[str]:

`

``

143

`+

res = []

`

``

144

`+

pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

`

``

145

+

``

146

`+

for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:

`

``

147

`+

encode, pad, and truncate contexts for this batch

`

``

148

`+

if "[UNUSED_TOKEN_146]" not in contexts:

`

``

149

`+

contexts = f"[UNUSED_TOKEN_146]user\n{contexts}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"

`

``

150

`+

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]

`

``

151

`+

visuals = self.flatten(visuals)

`

``

152

+

``

153

`+

if "hd_num" not in gen_kwargs:

`

``

154

`+

if listinstr(["docvqa_test", "infovqa_test"], task.lower()):

`

``

155

`+

self.model.hd_num = 65

`

``

156

`+

elif listinstr(["docvqa_val", "infovqa_val", "OCRBench"], task.lower()):

`

``

157

`+

self.model.hd_num = 55

`

``

158

`+

elif listinstr(["mmmu", "mmbench", "mmvet"], task.lower()):

`

``

159

`+

self.model.hd_num = 16

`

``

160

`+

else:

`

``

161

`+

self.model.hd_num = 25

`

``

162

`+

else:

`

``

163

`+

self.model.hd_num = gen_kwargs.pop("hd_num")

`

``

164

+

``

165

`+

pt1 = 0

`

``

166

`+

embeds = []

`

``

167

`+

im_mask = []

`

``

168

`+

images_loc = [0]

`

``

169

`+

need_bos = self.need_bos

`

``

170

`+

padding = self.padding

`

``

171

`+

for i, pts in enumerate(images_loc + [len(contexts)]):

`

``

172

`+

subtext = contexts[pt1:pts]

`

``

173

`+

if need_bos or len(subtext) > 0:

`

``

174

`+

text_embeds = self.model.encode_text(subtext, add_special_tokens=need_bos).to(self.device)

`

``

175

`+

embeds.append(text_embeds)

`

``

176

`+

im_mask.append(torch.zeros(text_embeds.shape[:2]).to(self.device))

`

``

177

`+

need_bos = False

`

``

178

`+

if i < len(visuals):

`

``

179

`+

image = visuals[i]

`

``

180

+

``

181

`+

image = HD_transform(image, im_num=self.model.hd_num)

`

``

182

`+

image = self.model.vis_processor(image).unsqueeze(0).to(self.device)

`

``

183

`+

image_embeds = self.model.encode_img(image)

`

``

184

`+

embeds.append(image_embeds)

`

``

185

`+

im_mask.append(torch.ones(image_embeds.shape[:2]).to(self.device))

`

``

186

`+

pt1 = pts

`

``

187

`+

embeds = torch.cat(embeds, dim=1)

`

``

188

`+

im_mask = torch.cat(im_mask, dim=1)

`

``

189

`+

im_mask = im_mask.bool()

`

``

190

+

``

191

`+

if "max_new_tokens" not in gen_kwargs:

`

``

192

`+

gen_kwargs["max_new_tokens"] = 1024

`

``

193

`+

if "temperature" not in gen_kwargs:

`

``

194

`+

gen_kwargs["temperature"] = 0

`

``

195

`+

if "top_p" not in gen_kwargs:

`

``

196

`+

gen_kwargs["top_p"] = None

`

``

197

`+

if "num_beams" not in gen_kwargs:

`

``

198

`+

gen_kwargs["num_beams"] = 1

`

``

199

`+

if "do_sample" not in gen_kwargs:

`

``

200

`+

gen_kwargs["do_sample"] = False

`

``

201

`+

if "repetition_penalty" not in gen_kwargs:

`

``

202

`+

gen_kwargs["repetition_penalty"] = 1.0

`

``

203

+

``

204

`+

outputs = self.model.generate(

`

``

205

`+

inputs_embeds=embeds,

`

``

206

`+

im_mask=im_mask,

`

``

207

`+

temperature=gen_kwargs["temperature"],

`

``

208

`+

max_new_tokens=gen_kwargs["max_new_tokens"],

`

``

209

`+

num_beams=gen_kwargs["num_beams"],

`

``

210

`+

do_sample=gen_kwargs["do_sample"],

`

``

211

`+

repetition_penalty=gen_kwargs["repetition_penalty"],

`

``

212

`+

)

`

``

213

`+

output_token = outputs[0]

`

``

214

`+

if output_token[0] == 0 or output_token[0] == 1:

`

``

215

`+

output_token = output_token[1:]

`

``

216

`+

output_text = self.model.tokenizer.decode(output_token, add_special_tokens=False)

`

``

217

`+

output_text = output_text.split("[UNUSED_TOKEN_145]")[0].strip()

`

``

218

`+

output_text = output_text.split("<|im_end|>")[0].strip()

`

``

219

`+

if DATASET_TYPE(task) == "multi-choice":

`

``

220

`+

output_text = pattern.findall(output_text)

`

``

221

`+

if len(output_text) == 0:

`

``

222

`+

print("Error:", output_text)

`

``

223

`+

output_text = "Z"

`

``

224

`+

if type(output_text) == list:

`

``

225

`+

output_text = output_text[0]

`

``

226

`+

res.append(output_text)

`

``

227

`+

pbar.update(1)

`

``

228

`+

pbar.close()

`

``

229

`+

return res

`

``

230

+

``

231

`+

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:

`

``

232

`+

return super().loglikelihood(requests)

`

``

233

+

``

234

+

``

235

`+

def padding_336(b):

`

``

236

`+

width, height = b.size

`

``

237

`+

tar = int(np.ceil(height / 336) * 336)

`

``

238

`+

top_padding = int((tar - height) / 2)

`

``

239

`+

bottom_padding = tar - height - top_padding

`

``

240

`+

left_padding = 0

`

``

241

`+

right_padding = 0

`

``

242

`+

b = transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255])

`

``

243

+

``

244

`+

return b

`

``

245

+

``

246

+

``

247

`+

def HD_transform(img, im_num=16):

`

``

248

`+

width, height = img.size

`

``

249

`+

trans = False

`

``

250

`+

if width < height:

`

``

251

`+

img = img.transpose(Image.TRANSPOSE)

`

``

252

`+

trans = True

`

``

253

`+

width, height = img.size

`

``

254

`+

ratio = width / height

`

``

255

`+

scale = 1

`

``

256

`+

while scale * np.ceil(scale / ratio) <= im_num:

`

``

257

`+

scale += 1

`

``

258

`+

scale -= 1

`

``

259

`+

new_w = int(scale * 336)

`

``

260

`+

new_h = int(new_w / ratio)

`

``

261

+

``

262

`+

img = transforms.functional.resize(

`

``

263

`+

img,

`

``

264

`+

[new_h, new_w],

`

``

265

`+

)

`

``

266

`+

img = padding_336(img)

`

``

267

`+

width, height = img.size

`

``

268

`+

assert width * height <= im_num * 336 * 336

`

``

269

`+

if trans:

`

``

270

`+

img = img.transpose(Image.TRANSPOSE)

`

``

271

+

``

272

`+

return img

`

``

273

+

``

274

+

``

275

`+

def listinstr(lst, s):

`

``

276

`+

assert isinstance(lst, list)

`

``

277

`+

for item in lst:

`

``

278

`+

if item in s:

`

``

279

`+

return True

`

``

280

`+

return False

`

``

281

+

``

282

+

``

283

`+

def DATASET_TYPE(dataset):

`

``

284

`+

Dealing with Custom Dataset

`

``

285

`+

dataset = dataset.lower()

`

``

286

`+

if listinstr(["mmbench", "seedbench", "ccbench", "mmmu", "scienceqa", "ai2d", "mmstar"], dataset):

`

``

287

`+

return "multi-choice"

`

``

288

`+

elif listinstr(["mme", "hallusion"], dataset):

`

``

289

`+

return "Y/N"

`

``

290

`+

elif "coco" in dataset:

`

``

291

`+

return "Caption"

`

``

292

`+

elif listinstr(["ocrvqa", "textvqa", "chartqa", "mathvista", "docvqa", "infovqa", "llavabench", "mmvet", "ocrbench"], dataset):

`

``

293

`+

return "VQA"

`

``

294

`+

else:

`

``

295

`+

return "QA"

`