Merge pull request #69 from hunterheiden/hsh/new_task/WebSRC · EvolvingLMMs-Lab/lmms-eval@eef3aeb (original) (raw)
``
1
`+
from collections import defaultdict
`
``
2
`+
import re
`
``
3
`+
import ast
`
``
4
`+
import base64
`
``
5
`+
import io
`
``
6
`+
import random
`
``
7
`+
import numpy as np
`
``
8
`+
import os
`
``
9
`+
import json
`
``
10
`+
import logging
`
``
11
`+
from PIL import Image
`
``
12
+
``
13
`+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
`
``
14
+
``
15
`+
lmms_logger = logging.getLogger("lmms-eval")
`
``
16
+
``
17
`+
OPEN_ENDED_PROMPT = "Answer the question using a single word or phrase."
`
``
18
+
``
19
+
``
20
`+
def construct_prompt(doc):
`
``
21
`+
question = doc["question"]
`
``
22
`+
question = f"{OPEN_ENDED_PROMPT}\n{question}"
`
``
23
`+
return question
`
``
24
+
``
25
+
``
26
`+
def websrc_doc_to_text(doc):
`
``
27
`+
question = construct_prompt(doc)
`
``
28
`+
return question
`
``
29
+
``
30
+
``
31
`+
def websrc_doc_to_visual(doc):
`
``
32
`+
img_bs64 = doc["image"]
`
``
33
`+
img = Image.open(io.BytesIO(base64.b64decode(img_bs64)))
`
``
34
`+
del doc['image']
`
``
35
`+
return [img]
`
``
36
+
``
37
+
``
38
`+
def websrc_process_results(doc, results):
`
``
39
`+
pred = results[0]
`
``
40
`+
parsed_pred = pred
`
``
41
`+
id = doc["page_id"]
`
``
42
`+
websrc_ans = {"id": id, "domain": doc['domain'], "parsed_pred": parsed_pred}
`
``
43
`+
if "answer" in doc:
`
``
44
`+
websrc_ans["answer"] = doc["answer"]
`
``
45
+
``
46
`+
if 'id' in doc:
`
``
47
`+
websrc_ans['question_id'] = doc['id']
`
``
48
+
``
49
`+
return {
`
``
50
`+
"websrc_squad_f1": websrc_ans,
`
``
51
`+
"submission": {
`
``
52
`+
websrc_ans['question_id']: pred,
`
``
53
`+
},
`
``
54
`+
}
`
``
55
+
``
56
+
``
57
`+
def websrc_test_aggregate_results_for_submission(results, args):
`
``
58
`+
path = generate_submission_file("websrc_test_for_submission.json", args)
`
``
59
`+
with open(path, "w") as f:
`
``
60
`+
out = {}
`
``
61
`+
for result in results:
`
``
62
`+
out.update(result)
`
``
63
`+
json.dump(out, f, indent=4)
`
``
64
`+
lmms_logger.info(f"Results saved to {path}.")
`
``
65
+
``
66
+
``
67
`+
def websrc_aggregate_results(results):
`
``
68
`+
evaluation_result = {}
`
``
69
+
``
70
`+
Group results by domain
`
``
71
`+
subset_to_eval_samples = defaultdict(list)
`
``
72
`+
for result in results:
`
``
73
`+
subset_to_eval_samples[result["domain"]].append(result)
`
``
74
+
``
75
`+
Evaluate each domain
`
``
76
`+
for subset, sub_eval_samples in subset_to_eval_samples.items():
`
``
77
`+
judge_dict, metric_dict = evaluate_websrc(sub_eval_samples)
`
``
78
`+
metric_dict.update({"num_example": len(sub_eval_samples)})
`
``
79
`+
evaluation_result[subset] = metric_dict
`
``
80
+
``
81
`+
Aggregate results for all domains
`
``
82
`+
printable_results = {}
`
``
83
`+
for domain in DOMAINS:
`
``
84
`+
if domain not in evaluation_result:
`
``
85
`+
continue
`
``
86
`+
printable_results[domain] = {
`
``
87
`+
"num": int(evaluation_result[domain]["num_example"]),
`
``
88
`+
"f1": round(evaluation_result[domain]["f1"], 3),
`
``
89
`+
}
`
``
90
`+
all_ins_f1 = np.sum([cat_results["f1"] * cat_results["num_example"] for cat_results in evaluation_result.values()]) / sum(
`
``
91
`+
[cat_results["num_example"] for cat_results in evaluation_result.values()]
`
``
92
`+
)
`
``
93
`+
printable_results["Overall"] = {
`
``
94
`+
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),
`
``
95
`+
"f1": round(all_ins_f1, 3),
`
``
96
`+
}
`
``
97
`+
print(printable_results)
`
``
98
`+
return printable_results["Overall"]["f1"]
`
``
99
+
``
100
+
``
101
`+
##################
`
``
102
`+
Helper functions written by official MMMU repo.
`
``
103
`+
##################
`
``
104
`+
DOMAINS = [
`
``
105
`+
'auto',
`
``
106
`+
'book',
`
``
107
`+
'camera',
`
``
108
`+
'game',
`
``
109
`+
'jobs',
`
``
110
`+
'movie',
`
``
111
`+
'phone',
`
``
112
`+
'restaurant',
`
``
113
`+
'sports',
`
``
114
`+
'university',
`
``
115
`+
'hotel',
`
``
116
`+
]
`
``
117
+
``
118
+
``
119
`+
def evaluate_websrc(samples):
`
``
120
+
``
121
`+
def _normalize_str(string):
`
``
122
`+
lower it
`
``
123
`+
string = string.lower()
`
``
124
+
``
125
`+
strip non-alphanumeric characters
`
``
126
`+
string = re.sub(r"[^a-zA-Z0-9]", "", string)
`
``
127
+
``
128
`+
strip leading and trailing whitespaces
`
``
129
`+
string = string.strip()
`
``
130
+
``
131
`+
return string
`
``
132
+
``
133
`+
judge_list = []
`
``
134
`+
for sample in samples:
`
``
135
`+
gold_i = set(_normalize_str(sample["answer"]))
`
``
136
`+
pred_i = set(_normalize_str( sample["parsed_pred"]))
`
``
137
`+
if len(pred_i) == 0:
`
``
138
`+
judge_list.append(0.0)
`
``
139
`+
continue
`
``
140
+
``
141
`+
comm_i = gold_i.intersection(pred_i)
`
``
142
`+
prec_i = len(comm_i) / len(pred_i)
`
``
143
`+
rec_i = len(comm_i) / len(gold_i)
`
``
144
`+
f1_i = 2 * prec_i * rec_i / (prec_i + rec_i) if prec_i + rec_i > 0 else 0
`
``
145
`+
judge_list.append(f1_i)
`
``
146
+
``
147
`+
f1 = np.mean(judge_list)
`
``
148
`+
return judge_list, {"f1": f1}
`