Merge pull request #113 from teowu/main · dadwadw233/lmms-eval@ac3a66f (original) (raw)

``

1

`+

import json

`

``

2

`+

import logging

`

``

3

`+

import re

`

``

4

`+

from collections import Counter, defaultdict

`

``

5

`+

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

`

``

6

+

``

7

+

``

8

`+

def q_bench_doc_to_text(doc, model_specific_prompt_kwargs):

`

``

9

`+

candidates = []

`

``

10

`+

for i in range(4):

`

``

11

`+

candidate = doc.get(f"option{i}")

`

``

12

`+

if candidate != "N/A":

`

``

13

`+

candidates.append(candidate)

`

``

14

+

``

15

`+

question = doc["question"] + "\n" + "\n".join([". ".join([chr(ord("A")+i), candidate]) for i, candidate in enumerate(candidates)])

`

``

16

`+

pre_prompt = model_specific_prompt_kwargs["pre_prompt"]

`

``

17

`+

post_prompt = model_specific_prompt_kwargs["post_prompt"]

`

``

18

`+

return f"{pre_prompt}{question}\n{post_prompt}"

`

``

19

+

``

20

+

``

21

`+

def q_bench_doc_to_visual(doc):

`

``

22

`+

if "image2" not in doc:

`

``

23

`+

return [doc["image"].convert("RGB")]

`

``

24

`+

else:

`

``

25

`+

return [doc["image1"].convert("RGB"), doc["image2"].convert("RGB")]

`

``

26

+

``

27

+

``

28

`+

def get_multi_choice_info(options):

`

``

29

`+

"""

`

``

30

`+

Given the list of options for multiple choice question

`

``

31

`+

Return the index2ans and all_choices

`

``

32

`+

https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/data_utils.py#L54

`

``

33

`+

"""

`

``

34

+

``

35

`+

start_chr = "A"

`

``

36

`+

all_choices = []

`

``

37

`+

index2ans = {}

`

``

38

`+

for i, option in enumerate(options):

`

``

39

`+

index2ans[chr(ord(start_chr) + i)] = option

`

``

40

`+

all_choices.append(chr(ord(start_chr) + i))

`

``

41

+

``

42

`+

return index2ans, all_choices

`

``

43

+

``

44

+

``

45

`+

def parse_multi_choice_response(response, all_choices, index2ans):

`

``

46

`+

"""

`

``

47

`+

Parse the prediction from the generated response.

`

``

48

`+

Return the predicted index e.g., A, B, C, D.

`

``

49

`+

https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L10

`

``

50

`+

"""

`

``

51

`+

for char in [",", ".", "!", "?", ";", ":", "'"]:

`

``

52

`+

response = response.strip(char)

`

``

53

`+

response = " " + response + " " # add space to avoid partial match

`

``

54

+

``

55

`+

index_ans = True

`

``

56

`+

ans_with_brack = False

`

``

57

`+

candidates = []

`

``

58

`+

for choice in all_choices: # e.g., (A) (B) (C) (D)

`

``

59

`+

if f"({choice})" in response:

`

``

60

`+

candidates.append(choice)

`

``

61

`+

ans_with_brack = True

`

``

62

+

``

63

`+

if len(candidates) == 0:

`

``

64

`+

for choice in all_choices: # e.g., A B C D

`

``

65

`+

if f"{choice} " in response:

`

``

66

`+

candidates.append(choice)

`

``

67

+

``

68

`+

if len(candidates) == 0:

`

``

69

`+

for choice in all_choices: # e.g., A. B. C. D.

`

``

70

`+

if f"{choice}." in response:

`

``

71

`+

candidates.append(choice)

`

``

72

+

``

73

`+

if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example

`

``

74

`+

if len(candidates) == 0 and len(response.split()) > 5:

`

``

75

`+

for index, ans in index2ans.items():

`

``

76

`+

if ans.lower() in response.lower():

`

``

77

`+

candidates.append(index)

`

``

78

`+

index_ans = False # it's content ans.

`

``

79

+

``

80

`+

if len(candidates) == 0: # still not get answer, randomly choose one.

`

``

81

`+

pred_index = random.choice(all_choices)

`

``

82

`+

elif len(candidates) > 1:

`

``

83

`+

start_indexes = []

`

``

84

`+

if index_ans:

`

``

85

`+

if ans_with_brack:

`

``

86

`+

for can in candidates:

`

``

87

`+

index = response.rfind(f"({can})")

`

``

88

`+

start_indexes.append(index) # -1 will be ignored anyway

`

``

89

`+

start_indexes = [generated_response.index(f'({can})') for can in candidates]

`

``

90

`+

else:

`

``

91

`+

for can in candidates:

`

``

92

`+

index = response.rfind(f" {can} ")

`

``

93

`+

start_indexes.append(index)

`

``

94

`+

else:

`

``

95

`+

for can in candidates:

`

``

96

`+

index = response.lower().rfind(index2ans[can].lower())

`

``

97

`+

start_indexes.append(index)

`

``

98

`+

get the last one

`

``

99

`+

pred_index = candidates[np.argmax(start_indexes)]

`

``

100

`+

else: # if only one candidate, use it.

`

``

101

`+

pred_index = candidates[0]

`

``

102

+

``

103

`+

return pred_index

`

``

104

+

``

105

+

``

106

`+

def evaluate_q_bench(samples):

`

``

107

`+

pred_correct = 0

`

``

108

`+

judge_dict = dict()

`

``

109

`+

for sample in samples:

`

``

110

`+

gold_i = sample["answer"]

`

``

111

`+

pred_i = sample["parsed_pred"]

`

``

112

`+

correct = eval_multi_choice(gold_i, pred_i)

`

``

113

+

``

114

`+

if correct:

`

``

115

`+

judge_dict[sample["id"]] = "Correct"

`

``

116

`+

pred_correct += 1

`

``

117

`+

else:

`

``

118

`+

judge_dict[sample["id"]] = "Wrong"

`

``

119

+

``

120

`+

if len(samples) == 0:

`

``

121

`+

return {"acc": 0}

`

``

122

`+

return judge_dict, {"acc": pred_correct / len(samples)}

`

``

123

+

``

124

`+

def eval_multi_choice(gold_i, pred_i):

`

``

125

`+

correct = False

`

``

126

`+

only they are exactly the same, we consider it as correct

`

``

127

`+

if isinstance(gold_i, list):

`

``

128

`+

for answer in gold_i:

`

``

129

`+

if answer == pred_i:

`

``

130

`+

correct = True

`

``

131

`+

break

`

``

132

`+

else: # gold_i is a string

`

``

133

`+

if gold_i == pred_i:

`

``

134

`+

correct = True

`

``

135

`+

return correct

`

``

136

+

``

137

`+

def calculate_ins_level_acc(results):

`

``

138

`+

"""Calculate the instruction level accuracy for given Subject results

`

``

139

`+

https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L246

`

``

140

`+

"""

`

``

141

`+

acc = 0

`

``

142

`+

ins_num = 0

`

``

143

`+

for cat_results in results.values():

`

``

144

`+

acc += cat_results["acc"] * cat_results["num_example"]

`

``

145

`+

ins_num += cat_results["num_example"]

`

``

146

`+

if ins_num == 0:

`

``

147

`+

return 0

`

``

148

`+

return acc / ins_num

`

``

149

+

``

150

+

``

151

`+

def q_bench_process_results(doc, results):

`

``

152

`+

pred = results[0]

`

``

153

`+

all_choices = []

`

``

154

`+

index2ans = {}

`

``

155

`+

for i in range(4):

`

``

156

`+

option = doc.get(f"option{i}")

`

``

157

`+

if option == "N/A":

`

``

158

`+

break

`

``

159

`+

index2ans[chr(ord("A") + i)] = option

`

``

160

`+

all_choices.append(chr(ord("A") + i))

`

``

161

+

``

162

`+

parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)

`

``

163

`+

id = doc["id"]

`

``

164

`+

qbench_acc = {"id": id, "question_concern": doc["question_concern"], "question_type": doc["question_type"], "answer": doc["correct_choice"], "parsed_pred": parsed_pred}

`

``

165

`+

return {

`

``

166

`+

"qbench_acc": qbench_acc,

`

``

167

`+

"submission": {

`

``

168

`+

id: pred,

`

``

169

`+

},

`

``

170

`+

}

`

``

171

+

``

172

+

``

173

`+

concern_list = ["Global Distortion", "Global Others", "Local Distortion", "Local Others"]

`

``

174

`+

question_list = ["Yes/No", "How", "What"]

`

``

175

+

``

176

`+

def q_bench_aggregate_results(results):

`

``

177

`+

evaluation_result = {}

`

``

178

`+

subset_to_eval_samples = defaultdict(list)

`

``

179

`+

for result in results:

`

``

180

`+

subset_to_eval_samples[concern_list[result["question_concern"]]].append(result)

`

``

181

`+

subset_to_eval_samples[question_list[result["question_type"]]].append(result)

`

``

182

`+

for subset, sub_eval_samples in subset_to_eval_samples.items():

`

``

183

`+

judge_dict, metric_dict = evaluate_q_bench(sub_eval_samples)

`

``

184

`+

metric_dict.update({"num_example": len(sub_eval_samples)})

`

``

185

`+

evaluation_result[subset] = metric_dict

`

``

186

`+

printable_results = {}

`

``

187

+

``

188

`+

for cat_name, cat_results in evaluation_result.items():

`

``

189

`+

printable_results[cat_name] = {

`

``

190

`+

"num": int(cat_results["num_example"]),

`

``

191

`+

"acc": round(cat_results["acc"], 5),

`

``

192

`+

}

`

``

193

`+

all_ins_acc = calculate_ins_level_acc(evaluation_result)

`

``

194

`+

printable_results["Overall"] = {

`

``

195

`+

"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),

`

``

196

`+

"acc": round(all_ins_acc, 5),

`

``

197

`+

}

`

``

198

`+

print(printable_results)

`

``

199

`+

return printable_results["Overall"]["acc"]

`

``

200

+

``

201

`+

def a_bench_process_results(doc, results):

`

``

202

`+

pred = results[0]

`

``

203

`+

all_choices = []

`

``

204

`+

index2ans = {}

`

``

205

`+

for i in range(4):

`

``

206

`+

option = doc.get(f"option{i}")

`

``

207

`+

if option == "N/A":

`

``

208

`+

break

`

``

209

`+

index2ans[chr(ord("A") + i)] = option

`

``

210

`+

all_choices.append(chr(ord("A") + i))

`

``

211

+

``

212

`+

parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)

`

``

213

`+

id = doc["id"]

`

``

214

`+

abench_acc = {"id": id, "category": doc["category"], "answer": doc["correct_choice"], "parsed_pred": parsed_pred}

`

``

215

`+

return {

`

``

216

`+

"abench_acc": abench_acc,

`

``

217

`+

"submission": {

`

``

218

`+

id: pred,

`

``

219

`+

},

`

``

220

`+

}

`

``

221

+

``

222

+

``

223

+

``

224

`+

def a_bench_aggregate_results(results):

`

``

225

`+

evaluation_result = {}

`

``

226

`+

subset_to_eval_samples = defaultdict(list)

`

``

227

`+

for result in results:

`

``

228

`+

subset_to_eval_samples[result["category"]].append(result)

`

``

229

`+

for subset, sub_eval_samples in subset_to_eval_samples.items():

`

``

230

`+

judge_dict, metric_dict = evaluate_q_bench(sub_eval_samples)

`

``

231

`+

metric_dict.update({"num_example": len(sub_eval_samples)})

`

``

232

`+

evaluation_result[subset] = metric_dict

`

``

233

`+

printable_results = {}

`

``

234

+

``

235

`+

for cat_name, cat_results in evaluation_result.items():

`

``

236

`+

printable_results[cat_name] = {

`

``

237

`+

"num": int(cat_results["num_example"]),

`

``

238

`+

"acc": round(cat_results["acc"], 5),

`

``

239

`+

}

`

``

240

`+

all_ins_acc = calculate_ins_level_acc(evaluation_result)

`

``

241

`+

printable_results["Overall"] = {

`

``

242

`+

"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),

`

``

243

`+

"acc": round(all_ins_acc, 5),

`

``

244

`+

}

`

``

245

`+

print(printable_results)

`

``

246

`+

return printable_results["Overall"]["acc"]

`

``

247

+