Merge pull request #113 from teowu/main · dadwadw233/lmms-eval@ac3a66f (original) (raw)
``
1
`+
import json
`
``
2
`+
import logging
`
``
3
`+
import re
`
``
4
`+
from collections import Counter, defaultdict
`
``
5
`+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
`
``
6
+
``
7
+
``
8
`+
def q_bench_doc_to_text(doc, model_specific_prompt_kwargs):
`
``
9
`+
candidates = []
`
``
10
`+
for i in range(4):
`
``
11
`+
candidate = doc.get(f"option{i}")
`
``
12
`+
if candidate != "N/A":
`
``
13
`+
candidates.append(candidate)
`
``
14
+
``
15
`+
question = doc["question"] + "\n" + "\n".join([". ".join([chr(ord("A")+i), candidate]) for i, candidate in enumerate(candidates)])
`
``
16
`+
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
`
``
17
`+
post_prompt = model_specific_prompt_kwargs["post_prompt"]
`
``
18
`+
return f"{pre_prompt}{question}\n{post_prompt}"
`
``
19
+
``
20
+
``
21
`+
def q_bench_doc_to_visual(doc):
`
``
22
`+
if "image2" not in doc:
`
``
23
`+
return [doc["image"].convert("RGB")]
`
``
24
`+
else:
`
``
25
`+
return [doc["image1"].convert("RGB"), doc["image2"].convert("RGB")]
`
``
26
+
``
27
+
``
28
`+
def get_multi_choice_info(options):
`
``
29
`+
"""
`
``
30
`+
Given the list of options for multiple choice question
`
``
31
`+
Return the index2ans and all_choices
`
``
32
`+
`
``
33
`+
"""
`
``
34
+
``
35
`+
start_chr = "A"
`
``
36
`+
all_choices = []
`
``
37
`+
index2ans = {}
`
``
38
`+
for i, option in enumerate(options):
`
``
39
`+
index2ans[chr(ord(start_chr) + i)] = option
`
``
40
`+
all_choices.append(chr(ord(start_chr) + i))
`
``
41
+
``
42
`+
return index2ans, all_choices
`
``
43
+
``
44
+
``
45
`+
def parse_multi_choice_response(response, all_choices, index2ans):
`
``
46
`+
"""
`
``
47
`+
Parse the prediction from the generated response.
`
``
48
`+
Return the predicted index e.g., A, B, C, D.
`
``
49
`+
`
``
50
`+
"""
`
``
51
`+
for char in [",", ".", "!", "?", ";", ":", "'"]:
`
``
52
`+
response = response.strip(char)
`
``
53
`+
response = " " + response + " " # add space to avoid partial match
`
``
54
+
``
55
`+
index_ans = True
`
``
56
`+
ans_with_brack = False
`
``
57
`+
candidates = []
`
``
58
`+
for choice in all_choices: # e.g., (A) (B) (C) (D)
`
``
59
`+
if f"({choice})" in response:
`
``
60
`+
candidates.append(choice)
`
``
61
`+
ans_with_brack = True
`
``
62
+
``
63
`+
if len(candidates) == 0:
`
``
64
`+
for choice in all_choices: # e.g., A B C D
`
``
65
`+
if f"{choice} " in response:
`
``
66
`+
candidates.append(choice)
`
``
67
+
``
68
`+
if len(candidates) == 0:
`
``
69
`+
for choice in all_choices: # e.g., A. B. C. D.
`
``
70
`+
if f"{choice}." in response:
`
``
71
`+
candidates.append(choice)
`
``
72
+
``
73
`+
if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
`
``
74
`+
if len(candidates) == 0 and len(response.split()) > 5:
`
``
75
`+
for index, ans in index2ans.items():
`
``
76
`+
if ans.lower() in response.lower():
`
``
77
`+
candidates.append(index)
`
``
78
`+
index_ans = False # it's content ans.
`
``
79
+
``
80
`+
if len(candidates) == 0: # still not get answer, randomly choose one.
`
``
81
`+
pred_index = random.choice(all_choices)
`
``
82
`+
elif len(candidates) > 1:
`
``
83
`+
start_indexes = []
`
``
84
`+
if index_ans:
`
``
85
`+
if ans_with_brack:
`
``
86
`+
for can in candidates:
`
``
87
`+
index = response.rfind(f"({can})")
`
``
88
`+
start_indexes.append(index) # -1 will be ignored anyway
`
``
89
`+
start_indexes = [generated_response.index(f'({can})') for can in candidates]
`
``
90
`+
else:
`
``
91
`+
for can in candidates:
`
``
92
`+
index = response.rfind(f" {can} ")
`
``
93
`+
start_indexes.append(index)
`
``
94
`+
else:
`
``
95
`+
for can in candidates:
`
``
96
`+
index = response.lower().rfind(index2ans[can].lower())
`
``
97
`+
start_indexes.append(index)
`
``
98
`+
get the last one
`
``
99
`+
pred_index = candidates[np.argmax(start_indexes)]
`
``
100
`+
else: # if only one candidate, use it.
`
``
101
`+
pred_index = candidates[0]
`
``
102
+
``
103
`+
return pred_index
`
``
104
+
``
105
+
``
106
`+
def evaluate_q_bench(samples):
`
``
107
`+
pred_correct = 0
`
``
108
`+
judge_dict = dict()
`
``
109
`+
for sample in samples:
`
``
110
`+
gold_i = sample["answer"]
`
``
111
`+
pred_i = sample["parsed_pred"]
`
``
112
`+
correct = eval_multi_choice(gold_i, pred_i)
`
``
113
+
``
114
`+
if correct:
`
``
115
`+
judge_dict[sample["id"]] = "Correct"
`
``
116
`+
pred_correct += 1
`
``
117
`+
else:
`
``
118
`+
judge_dict[sample["id"]] = "Wrong"
`
``
119
+
``
120
`+
if len(samples) == 0:
`
``
121
`+
return {"acc": 0}
`
``
122
`+
return judge_dict, {"acc": pred_correct / len(samples)}
`
``
123
+
``
124
`+
def eval_multi_choice(gold_i, pred_i):
`
``
125
`+
correct = False
`
``
126
`+
only they are exactly the same, we consider it as correct
`
``
127
`+
if isinstance(gold_i, list):
`
``
128
`+
for answer in gold_i:
`
``
129
`+
if answer == pred_i:
`
``
130
`+
correct = True
`
``
131
`+
break
`
``
132
`+
else: # gold_i is a string
`
``
133
`+
if gold_i == pred_i:
`
``
134
`+
correct = True
`
``
135
`+
return correct
`
``
136
+
``
137
`+
def calculate_ins_level_acc(results):
`
``
138
`+
"""Calculate the instruction level accuracy for given Subject results
`
``
139
`+
`
``
140
`+
"""
`
``
141
`+
acc = 0
`
``
142
`+
ins_num = 0
`
``
143
`+
for cat_results in results.values():
`
``
144
`+
acc += cat_results["acc"] * cat_results["num_example"]
`
``
145
`+
ins_num += cat_results["num_example"]
`
``
146
`+
if ins_num == 0:
`
``
147
`+
return 0
`
``
148
`+
return acc / ins_num
`
``
149
+
``
150
+
``
151
`+
def q_bench_process_results(doc, results):
`
``
152
`+
pred = results[0]
`
``
153
`+
all_choices = []
`
``
154
`+
index2ans = {}
`
``
155
`+
for i in range(4):
`
``
156
`+
option = doc.get(f"option{i}")
`
``
157
`+
if option == "N/A":
`
``
158
`+
break
`
``
159
`+
index2ans[chr(ord("A") + i)] = option
`
``
160
`+
all_choices.append(chr(ord("A") + i))
`
``
161
+
``
162
`+
parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
`
``
163
`+
id = doc["id"]
`
``
164
`+
qbench_acc = {"id": id, "question_concern": doc["question_concern"], "question_type": doc["question_type"], "answer": doc["correct_choice"], "parsed_pred": parsed_pred}
`
``
165
`+
return {
`
``
166
`+
"qbench_acc": qbench_acc,
`
``
167
`+
"submission": {
`
``
168
`+
id: pred,
`
``
169
`+
},
`
``
170
`+
}
`
``
171
+
``
172
+
``
173
`+
concern_list = ["Global Distortion", "Global Others", "Local Distortion", "Local Others"]
`
``
174
`+
question_list = ["Yes/No", "How", "What"]
`
``
175
+
``
176
`+
def q_bench_aggregate_results(results):
`
``
177
`+
evaluation_result = {}
`
``
178
`+
subset_to_eval_samples = defaultdict(list)
`
``
179
`+
for result in results:
`
``
180
`+
subset_to_eval_samples[concern_list[result["question_concern"]]].append(result)
`
``
181
`+
subset_to_eval_samples[question_list[result["question_type"]]].append(result)
`
``
182
`+
for subset, sub_eval_samples in subset_to_eval_samples.items():
`
``
183
`+
judge_dict, metric_dict = evaluate_q_bench(sub_eval_samples)
`
``
184
`+
metric_dict.update({"num_example": len(sub_eval_samples)})
`
``
185
`+
evaluation_result[subset] = metric_dict
`
``
186
`+
printable_results = {}
`
``
187
+
``
188
`+
for cat_name, cat_results in evaluation_result.items():
`
``
189
`+
printable_results[cat_name] = {
`
``
190
`+
"num": int(cat_results["num_example"]),
`
``
191
`+
"acc": round(cat_results["acc"], 5),
`
``
192
`+
}
`
``
193
`+
all_ins_acc = calculate_ins_level_acc(evaluation_result)
`
``
194
`+
printable_results["Overall"] = {
`
``
195
`+
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),
`
``
196
`+
"acc": round(all_ins_acc, 5),
`
``
197
`+
}
`
``
198
`+
print(printable_results)
`
``
199
`+
return printable_results["Overall"]["acc"]
`
``
200
+
``
201
`+
def a_bench_process_results(doc, results):
`
``
202
`+
pred = results[0]
`
``
203
`+
all_choices = []
`
``
204
`+
index2ans = {}
`
``
205
`+
for i in range(4):
`
``
206
`+
option = doc.get(f"option{i}")
`
``
207
`+
if option == "N/A":
`
``
208
`+
break
`
``
209
`+
index2ans[chr(ord("A") + i)] = option
`
``
210
`+
all_choices.append(chr(ord("A") + i))
`
``
211
+
``
212
`+
parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
`
``
213
`+
id = doc["id"]
`
``
214
`+
abench_acc = {"id": id, "category": doc["category"], "answer": doc["correct_choice"], "parsed_pred": parsed_pred}
`
``
215
`+
return {
`
``
216
`+
"abench_acc": abench_acc,
`
``
217
`+
"submission": {
`
``
218
`+
id: pred,
`
``
219
`+
},
`
``
220
`+
}
`
``
221
+
``
222
+
``
223
+
``
224
`+
def a_bench_aggregate_results(results):
`
``
225
`+
evaluation_result = {}
`
``
226
`+
subset_to_eval_samples = defaultdict(list)
`
``
227
`+
for result in results:
`
``
228
`+
subset_to_eval_samples[result["category"]].append(result)
`
``
229
`+
for subset, sub_eval_samples in subset_to_eval_samples.items():
`
``
230
`+
judge_dict, metric_dict = evaluate_q_bench(sub_eval_samples)
`
``
231
`+
metric_dict.update({"num_example": len(sub_eval_samples)})
`
``
232
`+
evaluation_result[subset] = metric_dict
`
``
233
`+
printable_results = {}
`
``
234
+
``
235
`+
for cat_name, cat_results in evaluation_result.items():
`
``
236
`+
printable_results[cat_name] = {
`
``
237
`+
"num": int(cat_results["num_example"]),
`
``
238
`+
"acc": round(cat_results["acc"], 5),
`
``
239
`+
}
`
``
240
`+
all_ins_acc = calculate_ins_level_acc(evaluation_result)
`
``
241
`+
printable_results["Overall"] = {
`
``
242
`+
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),
`
``
243
`+
"acc": round(all_ins_acc, 5),
`
``
244
`+
}
`
``
245
`+
print(printable_results)
`
``
246
`+
return printable_results["Overall"]["acc"]
`
``
247
+