Merge pull request #73 from EvolvingLMMs-Lab/kc/qwen_vl_api · EvolvingLMMs-Lab/lmms-eval@caa5893 (original) (raw)

``

1

`+

from io import BytesIO

`

``

2

`+

from copy import deepcopy

`

``

3

`+

import os

`

``

4

`+

import base64

`

``

5

`+

from typing import List, Tuple, Union

`

``

6

`+

from tqdm import tqdm

`

``

7

`+

import requests as url_requests

`

``

8

`+

import time

`

``

9

`+

import logging

`

``

10

+

``

11

`+

from lmms_eval.api.instance import Instance

`

``

12

`+

from lmms_eval.api.model import lmms

`

``

13

`+

from lmms_eval.api.registry import register_model

`

``

14

`+

from lmms_eval import utils

`

``

15

+

``

16

`+

from PIL import Image

`

``

17

+

``

18

`+

NUM_SECONDS_TO_SLEEP = 5

`

``

19

`+

eval_logger = logging.getLogger("lmms-eval")

`

``

20

+

``

21

`+

try:

`

``

22

`+

import dashscope

`

``

23

`+

except:

`

``

24

`+

eval_logger.debug("Can not import Dashscope")

`

``

25

+

``

26

`+

API_KEY = os.getenv("DASHSCOPE_API_KEY", "YOUR_API_KEY")

`

``

27

+

``

28

+

``

29

`+

@register_model("qwen-vl-api")

`

``

30

`+

class Qwen_VL_API(lmms):

`

``

31

`+

def init(

`

``

32

`+

self,

`

``

33

`+

model_version: str = "qwen-vl-max",

`

``

34

`+

image_token: str = "", # Use to separate interleaved image and text

`

``

35

`+

system_prompt: str = "", # Whether you want some special system prompt here

`

``

36

`+

tmp_folder: str = "./tmp", # Due to qwen's api restriction,

`

``

37

`+

**kwargs,

`

``

38

`+

) -> None:

`

``

39

`+

super().init()

`

``

40

+

``

41

`+

self.model_version = model_version

`

``

42

`+

self.image_token = image_token

`

``

43

`+

self.system_prompt = system_prompt

`

``

44

`+

self.tmp_folder = tmp_folder

`

``

45

+

``

46

`+

@property

`

``

47

`+

def rank(self):

`

``

48

`+

return self._rank

`

``

49

+

``

50

`+

@property

`

``

51

`+

def world_size(self):

`

``

52

`+

return self._world_size

`

``

53

+

``

54

`+

def generate_until(self, requests) -> List[str]:

`

``

55

`+

res = []

`

``

56

`+

pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

`

``

57

`+

os.makedirs(self.tmp_folder, exist_ok=True)

`

``

58

+

``

59

`+

for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:

`

``

60

`+

encode, pad, and truncate contexts for this batch

`

``

61

`+

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]

`

``

62

`+

visuals = self.flatten(visuals)

`

``

63

`+

imgs = []

`

``

64

+

``

65

`+

for idx, visual in enumerate(visuals):

`

``

66

`+

visual.save(os.path.join(self.tmp_folder, f"tmp_{idx}{self.rank}{self.world_size}.jpg"))

`

``

67

`+

imgs.append(os.path.join(self.tmp_folder, f"tmp_{idx}{self.rank}{self.world_size}.jpg"))

`

``

68

+

``

69

`+

messages = [{"role": "user", "content": []}]

`

``

70

+

``

71

`+

if self.image_token not in contexts:

`

``

72

`+

for img in imgs:

`

``

73

`+

messages[0]["content"].append({"image": img})

`

``

74

`+

messages[0]["content"].append({"text": contexts})

`

``

75

`+

else:

`

``

76

`+

contexts = contexts.split(self.image_token)

`

``

77

+

``

78

`+

for idx, img in enumerate(imgs):

`

``

79

`+

messages[0]["content"].append({"text": contexts[idx]})

`

``

80

`+

messages[0]["content"].append({"image": img})

`

``

81

`+

messages[0]["content"].append({"text": contexts[-1]})

`

``

82

+

``

83

`+

if "max_new_tokens" not in gen_kwargs or gen_kwargs["max_new_tokens"] > 1500:

`

``

84

`+

gen_kwargs["max_new_tokens"] = 1024

`

``

85

`+

if "temperature" not in gen_kwargs:

`

``

86

`+

gen_kwargs["temperature"] = 0

`

``

87

`+

if "top_p" not in gen_kwargs:

`

``

88

`+

gen_kwargs["top_p"] = None

`

``

89

`+

if "num_beams" not in gen_kwargs:

`

``

90

`+

gen_kwargs["num_beams"] = 1

`

``

91

+

``

92

`+

for attempt in range(5):

`

``

93

`+

try:

`

``

94

`+

response_data = dashscope.MultiModalConversation.call(model=self.model_version, messages=messages, api_key=API_KEY, max_length=gen_kwargs["max_new_tokens"])

`

``

95

`+

except Exception as e:

`

``

96

`+

eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}")

`

``

97

`+

if attempt < 5 - 1: # If we have retries left, sleep and then continue to next attempt

`

``

98

`+

time.sleep(NUM_SECONDS_TO_SLEEP)

`

``

99

`+

else: # If this was the last attempt, log and return empty

`

``

100

`+

eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}")

`

``

101

`+

res.append("")

`

``

102

`+

pbar.update(1)

`

``

103

`+

continue

`

``

104

`+

try:

`

``

105

`+

res.append(response_data["output"]["choices"][0]["message"]["content"][0]["text"].strip())

`

``

106

`+

except Exception as e:

`

``

107

`+

eval_logger.error(f"Error {e} happens when parsing input.")

`

``

108

`+

eval_logger.error(f"{response_data}")

`

``

109

`+

res.append("")

`

``

110

`+

pbar.update(1)

`

``

111

+

``

112

`+

pbar.close()

`

``

113

+

``

114

`+

return res

`

``

115

+

``

116

`+

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:

`

``

117

`+

assert False, "Not supported for claude"

`

``

118

+

``

119

`+

def flatten(self, input):

`

``

120

`+

new_list = []

`

``

121

`+

for i in input:

`

``

122

`+

for j in i:

`

``

123

`+

new_list.append(j)

`

``

124

`+

return new_list

`