Bugfix: WebSRC should be token-level F1 NOT character-level · dadwadw233/lmms-eval@0a6b210 (original) (raw)

Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ def websrc_process_results(doc, results):
50 50 "websrc_squad_f1": websrc_ans,
51 51 "submission": {
52 52 websrc_ans['question_id']: pred,
53 - },
53 + } if 'question_id' in websrc_ans else None
54 54 }
55 55
56 56
@@ -122,27 +122,39 @@ def _normalize_str(string):
122 122 # lower it
123 123 string = string.lower()
124 124
125 -# strip non-alphanumeric characters
126 -string = re.sub(r"[^a-zA-Z0-9]", "", string)
127 -
128 125 # strip leading and trailing whitespaces
129 126 string = string.strip()
130 127
131 128 return string
132 129
130 +def _tokenize(text):
131 +# Regex pattern to match words and isolate punctuation
132 +pattern = r'\w+|[^\w\s]'
133 +tokens = re.findall(pattern, text)
134 +return tokens
135 +
136 +def _compute_f1(sa, sb):
137 +sa = _normalize_str(sa)
138 +sb = _normalize_str(sb)
139 +
140 +sa = _tokenize(sa)
141 +sb = _tokenize(sb)
142 +
143 +sa = set(sa)
144 +sb = set(sb)
145 +
146 +if len(sa) == 0 or len(sb) == 0:
147 +return 0.0
148 +
149 +comm = sa.intersection(sb)
150 +prec = len(comm) / len(sb)
151 +rec = len(comm) / len(sa)
152 +f1 = 2 * prec * rec / (prec + rec) if prec + rec > 0 else 0
153 +return f1
154 +
133 155 judge_list = []
134 156 for sample in samples:
135 -gold_i = set(_normalize_str(sample["answer"]))
136 -pred_i = set(_normalize_str( sample["parsed_pred"]))
137 -if len(pred_i) == 0:
138 -judge_list.append(0.0)
139 -continue
140 -
141 -comm_i = gold_i.intersection(pred_i)
142 -prec_i = len(comm_i) / len(pred_i)
143 -rec_i = len(comm_i) / len(gold_i)
144 -f1_i = 2 * prec_i * rec_i / (prec_i + rec_i) if prec_i + rec_i > 0 else 0
145 -judge_list.append(f1_i)
157 +judge_list.append(_compute_f1(sample["answer"], sample["parsed_pred"]))
146 158
147 159 f1 = np.mean(judge_list)
148 160 return judge_list, {"f1": f1}