Merge pull request #69 from hunterheiden/hsh/new_task/WebSRC · dadwadw233/lmms-eval@dac58a8 (original) (raw)

``

1

`+

from collections import defaultdict

`

``

2

`+

import re

`

``

3

`+

import ast

`

``

4

`+

import base64

`

``

5

`+

import io

`

``

6

`+

import random

`

``

7

`+

import numpy as np

`

``

8

`+

import os

`

``

9

`+

import json

`

``

10

`+

import logging

`

``

11

`+

from PIL import Image

`

``

12

+

``

13

`+

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

`

``

14

+

``

15

`+

lmms_logger = logging.getLogger("lmms-eval")

`

``

16

+

``

17

`+

OPEN_ENDED_PROMPT = "Answer the question using a single word or phrase."

`

``

18

+

``

19

+

``

20

`+

def construct_prompt(doc):

`

``

21

`+

question = doc["question"]

`

``

22

`+

question = f"{OPEN_ENDED_PROMPT}\n{question}"

`

``

23

`+

return question

`

``

24

+

``

25

+

``

26

`+

def websrc_doc_to_text(doc):

`

``

27

`+

question = construct_prompt(doc)

`

``

28

`+

return question

`

``

29

+

``

30

+

``

31

`+

def websrc_doc_to_visual(doc):

`

``

32

`+

img_bs64 = doc["image"]

`

``

33

`+

img = Image.open(io.BytesIO(base64.b64decode(img_bs64)))

`

``

34

`+

del doc['image']

`

``

35

`+

return [img]

`

``

36

+

``

37

+

``

38

`+

def websrc_process_results(doc, results):

`

``

39

`+

pred = results[0]

`

``

40

`+

parsed_pred = pred

`

``

41

`+

id = doc["page_id"]

`

``

42

`+

websrc_ans = {"id": id, "domain": doc['domain'], "parsed_pred": parsed_pred}

`

``

43

`+

if "answer" in doc:

`

``

44

`+

websrc_ans["answer"] = doc["answer"]

`

``

45

+

``

46

`+

if 'id' in doc:

`

``

47

`+

websrc_ans['question_id'] = doc['id']

`

``

48

+

``

49

`+

return {

`

``

50

`+

"websrc_squad_f1": websrc_ans,

`

``

51

`+

"submission": {

`

``

52

`+

websrc_ans['question_id']: pred,

`

``

53

`+

},

`

``

54

`+

}

`

``

55

+

``

56

+

``

57

`+

def websrc_test_aggregate_results_for_submission(results, args):

`

``

58

`+

path = generate_submission_file("websrc_test_for_submission.json", args)

`

``

59

`+

with open(path, "w") as f:

`

``

60

`+

out = {}

`

``

61

`+

for result in results:

`

``

62

`+

out.update(result)

`

``

63

`+

json.dump(out, f, indent=4)

`

``

64

`+

lmms_logger.info(f"Results saved to {path}.")

`

``

65

+

``

66

+

``

67

`+

def websrc_aggregate_results(results):

`

``

68

`+

evaluation_result = {}

`

``

69

+

``

70

`+

Group results by domain

`

``

71

`+

subset_to_eval_samples = defaultdict(list)

`

``

72

`+

for result in results:

`

``

73

`+

subset_to_eval_samples[result["domain"]].append(result)

`

``

74

+

``

75

`+

Evaluate each domain

`

``

76

`+

for subset, sub_eval_samples in subset_to_eval_samples.items():

`

``

77

`+

judge_dict, metric_dict = evaluate_websrc(sub_eval_samples)

`

``

78

`+

metric_dict.update({"num_example": len(sub_eval_samples)})

`

``

79

`+

evaluation_result[subset] = metric_dict

`

``

80

+

``

81

`+

Aggregate results for all domains

`

``

82

`+

printable_results = {}

`

``

83

`+

for domain in DOMAINS:

`

``

84

`+

if domain not in evaluation_result:

`

``

85

`+

continue

`

``

86

`+

printable_results[domain] = {

`

``

87

`+

"num": int(evaluation_result[domain]["num_example"]),

`

``

88

`+

"f1": round(evaluation_result[domain]["f1"], 3),

`

``

89

`+

}

`

``

90

`+

all_ins_f1 = np.sum([cat_results["f1"] * cat_results["num_example"] for cat_results in evaluation_result.values()]) / sum(

`

``

91

`+

[cat_results["num_example"] for cat_results in evaluation_result.values()]

`

``

92

`+

)

`

``

93

`+

printable_results["Overall"] = {

`

``

94

`+

"num": sum([cat_results["num_example"] for cat_results in evaluation_result.values()]),

`

``

95

`+

"f1": round(all_ins_f1, 3),

`

``

96

`+

}

`

``

97

`+

print(printable_results)

`

``

98

`+

return printable_results["Overall"]["f1"]

`

``

99

+

``

100

+

``

101

`+

##################

`

``

102

`+

Helper functions written by official MMMU repo.

`

``

103

`+

##################

`

``

104

`+

DOMAINS = [

`

``

105

`+

'auto',

`

``

106

`+

'book',

`

``

107

`+

'camera',

`

``

108

`+

'game',

`

``

109

`+

'jobs',

`

``

110

`+

'movie',

`

``

111

`+

'phone',

`

``

112

`+

'restaurant',

`

``

113

`+

'sports',

`

``

114

`+

'university',

`

``

115

`+

'hotel',

`

``

116

`+

]

`

``

117

+

``

118

+

``

119

`+

def evaluate_websrc(samples):

`

``

120

+

``

121

`+

def _normalize_str(string):

`

``

122

`+

lower it

`

``

123

`+

string = string.lower()

`

``

124

+

``

125

`+

strip non-alphanumeric characters

`

``

126

`+

string = re.sub(r"[^a-zA-Z0-9]", "", string)

`

``

127

+

``

128

`+

strip leading and trailing whitespaces

`

``

129

`+

string = string.strip()

`

``

130

+

``

131

`+

return string

`

``

132

+

``

133

`+

judge_list = []

`

``

134

`+

for sample in samples:

`

``

135

`+

gold_i = set(_normalize_str(sample["answer"]))

`

``

136

`+

pred_i = set(_normalize_str( sample["parsed_pred"]))

`

``

137

`+

if len(pred_i) == 0:

`

``

138

`+

judge_list.append(0.0)

`

``

139

`+

continue

`

``

140

+

``

141

`+

comm_i = gold_i.intersection(pred_i)

`

``

142

`+

prec_i = len(comm_i) / len(pred_i)

`

``

143

`+

rec_i = len(comm_i) / len(gold_i)

`

``

144

`+

f1_i = 2 * prec_i * rec_i / (prec_i + rec_i) if prec_i + rec_i > 0 else 0

`

``

145

`+

judge_list.append(f1_i)

`

``

146

+

``

147

`+

f1 = np.mean(judge_list)

`

``

148

`+

return judge_list, {"f1": f1}

`