Spaces:
AIR-Bench
/
Running on CPU Upgrade

nan commited on
Commit
6f9f649
1 Parent(s): 0401aeb

refactor: use enum class for the task type

Browse files
Files changed (6) hide show
  1. app.py +9 -8
  2. src/benchmarks.py +8 -7
  3. src/loaders.py +22 -29
  4. src/models.py +13 -2
  5. src/utils.py +45 -54
  6. tests/test_utils.py +2 -2
app.py CHANGED
@@ -35,6 +35,7 @@ from src.envs import (
35
  TOKEN,
36
  )
37
  from src.loaders import load_eval_results
 
38
  from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
39
 
40
 
@@ -75,7 +76,7 @@ def update_qa_metric(
75
  global datastore
76
  return update_metric(
77
  datastore,
78
- "qa",
79
  metric,
80
  domains,
81
  langs,
@@ -98,7 +99,7 @@ def update_doc_metric(
98
  global datastore
99
  return update_metric(
100
  datastore,
101
- "long-doc",
102
  metric,
103
  domains,
104
  langs,
@@ -181,7 +182,7 @@ with demo:
181
  )
182
 
183
  set_listeners(
184
- "qa",
185
  qa_df_elem_ret_rerank,
186
  qa_df_elem_ret_rerank_hidden,
187
  search_bar,
@@ -224,7 +225,7 @@ with demo:
224
  )
225
 
226
  set_listeners(
227
- "qa",
228
  qa_df_elem_ret,
229
  qa_df_elem_ret_hidden,
230
  search_bar_ret,
@@ -281,7 +282,7 @@ with demo:
281
  )
282
 
283
  set_listeners(
284
- "qa",
285
  qa_df_elem_rerank,
286
  qa_df_elem_rerank_hidden,
287
  qa_search_bar_rerank,
@@ -348,7 +349,7 @@ with demo:
348
  )
349
 
350
  set_listeners(
351
- "long-doc",
352
  doc_df_elem_ret_rerank,
353
  doc_df_elem_ret_rerank_hidden,
354
  search_bar,
@@ -405,7 +406,7 @@ with demo:
405
  )
406
 
407
  set_listeners(
408
- "long-doc",
409
  doc_df_elem_ret,
410
  doc_df_elem_ret_hidden,
411
  search_bar_ret,
@@ -462,7 +463,7 @@ with demo:
462
  )
463
 
464
  set_listeners(
465
- "long-doc",
466
  doc_df_elem_rerank,
467
  doc_df_elem_rerank_hidden,
468
  doc_search_bar_rerank,
 
35
  TOKEN,
36
  )
37
  from src.loaders import load_eval_results
38
+ from src.models import TaskType
39
  from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
40
 
41
 
 
76
  global datastore
77
  return update_metric(
78
  datastore,
79
+ TaskType.qa,
80
  metric,
81
  domains,
82
  langs,
 
99
  global datastore
100
  return update_metric(
101
  datastore,
102
+ TaskType.long_doc,
103
  metric,
104
  domains,
105
  langs,
 
182
  )
183
 
184
  set_listeners(
185
+ TaskType.qa,
186
  qa_df_elem_ret_rerank,
187
  qa_df_elem_ret_rerank_hidden,
188
  search_bar,
 
225
  )
226
 
227
  set_listeners(
228
+ TaskType.qa,
229
  qa_df_elem_ret,
230
  qa_df_elem_ret_hidden,
231
  search_bar_ret,
 
282
  )
283
 
284
  set_listeners(
285
+ TaskType.qa,
286
  qa_df_elem_rerank,
287
  qa_df_elem_rerank_hidden,
288
  qa_search_bar_rerank,
 
349
  )
350
 
351
  set_listeners(
352
+ TaskType.long_doc,
353
  doc_df_elem_ret_rerank,
354
  doc_df_elem_ret_rerank_hidden,
355
  search_bar,
 
406
  )
407
 
408
  set_listeners(
409
+ TaskType.long_doc,
410
  doc_df_elem_ret,
411
  doc_df_elem_ret_hidden,
412
  search_bar_ret,
 
463
  )
464
 
465
  set_listeners(
466
+ TaskType.long_doc,
467
  doc_df_elem_rerank,
468
  doc_df_elem_rerank_hidden,
469
  doc_search_bar_rerank,
src/benchmarks.py CHANGED
@@ -4,6 +4,7 @@ from enum import Enum
4
  from air_benchmark.tasks.tasks import BenchmarkTable
5
 
6
  from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
 
7
 
8
 
9
  def get_safe_name(name: str):
@@ -23,11 +24,11 @@ class Benchmark:
23
 
24
 
25
  # create a function return an enum class containing all the benchmarks
26
- def get_benchmarks_enum(benchmark_version, task_type):
27
  benchmark_dict = {}
28
- if task_type == "qa":
29
  for task, domain_dict in BenchmarkTable[benchmark_version].items():
30
- if task != task_type:
31
  continue
32
  for domain, lang_dict in domain_dict.items():
33
  for lang, dataset_list in lang_dict.items():
@@ -39,9 +40,9 @@ def get_benchmarks_enum(benchmark_version, task_type):
39
  benchmark_dict[benchmark_name] = Benchmark(
40
  benchmark_name, metric, col_name, domain, lang, task
41
  )
42
- elif task_type == "long-doc":
43
  for task, domain_dict in BenchmarkTable[benchmark_version].items():
44
- if task != task_type:
45
  continue
46
  for domain, lang_dict in domain_dict.items():
47
  for lang, dataset_list in lang_dict.items():
@@ -62,14 +63,14 @@ qa_benchmark_dict = {}
62
  for version in BENCHMARK_VERSION_LIST:
63
  safe_version_name = get_safe_name(version)[-4:]
64
  qa_benchmark_dict[safe_version_name] = Enum(
65
- f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, "qa")
66
  )
67
 
68
  long_doc_benchmark_dict = {}
69
  for version in BENCHMARK_VERSION_LIST:
70
  safe_version_name = get_safe_name(version)[-4:]
71
  long_doc_benchmark_dict[safe_version_name] = Enum(
72
- f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, "long-doc")
73
  )
74
 
75
 
 
4
  from air_benchmark.tasks.tasks import BenchmarkTable
5
 
6
  from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
7
+ from src.models import TaskType
8
 
9
 
10
  def get_safe_name(name: str):
 
24
 
25
 
26
  # create a function return an enum class containing all the benchmarks
27
+ def get_benchmarks_enum(benchmark_version: str, task_type: TaskType):
28
  benchmark_dict = {}
29
+ if task_type == TaskType.qa:
30
  for task, domain_dict in BenchmarkTable[benchmark_version].items():
31
+ if task != task_type.value:
32
  continue
33
  for domain, lang_dict in domain_dict.items():
34
  for lang, dataset_list in lang_dict.items():
 
40
  benchmark_dict[benchmark_name] = Benchmark(
41
  benchmark_name, metric, col_name, domain, lang, task
42
  )
43
+ elif task_type == TaskType.long_doc:
44
  for task, domain_dict in BenchmarkTable[benchmark_version].items():
45
+ if task != task_type.value:
46
  continue
47
  for domain, lang_dict in domain_dict.items():
48
  for lang, dataset_list in lang_dict.items():
 
63
  for version in BENCHMARK_VERSION_LIST:
64
  safe_version_name = get_safe_name(version)[-4:]
65
  qa_benchmark_dict[safe_version_name] = Enum(
66
+ f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.qa)
67
  )
68
 
69
  long_doc_benchmark_dict = {}
70
  for version in BENCHMARK_VERSION_LIST:
71
  safe_version_name = get_safe_name(version)[-4:]
72
  long_doc_benchmark_dict[safe_version_name] = Enum(
73
+ f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.long_doc)
74
  )
75
 
76
 
src/loaders.py CHANGED
@@ -11,7 +11,7 @@ from src.envs import (
11
  DEFAULT_METRIC_LONG_DOC,
12
  DEFAULT_METRIC_QA,
13
  )
14
- from src.models import FullEvalResult, LeaderboardDataStore
15
  from src.utils import get_default_cols, get_leaderboard_df
16
 
17
  pd.options.mode.copy_on_write = True
@@ -64,34 +64,27 @@ def get_safe_name(name: str):
64
 
65
  def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
66
  slug = get_safe_name(version)[-4:]
67
- lb_data_store = LeaderboardDataStore(version, slug, None, None, None, None, None, None, None, None)
68
- lb_data_store.raw_data = load_raw_eval_results(file_path)
69
- print(f"raw data: {len(lb_data_store.raw_data)}")
70
-
71
- lb_data_store.qa_raw_df = get_leaderboard_df(lb_data_store, task="qa", metric=DEFAULT_METRIC_QA)
72
- print(f"QA data loaded: {lb_data_store.qa_raw_df.shape}")
73
- lb_data_store.qa_fmt_df = lb_data_store.qa_raw_df.copy()
74
- shown_columns_qa, types_qa = get_default_cols("qa", lb_data_store.slug, add_fix_cols=True)
75
- lb_data_store.qa_types = types_qa
76
- lb_data_store.qa_fmt_df = lb_data_store.qa_fmt_df[
77
- ~lb_data_store.qa_fmt_df[COL_NAME_IS_ANONYMOUS]
78
- ][shown_columns_qa]
79
- lb_data_store.qa_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
80
-
81
- lb_data_store.doc_raw_df = get_leaderboard_df(lb_data_store, task="long-doc", metric=DEFAULT_METRIC_LONG_DOC)
82
- print(f"Long-Doc data loaded: {len(lb_data_store.doc_raw_df)}")
83
- lb_data_store.doc_fmt_df = lb_data_store.doc_raw_df.copy()
84
- shown_columns_long_doc, types_long_doc = get_default_cols("long-doc", lb_data_store.slug, add_fix_cols=True)
85
- lb_data_store.doc_types = types_long_doc
86
- lb_data_store.doc_fmt_df = lb_data_store.doc_fmt_df[
87
- ~lb_data_store.doc_fmt_df[COL_NAME_IS_ANONYMOUS]
88
- ][shown_columns_long_doc]
89
- lb_data_store.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
90
-
91
- lb_data_store.reranking_models = sorted(
92
- list(frozenset([eval_result.reranking_model for eval_result in lb_data_store.raw_data]))
93
- )
94
- return lb_data_store
95
 
96
 
97
  def load_eval_results(file_path: str) -> Dict[str, LeaderboardDataStore]:
 
11
  DEFAULT_METRIC_LONG_DOC,
12
  DEFAULT_METRIC_QA,
13
  )
14
+ from src.models import FullEvalResult, LeaderboardDataStore, TaskType
15
  from src.utils import get_default_cols, get_leaderboard_df
16
 
17
  pd.options.mode.copy_on_write = True
 
64
 
65
  def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
66
  slug = get_safe_name(version)[-4:]
67
+ datastore = LeaderboardDataStore(version, slug, None, None, None, None, None, None, None, None)
68
+ datastore.raw_data = load_raw_eval_results(file_path)
69
+ print(f"raw data: {len(datastore.raw_data)}")
70
+
71
+ datastore.qa_raw_df = get_leaderboard_df(datastore, TaskType.qa, DEFAULT_METRIC_QA)
72
+ print(f"QA data loaded: {datastore.qa_raw_df.shape}")
73
+ datastore.qa_fmt_df = datastore.qa_raw_df.copy()
74
+ qa_cols, datastore.qa_types = get_default_cols(TaskType.qa, datastore.slug, add_fix_cols=True)
75
+ datastore.qa_fmt_df = datastore.qa_fmt_df[~datastore.qa_fmt_df[COL_NAME_IS_ANONYMOUS]][qa_cols]
76
+ datastore.qa_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
77
+
78
+ datastore.doc_raw_df = get_leaderboard_df(datastore, TaskType.long_doc, DEFAULT_METRIC_LONG_DOC)
79
+ print(f"Long-Doc data loaded: {len(datastore.doc_raw_df)}")
80
+ datastore.doc_fmt_df = datastore.doc_raw_df.copy()
81
+ doc_cols, datastore.doc_types = get_default_cols(TaskType.long_doc, datastore.slug, add_fix_cols=True)
82
+ datastore.doc_fmt_df = datastore.doc_fmt_df[~datastore.doc_fmt_df[COL_NAME_IS_ANONYMOUS]][doc_cols]
83
+ datastore.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
84
+
85
+ datastore.reranking_models = \
86
+ sorted(list(frozenset([eval_result.reranking_model for eval_result in datastore.raw_data])))
87
+ return datastore
 
 
 
 
 
 
 
88
 
89
 
90
  def load_eval_results(file_path: str) -> Dict[str, LeaderboardDataStore]:
src/models.py CHANGED
@@ -1,11 +1,12 @@
1
  import json
 
 
2
  from collections import defaultdict
3
  from dataclasses import dataclass
4
  from typing import List, Optional
5
 
6
  import pandas as pd
7
 
8
- from src.benchmarks import get_safe_name
9
  from src.display.formatting import make_clickable_model
10
  from src.envs import (
11
  COL_NAME_IS_ANONYMOUS,
@@ -17,6 +18,10 @@ from src.envs import (
17
  COL_NAME_TIMESTAMP,
18
  )
19
 
 
 
 
 
20
 
21
  @dataclass
22
  class EvalResult:
@@ -147,4 +152,10 @@ class LeaderboardDataStore:
147
  doc_fmt_df: Optional[pd.DataFrame]
148
  reranking_models: Optional[list]
149
  qa_types: Optional[list]
150
- doc_types: Optional[list]
 
 
 
 
 
 
 
1
  import json
2
+ from enum import Enum
3
+
4
  from collections import defaultdict
5
  from dataclasses import dataclass
6
  from typing import List, Optional
7
 
8
  import pandas as pd
9
 
 
10
  from src.display.formatting import make_clickable_model
11
  from src.envs import (
12
  COL_NAME_IS_ANONYMOUS,
 
18
  COL_NAME_TIMESTAMP,
19
  )
20
 
21
+ def get_safe_name(name: str):
22
+ """Get RFC 1123 compatible safe name"""
23
+ name = name.replace("-", "_")
24
+ return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
25
 
26
  @dataclass
27
  class EvalResult:
 
152
  doc_fmt_df: Optional[pd.DataFrame]
153
  reranking_models: Optional[list]
154
  qa_types: Optional[list]
155
+ doc_types: Optional[list]
156
+
157
+
158
+ # Define an enum class with the name `TaskType`. There are two types of tasks, `qa` and `long-doc`.
159
+ class TaskType(Enum):
160
+ qa = "qa"
161
+ long_doc = "long-doc"
src/utils.py CHANGED
@@ -6,6 +6,7 @@ from pathlib import Path
6
 
7
  import pandas as pd
8
 
 
9
  from src.benchmarks import LongDocBenchmarks, QABenchmarks
10
  from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
11
  from src.display.formatting import styled_error, styled_message
@@ -69,12 +70,12 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
69
  return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
70
 
71
 
72
- def get_default_cols(task: str, version_slug, add_fix_cols: bool = True) -> tuple:
73
  cols = []
74
  types = []
75
- if task == "qa":
76
  benchmarks = QABenchmarks[version_slug]
77
- elif task == "long-doc":
78
  benchmarks = LongDocBenchmarks[version_slug]
79
  else:
80
  raise NotImplementedError
@@ -85,7 +86,6 @@ def get_default_cols(task: str, version_slug, add_fix_cols: bool = True) -> tupl
85
  continue
86
  cols.append(col_name)
87
  types.append(col_type)
88
-
89
  if add_fix_cols:
90
  _cols = []
91
  _types = []
@@ -104,16 +104,16 @@ def select_columns(
104
  df: pd.DataFrame,
105
  domain_query: list,
106
  language_query: list,
107
- task: str = "qa",
108
  reset_ranking: bool = True,
109
  version_slug: str = None,
110
  ) -> pd.DataFrame:
111
  cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
112
  selected_cols = []
113
  for c in cols:
114
- if task == "qa":
115
  eval_col = QABenchmarks[version_slug].value[c].value
116
- elif task == "long-doc":
117
  eval_col = LongDocBenchmarks[version_slug].value[c].value
118
  else:
119
  raise NotImplementedError
@@ -141,10 +141,10 @@ def get_safe_name(name: str):
141
  return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
142
 
143
 
144
- def _update_table(
145
- task: str,
146
  version: str,
147
- hidden_df: pd.DataFrame,
148
  domains: list,
149
  langs: list,
150
  reranking_query: list,
@@ -154,7 +154,7 @@ def _update_table(
154
  show_revision_and_timestamp: bool = False,
155
  ):
156
  version_slug = get_safe_name(version)[-4:]
157
- filtered_df = hidden_df.copy()
158
  if not show_anonymous:
159
  filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
160
  filtered_df = filter_models(filtered_df, reranking_query)
@@ -165,7 +165,7 @@ def _update_table(
165
  return filtered_df
166
 
167
 
168
- def update_table_long_doc(
169
  version: str,
170
  hidden_df: pd.DataFrame,
171
  domains: list,
@@ -176,8 +176,8 @@ def update_table_long_doc(
176
  show_revision_and_timestamp: bool = False,
177
  reset_ranking: bool = True,
178
  ):
179
- return _update_table(
180
- "long-doc",
181
  version,
182
  hidden_df,
183
  domains,
@@ -192,7 +192,7 @@ def update_table_long_doc(
192
 
193
  def update_metric(
194
  datastore,
195
- task: str,
196
  metric: str,
197
  domains: list,
198
  langs: list,
@@ -201,33 +201,24 @@ def update_metric(
201
  show_anonymous: bool = False,
202
  show_revision_and_timestamp: bool = False,
203
  ) -> pd.DataFrame:
204
- # raw_data = datastore.raw_data
205
- if task == "qa":
206
- leaderboard_df = get_leaderboard_df(datastore, task=task, metric=metric)
207
- version = datastore.version
208
- return update_table(
209
- version,
210
- leaderboard_df,
211
- domains,
212
- langs,
213
- reranking_model,
214
- query,
215
- show_anonymous,
216
- show_revision_and_timestamp,
217
- )
218
- elif task == "long-doc":
219
- leaderboard_df = get_leaderboard_df(datastore, task=task, metric=metric)
220
- version = datastore.version
221
- return update_table_long_doc(
222
- version,
223
- leaderboard_df,
224
- domains,
225
- langs,
226
- reranking_model,
227
- query,
228
- show_anonymous,
229
- show_revision_and_timestamp,
230
- )
231
 
232
 
233
  def upload_file(filepath: str):
@@ -341,7 +332,7 @@ def reset_rank(df):
341
  return df
342
 
343
 
344
- def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
345
  """
346
  Creates a dataframe from all the individual experiment results
347
  """
@@ -349,9 +340,9 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
349
  cols = [
350
  COL_NAME_IS_ANONYMOUS,
351
  ]
352
- if task == "qa":
353
  benchmarks = QABenchmarks[datastore.slug]
354
- elif task == "long-doc":
355
  benchmarks = LongDocBenchmarks[datastore.slug]
356
  else:
357
  raise NotImplementedError
@@ -360,7 +351,7 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
360
  benchmark_cols = [t.value.col_name for t in list(benchmarks.value)]
361
  all_data_json = []
362
  for v in raw_data:
363
- all_data_json += v.to_dict(task=task, metric=metric)
364
  df = pd.DataFrame.from_records(all_data_json)
365
 
366
  _benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
@@ -385,7 +376,7 @@ def get_leaderboard_df(datastore, task: str, metric: str) -> pd.DataFrame:
385
 
386
 
387
  def set_listeners(
388
- task,
389
  target_df,
390
  source_df,
391
  search_bar,
@@ -396,10 +387,10 @@ def set_listeners(
396
  show_anonymous,
397
  show_revision_and_timestamp,
398
  ):
399
- if task == "qa":
400
- update_table_func = update_table
401
- elif task == "long-doc":
402
- update_table_func = update_table_long_doc
403
  else:
404
  raise NotImplementedError
405
  selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
@@ -427,7 +418,7 @@ def set_listeners(
427
  )
428
 
429
 
430
- def update_table(
431
  version: str,
432
  hidden_df: pd.DataFrame,
433
  domains: list,
@@ -438,8 +429,8 @@ def update_table(
438
  show_revision_and_timestamp: bool = False,
439
  reset_ranking: bool = True,
440
  ):
441
- return _update_table(
442
- "qa",
443
  version,
444
  hidden_df,
445
  domains,
 
6
 
7
  import pandas as pd
8
 
9
+ from src.models import TaskType
10
  from src.benchmarks import LongDocBenchmarks, QABenchmarks
11
  from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
12
  from src.display.formatting import styled_error, styled_message
 
70
  return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
71
 
72
 
73
+ def get_default_cols(task: TaskType, version_slug, add_fix_cols: bool = True) -> tuple:
74
  cols = []
75
  types = []
76
+ if task == TaskType.qa:
77
  benchmarks = QABenchmarks[version_slug]
78
+ elif task == TaskType.long_doc:
79
  benchmarks = LongDocBenchmarks[version_slug]
80
  else:
81
  raise NotImplementedError
 
86
  continue
87
  cols.append(col_name)
88
  types.append(col_type)
 
89
  if add_fix_cols:
90
  _cols = []
91
  _types = []
 
104
  df: pd.DataFrame,
105
  domain_query: list,
106
  language_query: list,
107
+ task: TaskType = TaskType.qa,
108
  reset_ranking: bool = True,
109
  version_slug: str = None,
110
  ) -> pd.DataFrame:
111
  cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
112
  selected_cols = []
113
  for c in cols:
114
+ if task == TaskType.qa:
115
  eval_col = QABenchmarks[version_slug].value[c].value
116
+ elif task == TaskType.long_doc:
117
  eval_col = LongDocBenchmarks[version_slug].value[c].value
118
  else:
119
  raise NotImplementedError
 
141
  return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
142
 
143
 
144
+ def _update_df_elem(
145
+ task: TaskType,
146
  version: str,
147
+ source_df: pd.DataFrame,
148
  domains: list,
149
  langs: list,
150
  reranking_query: list,
 
154
  show_revision_and_timestamp: bool = False,
155
  ):
156
  version_slug = get_safe_name(version)[-4:]
157
+ filtered_df = source_df.copy()
158
  if not show_anonymous:
159
  filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
160
  filtered_df = filter_models(filtered_df, reranking_query)
 
165
  return filtered_df
166
 
167
 
168
+ def update_doc_df_elem(
169
  version: str,
170
  hidden_df: pd.DataFrame,
171
  domains: list,
 
176
  show_revision_and_timestamp: bool = False,
177
  reset_ranking: bool = True,
178
  ):
179
+ return _update_df_elem(
180
+ TaskType.long_doc,
181
  version,
182
  hidden_df,
183
  domains,
 
192
 
193
  def update_metric(
194
  datastore,
195
+ task: TaskType,
196
  metric: str,
197
  domains: list,
198
  langs: list,
 
201
  show_anonymous: bool = False,
202
  show_revision_and_timestamp: bool = False,
203
  ) -> pd.DataFrame:
204
+ if task == TaskType.qa:
205
+ update_func = update_qa_df_elem
206
+ elif task == TaskType.long_doc:
207
+ update_func = update_doc_df_elem
208
+ else:
209
+ raise NotImplemented
210
+ df_elem = get_leaderboard_df(datastore, task=task, metric=metric)
211
+ version = datastore.version
212
+ return update_func(
213
+ version,
214
+ df_elem,
215
+ domains,
216
+ langs,
217
+ reranking_model,
218
+ query,
219
+ show_anonymous,
220
+ show_revision_and_timestamp,
221
+ )
 
 
 
 
 
 
 
 
 
222
 
223
 
224
  def upload_file(filepath: str):
 
332
  return df
333
 
334
 
335
+ def get_leaderboard_df(datastore, task: TaskType, metric: str) -> pd.DataFrame:
336
  """
337
  Creates a dataframe from all the individual experiment results
338
  """
 
340
  cols = [
341
  COL_NAME_IS_ANONYMOUS,
342
  ]
343
+ if task == TaskType.qa:
344
  benchmarks = QABenchmarks[datastore.slug]
345
+ elif task == TaskType.long_doc:
346
  benchmarks = LongDocBenchmarks[datastore.slug]
347
  else:
348
  raise NotImplementedError
 
351
  benchmark_cols = [t.value.col_name for t in list(benchmarks.value)]
352
  all_data_json = []
353
  for v in raw_data:
354
+ all_data_json += v.to_dict(task=task.value, metric=metric)
355
  df = pd.DataFrame.from_records(all_data_json)
356
 
357
  _benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
 
376
 
377
 
378
  def set_listeners(
379
+ task: TaskType,
380
  target_df,
381
  source_df,
382
  search_bar,
 
387
  show_anonymous,
388
  show_revision_and_timestamp,
389
  ):
390
+ if task == TaskType.qa:
391
+ update_table_func = update_qa_df_elem
392
+ elif task == TaskType.long_doc:
393
+ update_table_func = update_doc_df_elem
394
  else:
395
  raise NotImplementedError
396
  selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
 
418
  )
419
 
420
 
421
+ def update_qa_df_elem(
422
  version: str,
423
  hidden_df: pd.DataFrame,
424
  domains: list,
 
429
  show_revision_and_timestamp: bool = False,
430
  reset_ranking: bool = True,
431
  ):
432
+ return _update_df_elem(
433
+ TaskType.qa,
434
  version,
435
  hidden_df,
436
  domains,
tests/test_utils.py CHANGED
@@ -18,7 +18,7 @@ from src.utils import (
18
  get_iso_format_timestamp,
19
  search_table,
20
  select_columns,
21
- update_table_long_doc,
22
  )
23
 
24
 
@@ -90,7 +90,7 @@ def test_select_columns(toy_df):
90
 
91
 
92
  def test_update_table_long_doc(toy_df_long_doc):
93
- df_result = update_table_long_doc(
94
  toy_df_long_doc,
95
  [
96
  "law",
 
18
  get_iso_format_timestamp,
19
  search_table,
20
  select_columns,
21
+ update_doc_df_elem,
22
  )
23
 
24
 
 
90
 
91
 
92
  def test_update_table_long_doc(toy_df_long_doc):
93
+ df_result = update_doc_df_elem(
94
  toy_df_long_doc,
95
  [
96
  "law",