init include vcr · EvolvingLMMs-Lab/lmms-eval@96e8d98 (original) (raw)

``

1

`+

from collections import defaultdict

`

``

2

`+

import os

`

``

3

`+

from difflib import SequenceMatcher as SM

`

``

4

`+

import datetime

`

``

5

`+

import json

`

``

6

`+

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

`

``

7

`+

import evaluate

`

``

8

`+

import logging

`

``

9

`+

import spacy

`

``

10

`+

from spacy.cli import download

`

``

11

`+

from nltk.util import ngrams

`

``

12

`+

from functools import partial

`

``

13

+

``

14

`+

Download the English and Chinese models

`

``

15

`+

download("en_core_web_sm")

`

``

16

`+

download("zh_core_web_sm")

`

``

17

+

``

18

`+

eval_logger = logging.getLogger("lmms-eval")

`

``

19

+

``

20

`+

dir_name = os.path.dirname(os.path.abspath(file))

`

``

21

+

``

22

`+

rouge = evaluate.load("rouge")

`

``

23

`+

nlp_en = spacy.load("en_core_web_sm")

`

``

24

`+

nlp_zh = spacy.load("zh_core_web_sm")

`

``

25

`+

nlp = {"en": nlp_en, "zh": nlp_zh}

`

``

26

+

``

27

`+

aggregate_results_template = {

`

``

28

`+

"max_sim_val": 0,

`

``

29

`+

"precision": 0,

`

``

30

`+

"recall": 0,

`

``

31

`+

"f1": 0,

`

``

32

`+

"jaccard": 0,

`

``

33

`+

"rouge1": 0,

`

``

34

`+

}

`

``

35

+

``

36

+

``

37

`+

def vcr_doc_to_visual(doc):

`

``

38

`+

return [doc["stacked_image"].convert("RGB"), doc["only_it_image"].convert("RGB")]

`

``

39

+

``

40

+

``

41

`+

def vcr_doc_to_text(doc, model_specific_prompt_kwargs=None):

`

``

42

`+

if "pre_prompt" in model_specific_prompt_kwargs:

`

``

43

`+

pre_prompt = model_specific_prompt_kwargs["pre_prompt"]

`

``

44

`+

if "post_prompt" in model_specific_prompt_kwargs:

`

``

45

`+

post_prompt = model_specific_prompt_kwargs["post_prompt"]

`

``

46

`+

return f"{pre_prompt}{post_prompt}"

`

``

47

+

``

48

+

``

49

`+

def tokenize(text, language):

`

``

50

`+

"""

`

``

51

`+

Tokenize the text and return the tokens.

`

``

52

+

``

53

`+

Parameters:

`

``

54

`+

text (str): The text to tokenize.

`

``

55

`+

language (str): The language of the text.

`

``

56

+

``

57

`+

Returns:

`

``

58

`+

list: The list of tokens.

`

``

59

`+

"""

`

``

60

`+

assert language in ["en", "zh"]

`

``

61

`+

nlp_lang = nlp[language]

`

``

62

`+

processed_text = nlp_lang(text)

`

``

63

`+

return [token.text for token in processed_text]

`

``

64

+

``

65

+

``

66

`+

def vcr_process_results_single(doc, result, language):

`

``

67

`+

"""

`

``

68

`+

Args:

`

``

69

`+

doc: a instance of the eval dataset

`

``

70

`+

results: [pred]

`

``

71

`+

Returns:

`

``

72

`+

a dictionary with key: metric name (in this case mme score), value: metric value

`

``

73

`+

"""

`

``

74

`+

assert language in ["en", "zh"], f"Language {language} is not supported."

`

``

75

`+

crossed_text = doc["crossed_text"]

`

``

76

`+

tokens_result = tokenize(result, language)

`

``

77

`+

tokens_crossed_text = tokenize(crossed_text, language)

`

``

78

+

``

79

`+

splitter = " " if language == "en" else ""

`

``

80

`+

ngrams_ = ngrams(tokens_result, len(tokens_crossed_text))

`

``

81

`+

max_sim_val = 0

`

``

82

`+

max_sim_string = ""

`

``

83

`+

max_sim_ngram = []

`

``

84

`+

tokens_crossed_text_set = set(tokens_crossed_text)

`

``

85

`+

ngrams_hasjoint = [

`

``

86

`+

ngram for ngram in ngrams_ if not set(ngram).isdisjoint(tokens_crossed_text_set)

`

``

87

`+

]

`

``

88

+

``

89

`+

for ngram in ngrams_hasjoint:

`

``

90

`+

result_ngram = splitter.join(ngram)

`

``

91

`+

similarity = SM(None, result_ngram, crossed_text).ratio()

`

``

92

`+

if similarity > max_sim_val:

`

``

93

`+

max_sim_val = similarity

`

``

94

`+

max_sim_string = result_ngram

`

``

95

`+

max_sim_ngram = ngram

`

``

96

+

``

97

`+

Evaluate

`

``

98

`+

if len(max_sim_ngram) == 0:

`

``

99

`+

return {

`

``

100

`+

"crossed_text": crossed_text,

`

``

101

`+

"max_sim_val": 0,

`

``

102

`+

"max_sim_string": "",

`

``

103

`+

"precision": 0,

`

``

104

`+

"recall": 0,

`

``

105

`+

"f1": 0,

`

``

106

`+

"jaccard": 0,

`

``

107

`+

"rouge1": 0,

`

``

108

`+

"exact_match": 0,

`

``

109

`+

}

`

``

110

`+

pred_set = set(max_sim_ngram)

`

``

111

`+

ref_set = set(tokens_crossed_text)

`

``

112

`+

correct_tokens = pred_set.intersection(ref_set)

`

``

113

`+

len_correct_tokens = len(correct_tokens)

`

``

114

+

``

115

`+

precision = len_correct_tokens / len(pred_set)

`

``

116

`+

recall = len_correct_tokens / len(ref_set)

`

``

117

`+

if (precision + recall) == 0:

`

``

118

`+

f1 = 0

`

``

119

`+

else:

`

``

120

`+

f1 = 2 * precision * recall / (precision + recall)

`

``

121

`+

union = pred_set.union(ref_set)

`

``

122

`+

jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0

`

``

123

`+

rouge_1 = rouge.compute(

`

``

124

`+

predictions=[max_sim_string],

`

``

125

`+

references=[crossed_text],

`

``

126

`+

tokenizer=partial(tokenize, language=language),

`

``

127

`+

rouge_types=["rouge1"],

`

``

128

`+

)["rouge1"]

`

``

129

`+

exact_match = float(list(max_sim_ngram) == list(tokens_crossed_text))

`

``

130

`+

out = {

`

``

131

`+

"crossed_text": crossed_text,

`

``

132

`+

"max_sim_string": max_sim_string,

`

``

133

`+

"max_sim_val": max_sim_val,

`

``

134

`+

"precision": precision,

`

``

135

`+

"recall": recall,

`

``

136

`+

"f1": f1,

`

``

137

`+

"jaccard": jaccard,

`

``

138

`+

"rouge1": rouge_1,

`

``

139

`+

"exact_match": exact_match,

`

``

140

`+

}

`

``

141

`+

return out

`

``

142

+

``

143

+

``

144

`+

def vcr_en_process_results(doc, results):

`

``

145

`+

"""

`

``

146

`+

Args:

`

``

147

`+

doc: a instance of the eval dataset

`

``

148

`+

results: [pred]

`

``

149

`+

Returns:

`

``

150

`+

a dictionary with key: metric name (in this case mme score), value: metric value

`

``

151

`+

"""

`

``

152

`+

assert len(results) == 2, f"Expected 2 results, got {len(results)}"

`

``

153

`+

output = {

`

``

154

`+

"res_stacked_image": vcr_process_results_single(doc, results[0], "en"),

`

``

155

`+

"res_only_it_image": vcr_process_results_single(doc, results[1], "en"),

`

``

156

`+

}

`

``

157

`+

return output

`

``

158

+

``

159

+

``

160

`+

def vcr_zh_process_results(doc, results):

`

``

161

`+

"""

`

``

162

`+

Args:

`

``

163

`+

doc: a instance of the eval dataset

`

``

164

`+

results: [pred]

`

``

165

`+

Returns:

`

``

166

`+

a dictionary with key: metric name (in this case mme score), value: metric value

`

``

167

`+

"""

`

``

168

`+

assert len(results) == 2, f"Expected 2 results, got {len(results)}"

`

``

169

`+

output = {

`

``

170

`+

"res_stacked_image": vcr_process_results_single(doc, results[0], "zh"),

`

``

171

`+

"res_only_it_image": vcr_process_results_single(doc, results[1], "zh"),

`

``

172

`+

}

`

``

173

`+

return output

`

``

174

+

``

175

+

``

176

`+

def vcr_aggregate_results(results):

`

``

177

`+

"""

`

``

178

`+

Args:

`

``

179

`+

results: a list of values returned by process_results

`

``

180

`+

Returns:

`

``

181

`+

A dictionary of dictionary of float, where the outer dictionary has keys "res_stacked_image" and "res_only_it_image"

`

``

182

`+

"""

`

``

183

+

``

184

`+

output = {

`

``

185

`+

"res_stacked_image": {

`

``

186

`+

"max_sim_val": 0,

`

``

187

`+

"precision": 0,

`

``

188

`+

"recall": 0,

`

``

189

`+

"f1": 0,

`

``

190

`+

"jaccard": 0,

`

``

191

`+

"rouge1": 0,

`

``

192

`+

},

`

``

193

`+

"res_only_it_image": {

`

``

194

`+

"max_sim_val": 0,

`

``

195

`+

"precision": 0,

`

``

196

`+

"recall": 0,

`

``

197

`+

"f1": 0,

`

``

198

`+

"jaccard": 0,

`

``

199

`+

"rouge1": 0,

`

``

200

`+

},

`

``

201

`+

}

`

``

202

`+

for target_domain in output.keys():

`

``

203

`+

for target_metric_name in output[target_domain].keys():

`

``

204

`+

score = 0

`

``

205

`+

count = 0

`

``

206

`+

for inner_dict in results:

`

``

207

`+

for inner_key, inner_value in inner_dict.items():

`

``

208

`+

if inner_key == target_domain:

`

``

209

`+

for blank_id, blank_metrics in inner_value.items():

`

``

210

`+

for metric_name, metric_value in blank_metrics.items():

`

``

211

`+

if metric_name == target_metric_name:

`

``

212

`+

score += metric_value

`

``

213

`+

count += 1

`

``

214

`+

output[target_domain][target_metric_name] = score / count

`

``

215

`+

return output

`