Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
refactor: use enum class for the task type
Browse files- app.py +9 -8
- src/benchmarks.py +8 -7
- src/loaders.py +22 -29
- src/models.py +13 -2
- src/utils.py +45 -54
- tests/test_utils.py +2 -2
app.py
CHANGED
@@ -35,6 +35,7 @@ from src.envs import (
|
|
35 |
TOKEN,
|
36 |
)
|
37 |
from src.loaders import load_eval_results
|
|
|
38 |
from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
|
39 |
|
40 |
|
@@ -75,7 +76,7 @@ def update_qa_metric(
|
|
75 |
global datastore
|
76 |
return update_metric(
|
77 |
datastore,
|
78 |
-
|
79 |
metric,
|
80 |
domains,
|
81 |
langs,
|
@@ -98,7 +99,7 @@ def update_doc_metric(
|
|
98 |
global datastore
|
99 |
return update_metric(
|
100 |
datastore,
|
101 |
-
|
102 |
metric,
|
103 |
domains,
|
104 |
langs,
|
@@ -181,7 +182,7 @@ with demo:
|
|
181 |
)
|
182 |
|
183 |
set_listeners(
|
184 |
-
|
185 |
qa_df_elem_ret_rerank,
|
186 |
qa_df_elem_ret_rerank_hidden,
|
187 |
search_bar,
|
@@ -224,7 +225,7 @@ with demo:
|
|
224 |
)
|
225 |
|
226 |
set_listeners(
|
227 |
-
|
228 |
qa_df_elem_ret,
|
229 |
qa_df_elem_ret_hidden,
|
230 |
search_bar_ret,
|
@@ -281,7 +282,7 @@ with demo:
|
|
281 |
)
|
282 |
|
283 |
set_listeners(
|
284 |
-
|
285 |
qa_df_elem_rerank,
|
286 |
qa_df_elem_rerank_hidden,
|
287 |
qa_search_bar_rerank,
|
@@ -348,7 +349,7 @@ with demo:
|
|
348 |
)
|
349 |
|
350 |
set_listeners(
|
351 |
-
|
352 |
doc_df_elem_ret_rerank,
|
353 |
doc_df_elem_ret_rerank_hidden,
|
354 |
search_bar,
|
@@ -405,7 +406,7 @@ with demo:
|
|
405 |
)
|
406 |
|
407 |
set_listeners(
|
408 |
-
|
409 |
doc_df_elem_ret,
|
410 |
doc_df_elem_ret_hidden,
|
411 |
search_bar_ret,
|
@@ -462,7 +463,7 @@ with demo:
|
|
462 |
)
|
463 |
|
464 |
set_listeners(
|
465 |
-
|
466 |
doc_df_elem_rerank,
|
467 |
doc_df_elem_rerank_hidden,
|
468 |
doc_search_bar_rerank,
|
|
|
35 |
TOKEN,
|
36 |
)
|
37 |
from src.loaders import load_eval_results
|
38 |
+
from src.models import TaskType
|
39 |
from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
|
40 |
|
41 |
|
|
|
76 |
global datastore
|
77 |
return update_metric(
|
78 |
datastore,
|
79 |
+
TaskType.qa,
|
80 |
metric,
|
81 |
domains,
|
82 |
langs,
|
|
|
99 |
global datastore
|
100 |
return update_metric(
|
101 |
datastore,
|
102 |
+
TaskType.long_doc,
|
103 |
metric,
|
104 |
domains,
|
105 |
langs,
|
|
|
182 |
)
|
183 |
|
184 |
set_listeners(
|
185 |
+
TaskType.qa,
|
186 |
qa_df_elem_ret_rerank,
|
187 |
qa_df_elem_ret_rerank_hidden,
|
188 |
search_bar,
|
|
|
225 |
)
|
226 |
|
227 |
set_listeners(
|
228 |
+
TaskType.qa,
|
229 |
qa_df_elem_ret,
|
230 |
qa_df_elem_ret_hidden,
|
231 |
search_bar_ret,
|
|
|
282 |
)
|
283 |
|
284 |
set_listeners(
|
285 |
+
TaskType.qa,
|
286 |
qa_df_elem_rerank,
|
287 |
qa_df_elem_rerank_hidden,
|
288 |
qa_search_bar_rerank,
|
|
|
349 |
)
|
350 |
|
351 |
set_listeners(
|
352 |
+
TaskType.long_doc,
|
353 |
doc_df_elem_ret_rerank,
|
354 |
doc_df_elem_ret_rerank_hidden,
|
355 |
search_bar,
|
|
|
406 |
)
|
407 |
|
408 |
set_listeners(
|
409 |
+
TaskType.long_doc,
|
410 |
doc_df_elem_ret,
|
411 |
doc_df_elem_ret_hidden,
|
412 |
search_bar_ret,
|
|
|
463 |
)
|
464 |
|
465 |
set_listeners(
|
466 |
+
TaskType.long_doc,
|
467 |
doc_df_elem_rerank,
|
468 |
doc_df_elem_rerank_hidden,
|
469 |
doc_search_bar_rerank,
|
src/benchmarks.py
CHANGED
@@ -4,6 +4,7 @@ from enum import Enum
|
|
4 |
from air_benchmark.tasks.tasks import BenchmarkTable
|
5 |
|
6 |
from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
|
|
|
7 |
|
8 |
|
9 |
def get_safe_name(name: str):
|
@@ -23,11 +24,11 @@ class Benchmark:
|
|
23 |
|
24 |
|
25 |
# create a function return an enum class containing all the benchmarks
|
26 |
-
def get_benchmarks_enum(benchmark_version, task_type):
|
27 |
benchmark_dict = {}
|
28 |
-
if task_type ==
|
29 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
30 |
-
if task != task_type:
|
31 |
continue
|
32 |
for domain, lang_dict in domain_dict.items():
|
33 |
for lang, dataset_list in lang_dict.items():
|
@@ -39,9 +40,9 @@ def get_benchmarks_enum(benchmark_version, task_type):
|
|
39 |
benchmark_dict[benchmark_name] = Benchmark(
|
40 |
benchmark_name, metric, col_name, domain, lang, task
|
41 |
)
|
42 |
-
elif task_type ==
|
43 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
44 |
-
if task != task_type:
|
45 |
continue
|
46 |
for domain, lang_dict in domain_dict.items():
|
47 |
for lang, dataset_list in lang_dict.items():
|
@@ -62,14 +63,14 @@ qa_benchmark_dict = {}
|
|
62 |
for version in BENCHMARK_VERSION_LIST:
|
63 |
safe_version_name = get_safe_name(version)[-4:]
|
64 |
qa_benchmark_dict[safe_version_name] = Enum(
|
65 |
-
f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version,
|
66 |
)
|
67 |
|
68 |
long_doc_benchmark_dict = {}
|
69 |
for version in BENCHMARK_VERSION_LIST:
|
70 |
safe_version_name = get_safe_name(version)[-4:]
|
71 |
long_doc_benchmark_dict[safe_version_name] = Enum(
|
72 |
-
f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version,
|
73 |
)
|
74 |
|
75 |
|
|
|
4 |
from air_benchmark.tasks.tasks import BenchmarkTable
|
5 |
|
6 |
from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
|
7 |
+
from src.models import TaskType
|
8 |
|
9 |
|
10 |
def get_safe_name(name: str):
|
|
|
24 |
|
25 |
|
26 |
# create a function return an enum class containing all the benchmarks
|
27 |
+
def get_benchmarks_enum(benchmark_version: str, task_type: TaskType):
|
28 |
benchmark_dict = {}
|
29 |
+
if task_type == TaskType.qa:
|
30 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
31 |
+
if task != task_type.value:
|
32 |
continue
|
33 |
for domain, lang_dict in domain_dict.items():
|
34 |
for lang, dataset_list in lang_dict.items():
|
|
|
40 |
benchmark_dict[benchmark_name] = Benchmark(
|
41 |
benchmark_name, metric, col_name, domain, lang, task
|
42 |
)
|
43 |
+
elif task_type == TaskType.long_doc:
|
44 |
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
45 |
+
if task != task_type.value:
|
46 |
continue
|
47 |
for domain, lang_dict in domain_dict.items():
|
48 |
for lang, dataset_list in lang_dict.items():
|
|
|
63 |
for version in BENCHMARK_VERSION_LIST:
|
64 |
safe_version_name = get_safe_name(version)[-4:]
|
65 |
qa_benchmark_dict[safe_version_name] = Enum(
|
66 |
+
f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.qa)
|
67 |
)
|
68 |
|
69 |
long_doc_benchmark_dict = {}
|
70 |
for version in BENCHMARK_VERSION_LIST:
|
71 |
safe_version_name = get_safe_name(version)[-4:]
|
72 |
long_doc_benchmark_dict[safe_version_name] = Enum(
|
73 |
+
f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.long_doc)
|
74 |
)
|
75 |
|
76 |
|
src/loaders.py
CHANGED
@@ -11,7 +11,7 @@ from src.envs import (
|
|
11 |
DEFAULT_METRIC_LONG_DOC,
|
12 |
DEFAULT_METRIC_QA,
|
13 |
)
|
14 |
-
from src.models import FullEvalResult, LeaderboardDataStore
|
15 |
from src.utils import get_default_cols, get_leaderboard_df
|
16 |
|
17 |
pd.options.mode.copy_on_write = True
|
@@ -64,34 +64,27 @@ def get_safe_name(name: str):
|
|
64 |
|
65 |
def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
|
66 |
slug = get_safe_name(version)[-4:]
|
67 |
-
|
68 |
-
|
69 |
-
print(f"raw data: {len(
|
70 |
-
|
71 |
-
|
72 |
-
print(f"QA data loaded: {
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
][shown_columns_long_doc]
|
89 |
-
lb_data_store.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
90 |
-
|
91 |
-
lb_data_store.reranking_models = sorted(
|
92 |
-
list(frozenset([eval_result.reranking_model for eval_result in lb_data_store.raw_data]))
|
93 |
-
)
|
94 |
-
return lb_data_store
|
95 |
|
96 |
|
97 |
def load_eval_results(file_path: str) -> Dict[str, LeaderboardDataStore]:
|
|
|
11 |
DEFAULT_METRIC_LONG_DOC,
|
12 |
DEFAULT_METRIC_QA,
|
13 |
)
|
14 |
+
from src.models import FullEvalResult, LeaderboardDataStore, TaskType
|
15 |
from src.utils import get_default_cols, get_leaderboard_df
|
16 |
|
17 |
pd.options.mode.copy_on_write = True
|
|
|
64 |
|
65 |
def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
|
66 |
slug = get_safe_name(version)[-4:]
|
67 |
+
datastore = LeaderboardDataStore(version, slug, None, None, None, None, None, None, None, None)
|
68 |
+
datastore.raw_data = load_raw_eval_results(file_path)
|
69 |
+
print(f"raw data: {len(datastore.raw_data)}")
|
70 |
+
|
71 |
+
datastore.qa_raw_df = get_leaderboard_df(datastore, TaskType.qa, DEFAULT_METRIC_QA)
|
72 |
+
print(f"QA data loaded: {datastore.qa_raw_df.shape}")
|
73 |
+
datastore.qa_fmt_df = datastore.qa_raw_df.copy()
|
74 |
+
qa_cols, datastore.qa_types = get_default_cols(TaskType.qa, datastore.slug, add_fix_cols=True)
|
75 |
+
datastore.qa_fmt_df = datastore.qa_fmt_df[~datastore.qa_fmt_df[COL_NAME_IS_ANONYMOUS]][qa_cols]
|
76 |
+
datastore.qa_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
77 |
+
|
78 |
+
datastore.doc_raw_df = get_leaderboard_df(datastore, TaskType.long_doc, DEFAULT_METRIC_LONG_DOC)
|
79 |
+
print(f"Long-Doc data loaded: {len(datastore.doc_raw_df)}")
|
80 |
+
datastore.doc_fmt_df = datastore.doc_raw_df.copy()
|
81 |
+
doc_cols, datastore.doc_types = get_default_cols(TaskType.long_doc, datastore.slug, add_fix_cols=True)
|
82 |
+
datastore.doc_fmt_df = datastore.doc_fmt_df[~datastore.doc_fmt_df[COL_NAME_IS_ANONYMOUS]][doc_cols]
|
83 |
+
datastore.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
84 |
+
|
85 |
+
datastore.reranking_models = \
|
86 |
+
sorted(list(frozenset([eval_result.reranking_model for eval_result in datastore.raw_data])))
|
87 |
+
return datastore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
def load_eval_results(file_path: str) -> Dict[str, LeaderboardDataStore]:
|
src/models.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import json
|
|
|
|
|
2 |
from collections import defaultdict
|
3 |
from dataclasses import dataclass
|
4 |
from typing import List, Optional
|
5 |
|
6 |
import pandas as pd
|
7 |
|
8 |
-
from src.benchmarks import get_safe_name
|
9 |
from src.display.formatting import make_clickable_model
|
10 |
from src.envs import (
|
11 |
COL_NAME_IS_ANONYMOUS,
|
@@ -17,6 +18,10 @@ from src.envs import (
|
|
17 |
COL_NAME_TIMESTAMP,
|
18 |
)
|
19 |
|
|
|
|
|
|
|
|
|
20 |
|
21 |
@dataclass
|
22 |
class EvalResult:
|
@@ -147,4 +152,10 @@ class LeaderboardDataStore:
|
|
147 |
doc_fmt_df: Optional[pd.DataFrame]
|
148 |
reranking_models: Optional[list]
|
149 |
qa_types: Optional[list]
|
150 |
-
doc_types: Optional[list]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
from collections import defaultdict
|
5 |
from dataclasses import dataclass
|
6 |
from typing import List, Optional
|
7 |
|
8 |
import pandas as pd
|
9 |
|
|
|
10 |
from src.display.formatting import make_clickable_model
|
11 |
from src.envs import (
|
12 |
COL_NAME_IS_ANONYMOUS,
|
|
|
18 |
COL_NAME_TIMESTAMP,
|
19 |
)
|
20 |
|
21 |
+
def get_safe_name(name: str):
|
22 |
+
"""Get RFC 1123 compatible safe name"""
|
23 |
+
name = name.replace("-", "_")
|
24 |
+
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
25 |
|
26 |
@dataclass
|
27 |
class EvalResult:
|
|
|
152 |
doc_fmt_df: Optional[pd.DataFrame]
|
153 |
reranking_models: Optional[list]
|
154 |
qa_types: Optional[list]
|
155 |
+
doc_types: Optional[list]
|
156 |
+
|
157 |
+
|
158 |
+
# Define an enum class with the name `TaskType`. There are two types of tasks, `qa` and `long-doc`.
|
159 |
+
class TaskType(Enum):
|
160 |
+
qa = "qa"
|
161 |
+
long_doc = "long-doc"
|
src/utils.py
CHANGED
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
6 |
|
7 |
import pandas as pd
|
8 |
|
|
|
9 |
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
10 |
from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
|
11 |
from src.display.formatting import styled_error, styled_message
|
@@ -69,12 +70,12 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
|
|
69 |
return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
|
70 |
|
71 |
|
72 |
-
def get_default_cols(task:
|
73 |
cols = []
|
74 |
types = []
|
75 |
-
if task ==
|
76 |
benchmarks = QABenchmarks[version_slug]
|
77 |
-
elif task ==
|
78 |
benchmarks = LongDocBenchmarks[version_slug]
|
79 |
else:
|
80 |
raise NotImplementedError
|
@@ -85,7 +86,6 @@ def get_default_cols(task: str, version_slug, add_fix_cols: bool = True) -> tupl
|
|
85 |
continue
|
86 |
cols.append(col_name)
|
87 |
types.append(col_type)
|
88 |
-
|
89 |
if add_fix_cols:
|
90 |
_cols = []
|
91 |
_types = []
|
@@ -104,16 +104,16 @@ def select_columns(
|
|
104 |
df: pd.DataFrame,
|
105 |
domain_query: list,
|
106 |
language_query: list,
|
107 |
-
task:
|
108 |
reset_ranking: bool = True,
|
109 |
version_slug: str = None,
|
110 |
) -> pd.DataFrame:
|
111 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
112 |
selected_cols = []
|
113 |
for c in cols:
|
114 |
-
if task ==
|
115 |
eval_col = QABenchmarks[version_slug].value[c].value
|
116 |
-
elif task ==
|
117 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
118 |
else:
|
119 |
raise NotImplementedError
|
@@ -141,10 +141,10 @@ def get_safe_name(name: str):
|
|
141 |
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
142 |
|
143 |
|
144 |
-
def
|
145 |
-
task:
|
146 |
version: str,
|
147 |
-
|
148 |
domains: list,
|
149 |
langs: list,
|
150 |
reranking_query: list,
|
@@ -154,7 +154,7 @@ def _update_table(
|
|
154 |
show_revision_and_timestamp: bool = False,
|
155 |
):
|
156 |
version_slug = get_safe_name(version)[-4:]
|
157 |
-
filtered_df =
|
158 |
if not show_anonymous:
|
159 |
filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
|
160 |
filtered_df = filter_models(filtered_df, reranking_query)
|
@@ -165,7 +165,7 @@ def _update_table(
|
|
165 |
return filtered_df
|
166 |
|
167 |
|
168 |
-
def
|
169 |
version: str,
|
170 |
hidden_df: pd.DataFrame,
|
171 |
domains: list,
|
@@ -176,8 +176,8 @@ def update_table_long_doc(
|
|
176 |
show_revision_and_timestamp: bool = False,
|
177 |
reset_ranking: bool = True,
|
178 |
):
|
179 |
-
return
|
180 |
-
|
181 |
version,
|
182 |
hidden_df,
|
183 |
domains,
|
@@ -192,7 +192,7 @@ def update_table_long_doc(
|
|
192 |
|
193 |
def update_metric(
|
194 |
datastore,
|
195 |
-
task:
|
196 |
metric: str,
|
197 |
domains: list,
|
198 |
langs: list,
|
@@ -201,33 +201,24 @@ def update_metric(
|
|
201 |
show_anonymous: bool = False,
|
202 |
show_revision_and_timestamp: bool = False,
|
203 |
) -> pd.DataFrame:
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
version,
|
223 |
-
leaderboard_df,
|
224 |
-
domains,
|
225 |
-
langs,
|
226 |
-
reranking_model,
|
227 |
-
query,
|
228 |
-
show_anonymous,
|
229 |
-
show_revision_and_timestamp,
|
230 |
-
)
|
231 |
|
232 |
|
233 |
def upload_file(filepath: str):
|
@@ -341,7 +332,7 @@ def reset_rank(df):
|
|
341 |
return df
|
342 |
|
343 |
|
344 |
-
def get_leaderboard_df(datastore, task:
|
345 |
"""
|
346 |
Creates a dataframe from all the individual experiment results
|
347 |
"""
|
@@ -349,9 +340,9 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
|
|
349 |
cols = [
|
350 |
COL_NAME_IS_ANONYMOUS,
|
351 |
]
|
352 |
-
if task ==
|
353 |
benchmarks = QABenchmarks[datastore.slug]
|
354 |
-
elif task ==
|
355 |
benchmarks = LongDocBenchmarks[datastore.slug]
|
356 |
else:
|
357 |
raise NotImplementedError
|
@@ -360,7 +351,7 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
|
|
360 |
benchmark_cols = [t.value.col_name for t in list(benchmarks.value)]
|
361 |
all_data_json = []
|
362 |
for v in raw_data:
|
363 |
-
all_data_json += v.to_dict(task=task, metric=metric)
|
364 |
df = pd.DataFrame.from_records(all_data_json)
|
365 |
|
366 |
_benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
|
@@ -385,7 +376,7 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
|
|
385 |
|
386 |
|
387 |
def set_listeners(
|
388 |
-
task,
|
389 |
target_df,
|
390 |
source_df,
|
391 |
search_bar,
|
@@ -396,10 +387,10 @@ def set_listeners(
|
|
396 |
show_anonymous,
|
397 |
show_revision_and_timestamp,
|
398 |
):
|
399 |
-
if task ==
|
400 |
-
update_table_func =
|
401 |
-
elif task ==
|
402 |
-
update_table_func =
|
403 |
else:
|
404 |
raise NotImplementedError
|
405 |
selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
|
@@ -427,7 +418,7 @@ def set_listeners(
|
|
427 |
)
|
428 |
|
429 |
|
430 |
-
def
|
431 |
version: str,
|
432 |
hidden_df: pd.DataFrame,
|
433 |
domains: list,
|
@@ -438,8 +429,8 @@ def update_table(
|
|
438 |
show_revision_and_timestamp: bool = False,
|
439 |
reset_ranking: bool = True,
|
440 |
):
|
441 |
-
return
|
442 |
-
|
443 |
version,
|
444 |
hidden_df,
|
445 |
domains,
|
|
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
+
from src.models import TaskType
|
10 |
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
11 |
from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
|
12 |
from src.display.formatting import styled_error, styled_message
|
|
|
70 |
return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
|
71 |
|
72 |
|
73 |
+
def get_default_cols(task: TaskType, version_slug, add_fix_cols: bool = True) -> tuple:
|
74 |
cols = []
|
75 |
types = []
|
76 |
+
if task == TaskType.qa:
|
77 |
benchmarks = QABenchmarks[version_slug]
|
78 |
+
elif task == TaskType.long_doc:
|
79 |
benchmarks = LongDocBenchmarks[version_slug]
|
80 |
else:
|
81 |
raise NotImplementedError
|
|
|
86 |
continue
|
87 |
cols.append(col_name)
|
88 |
types.append(col_type)
|
|
|
89 |
if add_fix_cols:
|
90 |
_cols = []
|
91 |
_types = []
|
|
|
104 |
df: pd.DataFrame,
|
105 |
domain_query: list,
|
106 |
language_query: list,
|
107 |
+
task: TaskType = TaskType.qa,
|
108 |
reset_ranking: bool = True,
|
109 |
version_slug: str = None,
|
110 |
) -> pd.DataFrame:
|
111 |
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
112 |
selected_cols = []
|
113 |
for c in cols:
|
114 |
+
if task == TaskType.qa:
|
115 |
eval_col = QABenchmarks[version_slug].value[c].value
|
116 |
+
elif task == TaskType.long_doc:
|
117 |
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
118 |
else:
|
119 |
raise NotImplementedError
|
|
|
141 |
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
142 |
|
143 |
|
144 |
+
def _update_df_elem(
|
145 |
+
task: TaskType,
|
146 |
version: str,
|
147 |
+
source_df: pd.DataFrame,
|
148 |
domains: list,
|
149 |
langs: list,
|
150 |
reranking_query: list,
|
|
|
154 |
show_revision_and_timestamp: bool = False,
|
155 |
):
|
156 |
version_slug = get_safe_name(version)[-4:]
|
157 |
+
filtered_df = source_df.copy()
|
158 |
if not show_anonymous:
|
159 |
filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
|
160 |
filtered_df = filter_models(filtered_df, reranking_query)
|
|
|
165 |
return filtered_df
|
166 |
|
167 |
|
168 |
+
def update_doc_df_elem(
|
169 |
version: str,
|
170 |
hidden_df: pd.DataFrame,
|
171 |
domains: list,
|
|
|
176 |
show_revision_and_timestamp: bool = False,
|
177 |
reset_ranking: bool = True,
|
178 |
):
|
179 |
+
return _update_df_elem(
|
180 |
+
TaskType.long_doc,
|
181 |
version,
|
182 |
hidden_df,
|
183 |
domains,
|
|
|
192 |
|
193 |
def update_metric(
|
194 |
datastore,
|
195 |
+
task: TaskType,
|
196 |
metric: str,
|
197 |
domains: list,
|
198 |
langs: list,
|
|
|
201 |
show_anonymous: bool = False,
|
202 |
show_revision_and_timestamp: bool = False,
|
203 |
) -> pd.DataFrame:
|
204 |
+
if task == TaskType.qa:
|
205 |
+
update_func = update_qa_df_elem
|
206 |
+
elif task == TaskType.long_doc:
|
207 |
+
update_func = update_doc_df_elem
|
208 |
+
else:
|
209 |
+
raise NotImplemented
|
210 |
+
df_elem = get_leaderboard_df(datastore, task=task, metric=metric)
|
211 |
+
version = datastore.version
|
212 |
+
return update_func(
|
213 |
+
version,
|
214 |
+
df_elem,
|
215 |
+
domains,
|
216 |
+
langs,
|
217 |
+
reranking_model,
|
218 |
+
query,
|
219 |
+
show_anonymous,
|
220 |
+
show_revision_and_timestamp,
|
221 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
|
224 |
def upload_file(filepath: str):
|
|
|
332 |
return df
|
333 |
|
334 |
|
335 |
+
def get_leaderboard_df(datastore, task: TaskType, metric: str) -> pd.DataFrame:
|
336 |
"""
|
337 |
Creates a dataframe from all the individual experiment results
|
338 |
"""
|
|
|
340 |
cols = [
|
341 |
COL_NAME_IS_ANONYMOUS,
|
342 |
]
|
343 |
+
if task == TaskType.qa:
|
344 |
benchmarks = QABenchmarks[datastore.slug]
|
345 |
+
elif task == TaskType.long_doc:
|
346 |
benchmarks = LongDocBenchmarks[datastore.slug]
|
347 |
else:
|
348 |
raise NotImplementedError
|
|
|
351 |
benchmark_cols = [t.value.col_name for t in list(benchmarks.value)]
|
352 |
all_data_json = []
|
353 |
for v in raw_data:
|
354 |
+
all_data_json += v.to_dict(task=task.value, metric=metric)
|
355 |
df = pd.DataFrame.from_records(all_data_json)
|
356 |
|
357 |
_benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
|
|
|
376 |
|
377 |
|
378 |
def set_listeners(
|
379 |
+
task: TaskType,
|
380 |
target_df,
|
381 |
source_df,
|
382 |
search_bar,
|
|
|
387 |
show_anonymous,
|
388 |
show_revision_and_timestamp,
|
389 |
):
|
390 |
+
if task == TaskType.qa:
|
391 |
+
update_table_func = update_qa_df_elem
|
392 |
+
elif task == TaskType.long_doc:
|
393 |
+
update_table_func = update_doc_df_elem
|
394 |
else:
|
395 |
raise NotImplementedError
|
396 |
selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
|
|
|
418 |
)
|
419 |
|
420 |
|
421 |
+
def update_qa_df_elem(
|
422 |
version: str,
|
423 |
hidden_df: pd.DataFrame,
|
424 |
domains: list,
|
|
|
429 |
show_revision_and_timestamp: bool = False,
|
430 |
reset_ranking: bool = True,
|
431 |
):
|
432 |
+
return _update_df_elem(
|
433 |
+
TaskType.qa,
|
434 |
version,
|
435 |
hidden_df,
|
436 |
domains,
|
tests/test_utils.py
CHANGED
@@ -18,7 +18,7 @@ from src.utils import (
|
|
18 |
get_iso_format_timestamp,
|
19 |
search_table,
|
20 |
select_columns,
|
21 |
-
|
22 |
)
|
23 |
|
24 |
|
@@ -90,7 +90,7 @@ def test_select_columns(toy_df):
|
|
90 |
|
91 |
|
92 |
def test_update_table_long_doc(toy_df_long_doc):
|
93 |
-
df_result =
|
94 |
toy_df_long_doc,
|
95 |
[
|
96 |
"law",
|
|
|
18 |
get_iso_format_timestamp,
|
19 |
search_table,
|
20 |
select_columns,
|
21 |
+
update_doc_df_elem,
|
22 |
)
|
23 |
|
24 |
|
|
|
90 |
|
91 |
|
92 |
def test_update_table_long_doc(toy_df_long_doc):
|
93 |
+
df_result = update_doc_df_elem(
|
94 |
toy_df_long_doc,
|
95 |
[
|
96 |
"law",
|