update aggregation function for vcr_wiki · EvolvingLMMs-Lab/lmms-eval@47b13b9 (original) (raw)

`@@ -6,10 +6,11 @@

`

6

6

`from functools import partial

`

7

7

``

8

8

`import evaluate

`

``

9

`+

import numpy as np

`

9

10

`import spacy

`

10

11

`from nltk.util import ngrams

`

11

12

`from spacy.cli import download

`

12

``

`-

import numpy as np

`

``

13

+

13

14

`from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

`

14

15

``

15

16

`# Download the English and Chinese models

`

`@@ -262,7 +263,7 @@ def bootstrap_std(data, n_bootstrap=1000, ci=0.95):

`

262

263

`return std, lower_bound, upper_bound

`

263

264

``

264

265

``

265

``

`-

def vcr_aggregate_results(results, args):

`

``

266

`+

def vcr_aggregate_results(results, args, metric='exact_match'):

`

266

267

`"""

`

267

268

` Args:

`

268

269

` results: List[List[Dict]], list of results returned by process_results

`

`@@ -285,9 +286,17 @@ def vcr_aggregate_results(results, args):

`

285

286

`"detailed_results": output_dict_detail_result,

`

286

287

` }

`

287

288

`now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

`

288

``

`-

path = generate_submission_file(f"vcr_submission_{now_date_time}.json", args)

`

``

289

`+

path = generate_submission_file(f"vcr_submission_{metric}_{now_date_time}.json", args)

`

289

290

`with open(path, "w", encoding="utf-8") as f:

`

290

291

`json.dump(output_dict, f, indent=4, ensure_ascii=False)

`

291

292

`# print(f"Submission file saved to {path}")

`

292

293

`eval_logger.info(f"Submission file saved to {path}")

`

293

294

`return mean_score

`

``

295

+

``

296

+

``

297

`+

def vcr_aggregate_exact_match(results, args):

`

``

298

`+

return vcr_aggregate_results(results, args, metric='exact_match')

`

``

299

+

``

300

+

``

301

`+

def vcr_aggregate_jaccard(results, args):

`

``

302

`+

return vcr_aggregate_results(results, args, metric='jaccard')

`