Init webSRC · EvolvingLMMs-Lab/lmms-eval@955bd06 (original) (raw)
``
1
`+
from collections import defaultdict
`
``
2
`+
import re
`
``
3
`+
import ast
`
``
4
`+
import base64
`
``
5
`+
import io
`
``
6
`+
import random
`
``
7
`+
import numpy as np
`
``
8
`+
import os
`
``
9
`+
import json
`
``
10
`+
import logging
`
``
11
`+
from PIL import Image
`
``
12
+
``
13
`+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
`
``
14
+
``
15
`+
lmms_logger = logging.getLogger("lmms-eval")
`
``
16
+
``
17
`+
OPEN_ENDED_PROMPT = "Answer the question using a single word or phrase."
`
``
18
+
``
19
+
``
20
`+
def construct_prompt(doc):
`
``
21
`+
question = doc["question"]
`
``
22
`+
question = f"{question}\n{OPEN_ENDED_PROMPT}"
`
``
23
`+
question = f"{OPEN_ENDED_PROMPT}\n{question}"
`
``
24
`+
return question
`
``
25
+
``
26
+
``
27
`+
def websrc_doc_to_text(doc):
`
``
28
`+
question = construct_prompt(doc)
`
``
29
`+
return question
`
``
30
+
``
31
+
``
32
`+
def websrc_doc_to_visual(doc):
`
``
33
`+
img_bs64 = doc["image"]
`
``
34
`+
img = Image.open(io.BytesIO(base64.b64decode(img_bs64)))
`
``
35
`+
del doc['image']
`
``
36
`+
return [img]
`
``
37
+
``
38
+
``
39
`+
def websrc_process_results(doc, results):
`
``
40
`+
pred = results[0]
`
``
41
`+
parsed_pred = pred
`
``
42
`+
id = doc["page_id"]
`
``
43
`+
websrc_ans = {"id": id, "domain": doc['domain'], "answer": doc["answer"], "parsed_pred": parsed_pred}
`
``
44
`+
return {
`
``
45
`+
"websrc_squad_f1": websrc_ans,
`
``
46
`+
"submission": {
`
``
47
`+
id: pred,
`
``
48
`+
},
`
``
49
`+
}
`
``
50
+
``
51
+
``
52
`+
def websrc_test_aggregate_results_for_submission(results, args):
`
``
53
`+
path = generate_submission_file("websrc_test_for_submission.json", args)
`
``
54
`+
with open(path, "w") as f:
`
``
55
`+
json.dump(results, f)
`
``
56
`+
lmms_logger.info(f"Results saved to {path}.")
`
``
57
+
``
58
+
``
59
`+
def websrc_aggregate_results(results):
`
``
60
`+
evaluation_result = {}
`
``
61
+
``
62
`+
Group results by domain
`
``
63
`+
subset_to_eval_samples = defaultdict(list)
`
``
64
`+
for result in results:
`
``
65
`+
subset_to_eval_samples[result["domain"]].append(result)
`
``
66
+
``
67
`+
Evaluate each domain
`
``
68
`+
for subset, sub_eval_samples in subset_to_eval_samples.items():
`
``
69
`+
judge_dict, metric_dict = evaluate_websrc(sub_eval_samples)
`
``
70
`+
metric_dict.update({"num_example": len(sub_eval_samples)})
`
``
71
`+
evaluation_result[subset] = metric_dict
`
``
72
+
``
73
`+
Aggregate results for all domains
`
``
74
`+
printable_results = {}
`
``
75
`+
for domain in DOMAINS:
`
``
76
`+
if domain not in evaluation_result:
`
``
77
`+
continue
`
``
78
`+
printable_results[domain] = {
`
``
79
`+
"num": int(evaluation_result[domain]["num_example"]),
`
``
80
`+
"f1": round(evaluation_result[domain]["f1"], 3),
`
``
81
`+
}
`
``
82
`+
all_ins_f1 = np.sum([cat_results["f1"] * cat_results["num_example"] for cat_results in evaluation_result.values()]) / sum(
`
``
83
`+
[cat_results["num_example"] for cat_results in evaluation_result.values()]
`
``
84
`+
)
`
``
85
`+
printable_results["Overall"] = {
`
``
86
`+
"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),
`
``
87
`+
"f1": round(all_ins_f1, 3),
`
``
88
`+
}
`
``
89
`+
print(printable_results)
`
``
90
`+
return printable_results["Overall"]["f1"]
`
``
91
+
``
92
+
``
93
`+
##################
`
``
94
`+
Helper functions written by official MMMU repo.
`
``
95
`+
##################
`
``
96
`+
DOMAINS = [
`
``
97
`+
'auto',
`
``
98
`+
'book',
`
``
99
`+
'camera',
`
``
100
`+
'game',
`
``
101
`+
'jobs',
`
``
102
`+
'movie',
`
``
103
`+
'phone',
`
``
104
`+
'restaurant',
`
``
105
`+
'sports',
`
``
106
`+
'university',
`
``
107
`+
'hotel',
`
``
108
`+
]
`
``
109
+
``
110
+
``
111
`+
def evaluate_websrc(samples):
`
``
112
+
``
113
`+
def _normalize_str(string):
`
``
114
`+
lower it
`
``
115
`+
string = string.lower()
`
``
116
+
``
117
`+
strip non-alphanumeric characters
`
``
118
`+
string = re.sub(r"[^a-zA-Z0-9]", "", string)
`
``
119
+
``
120
`+
strip leading and trailing whitespaces
`
``
121
`+
string = string.strip()
`
``
122
+
``
123
`+
return string
`
``
124
+
``
125
`+
judge_list = []
`
``
126
`+
for sample in samples:
`
``
127
`+
gold_i = set(_normalize_str(sample["answer"]))
`
``
128
`+
pred_i = set(_normalize_str( sample["parsed_pred"]))
`
``
129
`+
if len(pred_i) == 0:
`
``
130
`+
judge_list.append(0.0)
`
``
131
`+
continue
`
``
132
+
``
133
`+
comm_i = gold_i.intersection(pred_i)
`
``
134
`+
prec_i = len(comm_i) / len(pred_i)
`
``
135
`+
rec_i = len(comm_i) / len(gold_i)
`
``
136
`+
f1_i = 2 * prec_i * rec_i / (prec_i + rec_i) if prec_i + rec_i > 0 else 0
`
``
137
`+
judge_list.append(f1_i)
`
``
138
+
``
139
`+
f1 = np.mean(judge_list)
`
``
140
`+
return judge_list, {"f1": f1}
`