init include vcr · EvolvingLMMs-Lab/lmms-eval@96e8d98 (original) (raw)
``
1
`+
from collections import defaultdict
`
``
2
`+
import os
`
``
3
`+
from difflib import SequenceMatcher as SM
`
``
4
`+
import datetime
`
``
5
`+
import json
`
``
6
`+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
`
``
7
`+
import evaluate
`
``
8
`+
import logging
`
``
9
`+
import spacy
`
``
10
`+
from spacy.cli import download
`
``
11
`+
from nltk.util import ngrams
`
``
12
`+
from functools import partial
`
``
13
+
``
14
`+
Download the English and Chinese models
`
``
15
`+
download("en_core_web_sm")
`
``
16
`+
download("zh_core_web_sm")
`
``
17
+
``
18
`+
eval_logger = logging.getLogger("lmms-eval")
`
``
19
+
``
20
`+
dir_name = os.path.dirname(os.path.abspath(file))
`
``
21
+
``
22
`+
rouge = evaluate.load("rouge")
`
``
23
`+
nlp_en = spacy.load("en_core_web_sm")
`
``
24
`+
nlp_zh = spacy.load("zh_core_web_sm")
`
``
25
`+
nlp = {"en": nlp_en, "zh": nlp_zh}
`
``
26
+
``
27
`+
aggregate_results_template = {
`
``
28
`+
"max_sim_val": 0,
`
``
29
`+
"precision": 0,
`
``
30
`+
"recall": 0,
`
``
31
`+
"f1": 0,
`
``
32
`+
"jaccard": 0,
`
``
33
`+
"rouge1": 0,
`
``
34
`+
}
`
``
35
+
``
36
+
``
37
`+
def vcr_doc_to_visual(doc):
`
``
38
`+
return [doc["stacked_image"].convert("RGB"), doc["only_it_image"].convert("RGB")]
`
``
39
+
``
40
+
``
41
`+
def vcr_doc_to_text(doc, model_specific_prompt_kwargs=None):
`
``
42
`+
if "pre_prompt" in model_specific_prompt_kwargs:
`
``
43
`+
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
`
``
44
`+
if "post_prompt" in model_specific_prompt_kwargs:
`
``
45
`+
post_prompt = model_specific_prompt_kwargs["post_prompt"]
`
``
46
`+
return f"{pre_prompt}{post_prompt}"
`
``
47
+
``
48
+
``
49
`+
def tokenize(text, language):
`
``
50
`+
"""
`
``
51
`+
Tokenize the text and return the tokens.
`
``
52
+
``
53
`+
Parameters:
`
``
54
`+
text (str): The text to tokenize.
`
``
55
`+
language (str): The language of the text.
`
``
56
+
``
57
`+
Returns:
`
``
58
`+
list: The list of tokens.
`
``
59
`+
"""
`
``
60
`+
assert language in ["en", "zh"]
`
``
61
`+
nlp_lang = nlp[language]
`
``
62
`+
processed_text = nlp_lang(text)
`
``
63
`+
return [token.text for token in processed_text]
`
``
64
+
``
65
+
``
66
`+
def vcr_process_results_single(doc, result, language):
`
``
67
`+
"""
`
``
68
`+
Args:
`
``
69
`+
doc: a instance of the eval dataset
`
``
70
`+
results: [pred]
`
``
71
`+
Returns:
`
``
72
`+
a dictionary with key: metric name (in this case mme score), value: metric value
`
``
73
`+
"""
`
``
74
`+
assert language in ["en", "zh"], f"Language {language} is not supported."
`
``
75
`+
crossed_text = doc["crossed_text"]
`
``
76
`+
tokens_result = tokenize(result, language)
`
``
77
`+
tokens_crossed_text = tokenize(crossed_text, language)
`
``
78
+
``
79
`+
splitter = " " if language == "en" else ""
`
``
80
`+
ngrams_ = ngrams(tokens_result, len(tokens_crossed_text))
`
``
81
`+
max_sim_val = 0
`
``
82
`+
max_sim_string = ""
`
``
83
`+
max_sim_ngram = []
`
``
84
`+
tokens_crossed_text_set = set(tokens_crossed_text)
`
``
85
`+
ngrams_hasjoint = [
`
``
86
`+
ngram for ngram in ngrams_ if not set(ngram).isdisjoint(tokens_crossed_text_set)
`
``
87
`+
]
`
``
88
+
``
89
`+
for ngram in ngrams_hasjoint:
`
``
90
`+
result_ngram = splitter.join(ngram)
`
``
91
`+
similarity = SM(None, result_ngram, crossed_text).ratio()
`
``
92
`+
if similarity > max_sim_val:
`
``
93
`+
max_sim_val = similarity
`
``
94
`+
max_sim_string = result_ngram
`
``
95
`+
max_sim_ngram = ngram
`
``
96
+
``
97
`+
Evaluate
`
``
98
`+
if len(max_sim_ngram) == 0:
`
``
99
`+
return {
`
``
100
`+
"crossed_text": crossed_text,
`
``
101
`+
"max_sim_val": 0,
`
``
102
`+
"max_sim_string": "",
`
``
103
`+
"precision": 0,
`
``
104
`+
"recall": 0,
`
``
105
`+
"f1": 0,
`
``
106
`+
"jaccard": 0,
`
``
107
`+
"rouge1": 0,
`
``
108
`+
"exact_match": 0,
`
``
109
`+
}
`
``
110
`+
pred_set = set(max_sim_ngram)
`
``
111
`+
ref_set = set(tokens_crossed_text)
`
``
112
`+
correct_tokens = pred_set.intersection(ref_set)
`
``
113
`+
len_correct_tokens = len(correct_tokens)
`
``
114
+
``
115
`+
precision = len_correct_tokens / len(pred_set)
`
``
116
`+
recall = len_correct_tokens / len(ref_set)
`
``
117
`+
if (precision + recall) == 0:
`
``
118
`+
f1 = 0
`
``
119
`+
else:
`
``
120
`+
f1 = 2 * precision * recall / (precision + recall)
`
``
121
`+
union = pred_set.union(ref_set)
`
``
122
`+
jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0
`
``
123
`+
rouge_1 = rouge.compute(
`
``
124
`+
predictions=[max_sim_string],
`
``
125
`+
references=[crossed_text],
`
``
126
`+
tokenizer=partial(tokenize, language=language),
`
``
127
`+
rouge_types=["rouge1"],
`
``
128
`+
)["rouge1"]
`
``
129
`+
exact_match = float(list(max_sim_ngram) == list(tokens_crossed_text))
`
``
130
`+
out = {
`
``
131
`+
"crossed_text": crossed_text,
`
``
132
`+
"max_sim_string": max_sim_string,
`
``
133
`+
"max_sim_val": max_sim_val,
`
``
134
`+
"precision": precision,
`
``
135
`+
"recall": recall,
`
``
136
`+
"f1": f1,
`
``
137
`+
"jaccard": jaccard,
`
``
138
`+
"rouge1": rouge_1,
`
``
139
`+
"exact_match": exact_match,
`
``
140
`+
}
`
``
141
`+
return out
`
``
142
+
``
143
+
``
144
`+
def vcr_en_process_results(doc, results):
`
``
145
`+
"""
`
``
146
`+
Args:
`
``
147
`+
doc: a instance of the eval dataset
`
``
148
`+
results: [pred]
`
``
149
`+
Returns:
`
``
150
`+
a dictionary with key: metric name (in this case mme score), value: metric value
`
``
151
`+
"""
`
``
152
`+
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
`
``
153
`+
output = {
`
``
154
`+
"res_stacked_image": vcr_process_results_single(doc, results[0], "en"),
`
``
155
`+
"res_only_it_image": vcr_process_results_single(doc, results[1], "en"),
`
``
156
`+
}
`
``
157
`+
return output
`
``
158
+
``
159
+
``
160
`+
def vcr_zh_process_results(doc, results):
`
``
161
`+
"""
`
``
162
`+
Args:
`
``
163
`+
doc: a instance of the eval dataset
`
``
164
`+
results: [pred]
`
``
165
`+
Returns:
`
``
166
`+
a dictionary with key: metric name (in this case mme score), value: metric value
`
``
167
`+
"""
`
``
168
`+
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
`
``
169
`+
output = {
`
``
170
`+
"res_stacked_image": vcr_process_results_single(doc, results[0], "zh"),
`
``
171
`+
"res_only_it_image": vcr_process_results_single(doc, results[1], "zh"),
`
``
172
`+
}
`
``
173
`+
return output
`
``
174
+
``
175
+
``
176
`+
def vcr_aggregate_results(results):
`
``
177
`+
"""
`
``
178
`+
Args:
`
``
179
`+
results: a list of values returned by process_results
`
``
180
`+
Returns:
`
``
181
`+
A dictionary of dictionary of float, where the outer dictionary has keys "res_stacked_image" and "res_only_it_image"
`
``
182
`+
"""
`
``
183
+
``
184
`+
output = {
`
``
185
`+
"res_stacked_image": {
`
``
186
`+
"max_sim_val": 0,
`
``
187
`+
"precision": 0,
`
``
188
`+
"recall": 0,
`
``
189
`+
"f1": 0,
`
``
190
`+
"jaccard": 0,
`
``
191
`+
"rouge1": 0,
`
``
192
`+
},
`
``
193
`+
"res_only_it_image": {
`
``
194
`+
"max_sim_val": 0,
`
``
195
`+
"precision": 0,
`
``
196
`+
"recall": 0,
`
``
197
`+
"f1": 0,
`
``
198
`+
"jaccard": 0,
`
``
199
`+
"rouge1": 0,
`
``
200
`+
},
`
``
201
`+
}
`
``
202
`+
for target_domain in output.keys():
`
``
203
`+
for target_metric_name in output[target_domain].keys():
`
``
204
`+
score = 0
`
``
205
`+
count = 0
`
``
206
`+
for inner_dict in results:
`
``
207
`+
for inner_key, inner_value in inner_dict.items():
`
``
208
`+
if inner_key == target_domain:
`
``
209
`+
for blank_id, blank_metrics in inner_value.items():
`
``
210
`+
for metric_name, metric_value in blank_metrics.items():
`
``
211
`+
if metric_name == target_metric_name:
`
``
212
`+
score += metric_value
`
``
213
`+
count += 1
`
``
214
`+
output[target_domain][target_metric_name] = score / count
`
``
215
`+
return output
`