Adding Phi3v model. · EvolvingLMMs-Lab/lmms-eval@7f9fb6b (original) (raw)

``

1

`+

import torch

`

``

2

`+

import logging

`

``

3

+

``

4

`+

from accelerate import Accelerator, DistributedType

`

``

5

`+

from lmms_eval import utils

`

``

6

`+

from lmms_eval.api.instance import Instance

`

``

7

`+

from lmms_eval.api.model import lmms

`

``

8

`+

from lmms_eval.api.registry import register_model

`

``

9

`+

from tqdm import tqdm

`

``

10

`+

from transformers import AutoModelForCausalLM

`

``

11

`+

from transformers import AutoProcessor

`

``

12

`+

from typing import List, Optional, Tuple, Union

`

``

13

+

``

14

`+

eval_logger = logging.getLogger("lmms-eval")

`

``

15

+

``

16

+

``

17

`+

@register_model("phi3v")

`

``

18

`+

class Phi3v(lmms):

`

``

19

`+

"""

`

``

20

`+

TODO(vifragos): Document me!

`

``

21

`+

"""

`

``

22

`+

def init(

`

``

23

`+

self,

`

``

24

`+

model_id_name: str = "microsoft/Phi-3-vision-128k-instruct",

`

``

25

`+

device: str = "cuda",

`

``

26

`+

dtype: Optional[Union[str, torch.dtype]] = "auto",

`

``

27

`+

batch_size: int = 1,

`

``

28

`+

trust_remote_code: Optional[bool] = True,

`

``

29

`+

use_cache: bool = True,

`

``

30

`+

**kwargs,

`

``

31

`+

) -> None:

`

``

32

`+

super().init()

`

``

33

`+

Do not use kwargs for now

`

``

34

`+

assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

`

``

35

`+

Setup accelerator.

`

``

36

`+

accelerator = Accelerator()

`

``

37

`+

if accelerator.num_processes > 1:

`

``

38

`+

self._device = torch.device(

`

``

39

`+

f"cuda:{accelerator.local_process_index}")

`

``

40

`+

else:

`

``

41

`+

self._device = device

`

``

42

`+

Load model.

`

``

43

`+

self._model = AutoModelForCausalLM.from_pretrained(

`

``

44

`+

model_id_name,

`

``

45

`+

device_map=device,

`

``

46

`+

trust_remote_code=trust_remote_code,

`

``

47

`+

torch_dtype=dtype)

`

``

48

`+

self._processor = AutoProcessor.from_pretrained(

`

``

49

`+

model_id_name,

`

``

50

`+

trust_remote_code=trust_remote_code)

`

``

51

`+

self._processor.tokenizer.padding_side = "left"

`

``

52

`+

self._tokenizer = self._processor.tokenizer

`

``

53

`+

self._config = self._model.config

`

``

54

`+

self.batch_size_per_gpu = int(batch_size)

`

``

55

`+

assert self.batch_size_per_gpu == 1, \

`

``

56

`+

"batch_size_per_gpu > 1 is not supported for now."

`

``

57

`+

self.use_cache = use_cache

`

``

58

`+

if accelerator.num_processes > 1:

`

``

59

`+

distributed_type_list = [

`

``

60

`+

DistributedType.FSDP,

`

``

61

`+

DistributedType.MULTI_GPU,

`

``

62

`+

DistributedType.DEEPSPEED

`

``

63

`+

]

`

``

64

`+

assert accelerator.distributed_type in distributed_type_list, \

`

``

65

`+

"Unsupported distributed type provided. Only DDP and FSDP are supported."

`

``

66

`+

if accelerator.distributed_type == DistributedType.FSDP:

`

``

67

`+

self._model = accelerator.prepare(self.model)

`

``

68

`+

else:

`

``

69

`+

self._model = accelerator.prepare_model(

`

``

70

`+

self.model,

`

``

71

`+

evaluation_mode=True)

`

``

72

`+

self.accelerator = accelerator

`

``

73

`+

if self.accelerator.is_local_main_process:

`

``

74

`+

eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")

`

``

75

`+

self._rank = self.accelerator.local_process_index

`

``

76

`+

self._world_size = self.accelerator.num_processes

`

``

77

`+

else:

`

``

78

`+

eval_logger.info(f"Using single device: {self._device}")

`

``

79

`+

self.model.to(self._device)

`

``

80

`+

self._rank = 0

`

``

81

`+

self._word_size = 1

`

``

82

+

``

83

`+

@property

`

``

84

`+

def config(self):

`

``

85

`+

return the associated transformers.AutoConfig for the given pretrained model.

`

``

86

`+

return self._config

`

``

87

+

``

88

`+

@property

`

``

89

`+

def tokenizer(self):

`

``

90

`+

return self._tokenizer

`

``

91

+

``

92

`+

@property

`

``

93

`+

def model(self):

`

``

94

`+

returns the model, unwrapping it if using Accelerate

`

``

95

`+

if hasattr(self, "accelerator"):

`

``

96

`+

return self.accelerator.unwrap_model(self._model)

`

``

97

`+

else:

`

``

98

`+

return self._model

`

``

99

+

``

100

`+

@property

`

``

101

`+

def eot_token_id(self):

`

``

102

`+

we use EOT because end of text is more accurate for what we're doing than end of sentence

`

``

103

`+

return self.tokenizer.eos_token_id

`

``

104

+

``

105

`+

@property

`

``

106

`+

def max_length(self):

`

``

107

`+

return self._max_length

`

``

108

+

``

109

`+

@property

`

``

110

`+

def batch_size(self):

`

``

111

`+

return self.batch_size_per_gpu

`

``

112

+

``

113

`+

@property

`

``

114

`+

def device(self):

`

``

115

`+

return self._device

`

``

116

+

``

117

`+

@property

`

``

118

`+

def rank(self):

`

``

119

`+

return self._rank

`

``

120

+

``

121

`+

@property

`

``

122

`+

def world_size(self):

`

``

123

`+

return self._world_size

`

``

124

+

``

125

`+

def flatten(self, input):

`

``

126

`+

new_list = []

`

``

127

`+

for i in input:

`

``

128

`+

for j in i:

`

``

129

`+

new_list.append(j)

`

``

130

`+

return new_list

`

``

131

+

``

132

`+

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:

`

``

133

`+

raise NotImplementedError("Not implemented for Phi3v.")

`

``

134

+

``

135

`+

def generate_until(self, requests: List[Instance]) -> List[str]:

`

``

136

`+

res = []

`

``

137

+

``

138

`+

def _collate(x):

`

``

139

`+

the negative sign on len(toks) sorts descending - this has a few advantages:

`

``

140

`+

- time estimates will always be over not underestimates, which is more useful for planning

`

``

141

`+

- to know the size of a batch when going through the list, you know the first one is always the batch

`

``

142

`+

padded context length. this is useful to simplify the batching logic and more importantly to make

`

``

143

`+

automatic adaptive batches much much easier to implement

`

``

144

`+

- any OOMs will happen right away rather than near the end

`

``

145

`+

toks = self.tokenizer.encode(x[0])

`

``

146

`+

return -len(toks), x[0]

`

``

147

+

``

148

`+

pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

`

``

149

`+

we group requests by their generation_kwargs,

`

``

150

`+

so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling

`

``

151

`+

in the same batch.

`

``

152

`+

re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)

`

``

153

`+

chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)

`

``

154

`+

for chunk in chunks:

`

``

155

`+

contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)

`

``

156

`+

task = task[0]

`

``

157

`+

split = split[0]

`

``

158

`+

visuals = [doc_to_visual0 for ids in doc_id]

`

``

159

`+

visuals = self.flatten(visuals)

`

``

160

`+

We assume all gen kwargs in the batch are the same

`

``

161

`` +

this is safe to assume because the grouper object ensures it.

``

``

162

`+

gen_kwargs = all_gen_kwargs[0]

`

``

163

`+

Set default values for until and max_new_tokens

`

``

164

`+

until = [self.tokenizer.decode(self.eot_token_id)]

`

``

165

`+

Update values from gen_kwargs if present

`

``

166

`+

if "until" in gen_kwargs:

`

``

167

`+

until = gen_kwargs.pop("until")

`

``

168

`+

if isinstance(until, str):

`

``

169

`+

until = [until]

`

``

170

`+

elif not isinstance(until, list):

`

``

171

`+

raise ValueError(

`

``

172

`` +

f"Expected gen_kwargs['until'] to be of type Union[str,list] but got {type(until)}")

``

``

173

`+

if isinstance(contexts, tuple):

`

``

174

`+

contexts = list(contexts)

`

``

175

`+

for i in range(len(contexts)):

`

``

176

`+

if "" in contexts[i]:

`

``

177

`+

query = contexts[i].replace("", "<|image_1|>")

`

``

178

`+

else:

`

``

179

`+

query = f"<|image_1|>\n{contexts[i]}"

`

``

180

`+

messages = [

`

``

181

`+

{"role": "user", "content": query}

`

``

182

`+

]

`

``

183

`+

contexts[i] = self._tokenizer.apply_chat_template(

`

``

184

`+

messages,

`

``

185

`+

tokenize=False,

`

``

186

`+

add_generation_prompt=True)

`

``

187

`+

assert len(contexts) == 1

`

``

188

`+

We always pass a single image given that the model only accepts one image (as of 5/21/24).

`

``

189

`+

context = contexts[0]

`

``

190

`+

pil_image = visuals[0]

`

``

191

`+

input_ids = self._processor(

`

``

192

`+

text=context,

`

``

193

`+

images=[pil_image],

`

``

194

`+

return_tensors="pt").to(self._device, self.model.dtype)

`

``

195

`+

Setting default parameters.

`

``

196

`+

if "max_new_tokens" not in gen_kwargs:

`

``

197

`+

gen_kwargs["max_new_tokens"] = 1024

`

``

198

`+

if "temperature" not in gen_kwargs:

`

``

199

`+

gen_kwargs["temperature"] = 0

`

``

200

`+

if "top_p" not in gen_kwargs:

`

``

201

`+

gen_kwargs["top_p"] = None

`

``

202

`+

if "num_beams" not in gen_kwargs:

`

``

203

`+

gen_kwargs["num_beams"] = 1

`

``

204

`+

Generate answer.

`

``

205

`+

pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None \

`

``

206

`+

else self.tokenizer.eod_id

`

``

207

`+

generate_ids = self.model.generate(

`

``

208

`+

**input_ids,

`

``

209

`+

eos_token_id=self.tokenizer.eos_token_id,

`

``

210

`+

pad_token_id=pad_token_id,

`

``

211

`+

do_sample=True if gen_kwargs["temperature"] > 0 else False,

`

``

212

`+

temperature=gen_kwargs["temperature"],

`

``

213

`+

top_p=gen_kwargs["top_p"],

`

``

214

`+

num_beams=gen_kwargs["num_beams"],

`

``

215

`+

max_new_tokens=gen_kwargs["max_new_tokens"],

`

``

216

`+

use_cache=self.use_cache,

`

``

217

`+

)

`

``

218

`+

generate_ids = generate_ids[:, input_ids['input_ids'].shape[1]:]

`

``

219

`+

response = self._processor.batch_decode(

`

``

220

`+

generate_ids,

`

``

221

`+

skip_special_tokens=True,

`

``

222

`+

clean_up_tokenization_spaces=False)[0]

`

``

223

`+

res.append(response)

`

``

224

`+

self.cache_hook.add_partial("generate_until", (context, gen_kwargs), response)

`

``

225

`+

pbar.update(1)

`

``

226

`+

reorder this group of results back to original unsorted form

`

``

227

`+

res = re_ords.get_original(res)

`

``

228

`+

pbar.close()

`

``

229

`+

return res

`