Spaces:
AIR-Bench
/
Running on CPU Upgrade

leaderboard / src /benchmarks.py
nan's picture
feat: add versioning for the long-doc
bf586e3
raw
history blame
3.2 kB
from dataclasses import dataclass
from enum import Enum
from air_benchmark.tasks.tasks import BenchmarkTable
from src.envs import METRIC_LIST
def get_safe_name(name: str):
"""Get RFC 1123 compatible safe name"""
name = name.replace('-', '_')
return ''.join(
character.lower()
for character in name
if (character.isalnum() or character == '_'))
@dataclass
class Benchmark:
name: str # [domain]_[language]_[metric], task_key in the json file,
metric: str # ndcg_at_1 ,metric_key in the json file
col_name: str # [domain]_[language], name to display in the leaderboard
domain: str
lang: str
task: str
# create a function return an enum class containing all the benchmarks
def get_benchmarks_enum(benchmark_version, task_type):
benchmark_dict = {}
if task_type == "qa":
for task, domain_dict in BenchmarkTable[benchmark_version].items():
if task != task_type:
continue
for domain, lang_dict in domain_dict.items():
for lang, dataset_list in lang_dict.items():
benchmark_name = get_safe_name(f"{domain}_{lang}")
col_name = benchmark_name
for metric in dataset_list:
if "test" not in dataset_list[metric]["splits"]:
continue
benchmark_dict[benchmark_name] = \
Benchmark(benchmark_name, metric, col_name, domain, lang, task)
elif task_type == "long-doc":
for task, domain_dict in BenchmarkTable[benchmark_version].items():
if task != task_type:
continue
for domain, lang_dict in domain_dict.items():
for lang, dataset_list in lang_dict.items():
for dataset in dataset_list:
benchmark_name = f"{domain}_{lang}_{dataset}"
benchmark_name = get_safe_name(benchmark_name)
col_name = benchmark_name
if "test" not in dataset_list[dataset]["splits"]:
continue
for metric in METRIC_LIST:
benchmark_dict[benchmark_name] = \
Benchmark(benchmark_name, metric, col_name, domain, lang, task)
return benchmark_dict
versions = ("AIR-Bench_24.04", "AIR-Bench_24.05")
qa_benchmark_dict = {}
for version in versions:
safe_version_name = get_safe_name(version)[-4:]
qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, "qa"))
long_doc_benchmark_dict = {}
for version in versions:
safe_version_name = get_safe_name(version)[-4:]
long_doc_benchmark_dict[safe_version_name] = Enum(f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, "long-doc"))
# _qa_benchmark_dict, = get_benchmarks_enum('AIR-Bench_24.04', "qa")
# _long_doc_benchmark_dict = get_benchmarks_enum('AIR-Bench_24.04', "long-doc")
QABenchmarks = Enum('QABenchmarks', qa_benchmark_dict)
LongDocBenchmarks = Enum('LongDocBenchmarks', long_doc_benchmark_dict)