update aggregation function for vcr_wiki · EvolvingLMMs-Lab/lmms-eval@47b13b9 (original) (raw)
`@@ -6,10 +6,11 @@
`
6
6
`from functools import partial
`
7
7
``
8
8
`import evaluate
`
``
9
`+
import numpy as np
`
9
10
`import spacy
`
10
11
`from nltk.util import ngrams
`
11
12
`from spacy.cli import download
`
12
``
`-
import numpy as np
`
``
13
+
13
14
`from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
`
14
15
``
15
16
`# Download the English and Chinese models
`
`@@ -262,7 +263,7 @@ def bootstrap_std(data, n_bootstrap=1000, ci=0.95):
`
262
263
`return std, lower_bound, upper_bound
`
263
264
``
264
265
``
265
``
`-
def vcr_aggregate_results(results, args):
`
``
266
`+
def vcr_aggregate_results(results, args, metric='exact_match'):
`
266
267
`"""
`
267
268
` Args:
`
268
269
` results: List[List[Dict]], list of results returned by process_results
`
`@@ -285,9 +286,17 @@ def vcr_aggregate_results(results, args):
`
285
286
`"detailed_results": output_dict_detail_result,
`
286
287
` }
`
287
288
`now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
`
288
``
`-
path = generate_submission_file(f"vcr_submission_{now_date_time}.json", args)
`
``
289
`+
path = generate_submission_file(f"vcr_submission_{metric}_{now_date_time}.json", args)
`
289
290
`with open(path, "w", encoding="utf-8") as f:
`
290
291
`json.dump(output_dict, f, indent=4, ensure_ascii=False)
`
291
292
`# print(f"Submission file saved to {path}")
`
292
293
`eval_logger.info(f"Submission file saved to {path}")
`
293
294
`return mean_score
`
``
295
+
``
296
+
``
297
`+
def vcr_aggregate_exact_match(results, args):
`
``
298
`+
return vcr_aggregate_results(results, args, metric='exact_match')
`
``
299
+
``
300
+
``
301
`+
def vcr_aggregate_jaccard(results, args):
`
``
302
`+
return vcr_aggregate_results(results, args, metric='jaccard')
`