Merge pull request #100 from Gumpest/main · EvolvingLMMs-Lab/lmms-eval@67b64ea (original) (raw)
``
1
`+
from collections import defaultdict
`
``
2
`+
import os
`
``
3
`+
from anls import anls_score
`
``
4
+
``
5
`+
import logging
`
``
6
+
``
7
`+
eval_logger = logging.getLogger("lmms-eval")
`
``
8
+
``
9
`+
dir_name = os.path.dirname(os.path.abspath(file))
`
``
10
+
``
11
`+
19 classes
`
``
12
`+
eval_type_dict = {
`
``
13
`+
"Sensation": ["count","color", "scene", "poster", "attribute_recognition", "ocr", "position"],
`
``
14
`+
"Cognition": ["calculation", "code", "translation", "math", "cross_instance_reason", "attribute_reason"],
`
``
15
`+
"Knowledge": ["celebrity", "chemistry", "physics", "biology", "landmark", "artwork"]
`
``
16
`+
}
`
``
17
+
``
18
+
``
19
`+
def conbench_doc_to_visual(doc):
`
``
20
`+
return [doc["image"].convert("RGB")]
`
``
21
+
``
22
+
``
23
`+
def conbench_doc_to_text(doc):
`
``
24
`+
question = doc["question"].strip()
`
``
25
`+
return question
`
``
26
+
``
27
+
``
28
`+
def parse_pred_ans_NY(pred_ans):
`
``
29
`+
pred_label = None
`
``
30
`+
if pred_ans in ["yes", "no"]:
`
``
31
`+
pred_label = pred_ans
`
``
32
`+
else:
`
``
33
`+
prefix_pred_ans = pred_ans[:4]
`
``
34
+
``
35
`+
if "yes" in prefix_pred_ans:
`
``
36
`+
pred_label = "yes"
`
``
37
`+
elif "no" in prefix_pred_ans:
`
``
38
`+
pred_label = "no"
`
``
39
`+
else:
`
``
40
`+
pred_label = "other"
`
``
41
`+
return pred_label
`
``
42
+
``
43
+
``
44
`+
def parse_pred_ans_choice(pred_ans):
`
``
45
`+
return pred_ans.replace(" ", "")[0]
`
``
46
+
``
47
+
``
48
`+
def conbench_process_results(doc, results):
`
``
49
`+
"""
`
``
50
`+
Args:
`
``
51
`+
doc: a instance of the eval dataset
`
``
52
`+
results: [pred]
`
``
53
`+
Returns:
`
``
54
`+
a dictionary with key: metric name (in this case mme score), value: metric value
`
``
55
`+
"""
`
``
56
`+
pred = results[0]
`
``
57
`+
pred = pred.replace('\n', "").lower()
`
``
58
`+
parser
`
``
59
`+
if doc["question_field"] == "N/Y":
`
``
60
`+
pred_ans = parse_pred_ans_NY(pred)
`
``
61
`+
elif doc["question_field"] == "Choices":
`
``
62
`+
pred_ans = parse_pred_ans_choice(pred)
`
``
63
`+
else:
`
``
64
`+
pred_ans = pred
`
``
65
+
``
66
`+
gt_ans = doc["answer"].lower()
`
``
67
+
``
68
`+
score
`
``
69
`+
score = 1 if (doc["question_field"] == "Q/A" and anls_score(prediction=pred_ans, gold_labels=[gt_ans], threshold=0.95) >= 0.4) \
`
``
70
`+
or (gt_ans == pred_ans) \
`
``
71
`+
else 0
`
``
72
`+
Note: the key name here is very important. It decides which aggregation function will receive the results
`
``
73
`+
We note down the question id/category to help us aggregate the results later
`
``
74
`+
return {"ConScore_D":{"image_id": doc["image_id"], "question_field": doc["question_field"], "score": score}}
`
``
75
+
``
76
+
``
77
`+
def conbench_aggregate_results(results):
`
``
78
`+
"""
`
``
79
`+
Args:
`
``
80
`+
results: a list of values returned by process_results
`
``
81
`+
Returns:
`
``
82
`+
A score
`
``
83
`+
"""
`
``
84
`+
summary = defaultdict(dict)
`
``
85
`+
for result in results:
`
``
86
`+
image_id = result["image_id"]
`
``
87
`+
score = result["score"]
`
``
88
`+
if image_id not in summary.keys():
`
``
89
`+
summary[image_id] = 0
`
``
90
`+
summary[image_id] += score
`
``
91
+
``
92
`+
cnt_con = 0
`
``
93
`+
for image_id, score in summary.items():
`
``
94
`+
if score == 3:
`
``
95
`+
cnt_con += 1
`
``
96
+
``
97
`+
print("Consistency Cases are ", cnt_con)
`
``
98
`+
cnt_con = cnt_con / (len(results) / 3)
`
``
99
`+
eval_logger.info(f"ConScore_D: {cnt_con:.2f}")
`
``
100
`+
return cnt_con
`