Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from dataclasses import dataclass | |
from enum import Enum | |
from air_benchmark.tasks.tasks import BenchmarkTable | |
from src.envs import METRIC_LIST | |
def get_safe_name(name: str): | |
"""Get RFC 1123 compatible safe name""" | |
name = name.replace('-', '_') | |
return ''.join( | |
character.lower() | |
for character in name | |
if (character.isalnum() or character == '_')) | |
class Benchmark: | |
name: str # [domain]_[language]_[metric], task_key in the json file, | |
metric: str # ndcg_at_1 ,metric_key in the json file | |
col_name: str # [domain]_[language], name to display in the leaderboard | |
domain: str | |
lang: str | |
task: str | |
# create a function return an enum class containing all the benchmarks | |
def get_benchmarks_enum(benchmark_version, task_type): | |
benchmark_dict = {} | |
if task_type == "qa": | |
for task, domain_dict in BenchmarkTable[benchmark_version].items(): | |
if task != task_type: | |
continue | |
for domain, lang_dict in domain_dict.items(): | |
for lang, dataset_list in lang_dict.items(): | |
benchmark_name = get_safe_name(f"{domain}_{lang}") | |
col_name = benchmark_name | |
for metric in dataset_list: | |
if "test" not in dataset_list[metric]["splits"]: | |
continue | |
benchmark_dict[benchmark_name] = \ | |
Benchmark(benchmark_name, metric, col_name, domain, lang, task) | |
elif task_type == "long-doc": | |
for task, domain_dict in BenchmarkTable[benchmark_version].items(): | |
if task != task_type: | |
continue | |
for domain, lang_dict in domain_dict.items(): | |
for lang, dataset_list in lang_dict.items(): | |
for dataset in dataset_list: | |
benchmark_name = f"{domain}_{lang}_{dataset}" | |
benchmark_name = get_safe_name(benchmark_name) | |
col_name = benchmark_name | |
if "test" not in dataset_list[dataset]["splits"]: | |
continue | |
for metric in METRIC_LIST: | |
benchmark_dict[benchmark_name] = \ | |
Benchmark(benchmark_name, metric, col_name, domain, lang, task) | |
return benchmark_dict | |
versions = ("AIR-Bench_24.04", "AIR-Bench_24.05") | |
qa_benchmark_dict = {} | |
for version in versions: | |
safe_version_name = get_safe_name(version)[-4:] | |
qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, "qa")) | |
long_doc_benchmark_dict = {} | |
for version in versions: | |
safe_version_name = get_safe_name(version)[-4:] | |
long_doc_benchmark_dict[safe_version_name] = Enum(f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, "long-doc")) | |
# _qa_benchmark_dict, = get_benchmarks_enum('AIR-Bench_24.04', "qa") | |
# _long_doc_benchmark_dict = get_benchmarks_enum('AIR-Bench_24.04', "long-doc") | |
QABenchmarks = Enum('QABenchmarks', qa_benchmark_dict) | |
LongDocBenchmarks = Enum('LongDocBenchmarks', long_doc_benchmark_dict) | |