Spaces:
AIR-Bench
/
Running on CPU Upgrade

fix-24-05-display-issue

#27
by nan - opened
.github/workflows/main.yaml DELETED
@@ -1,20 +0,0 @@
1
- name: Sync to Hugging Face hub
2
- on:
3
- push:
4
- branches: [main]
5
-
6
- # to run this workflow manually from the Actions tab
7
- workflow_dispatch:
8
-
9
- jobs:
10
- sync-to-hub:
11
- runs-on: ubuntu-latest
12
- steps:
13
- - uses: actions/checkout@v3
14
- with:
15
- fetch-depth: 0
16
- lfs: true
17
- - name: Push to hub
18
- env:
19
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
- run: git push https://hanhainebula:[email protected]/spaces/AIR-Bench/leaderboard main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Makefile CHANGED
@@ -3,21 +3,11 @@
3
 
4
  style:
5
  python -m black --line-length 119 .
6
- python -m black --line-length 119 src
7
  python -m isort .
8
- python -m isort src
9
  ruff check --fix .
10
- ruff check --fix src
11
 
12
 
13
  quality:
14
  python -m black --check --line-length 119 .
15
- python -m black --check --line-length 119 src
16
  python -m isort --check-only .
17
- python -m isort --check-only src
18
  ruff check .
19
- ruff check src
20
-
21
-
22
- test:
23
- python -m pytest tests
 
3
 
4
  style:
5
  python -m black --line-length 119 .
 
6
  python -m isort .
 
7
  ruff check --fix .
 
8
 
9
 
10
  quality:
11
  python -m black --check --line-length 119 .
 
12
  python -m isort --check-only .
 
13
  ruff check .
 
 
 
 
 
README.md CHANGED
@@ -10,14 +10,36 @@ pinned: true
10
  license: apache-2.0
11
  ---
12
 
13
- # The AIR-Bench Leaderboard repository
14
 
15
- This repository contains the code for the AIR-Bench Leaderboard.
16
 
17
- | Important Links | Description |
18
- | ------------------------------------------------------------ | ---------------------------------------------- |
19
- | [AIR-Bench](https://github.com/AIR-Bench/AIR-Bench) | The main repository for the AIR-Bench project. |
20
- | [Leaderboard Space](https://huggingface.co/spaces/AIR-Bench/leaderboard) | The leaderboard space on Hugging Face. |
21
- | [Leaderboard Backend Space](https://huggingface.co/spaces/AIR-Bench/leaderboard_backend) | The leaderboard backend space on Hugging Face. |
22
- | [Leaderboard Code](https://github.com/AIR-Bench/leaderboard) | The code for the leaderboard. |
23
- | [Leaderboard Backend Code](https://github.com/AIR-Bench/leaderboard_backend) | The code for the leaderboard backend. |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: apache-2.0
11
  ---
12
 
13
+ # Start the configuration
14
 
15
+ Most of the variables to change for a default leaderboard are in `src/env.py` (replace the path for your leaderboard) and `src/about.py` (for tasks).
16
 
17
+ Results files should have the following format and be stored as json files:
18
+ ```json
19
+ {
20
+ "config": {
21
+ "model_dtype": "torch.float16", # or torch.bfloat16 or 8bit or 4bit
22
+ "model_name": "path of the model on the hub: org/model",
23
+ "model_sha": "revision on the hub",
24
+ },
25
+ "results": {
26
+ "task_name": {
27
+ "metric_name": score,
28
+ },
29
+ "task_name2": {
30
+ "metric_name": score,
31
+ }
32
+ }
33
+ }
34
+ ```
35
+
36
+ Request files are created automatically by this tool.
37
+
38
+ If you encounter problem on the space, don't hesitate to restart it to remove the create eval-queue, eval-queue-bk, eval-results and eval-results-bk created folder.
39
+
40
+ # Code logic for more complex edits
41
+
42
+ You'll find
43
+ - the main table' columns names and properties in `src/display/utils.py`
44
+ - the logic to read all results and request files, then convert them in dataframe lines, in `src/leaderboard/read_evals.py`, and `src/populate.py`
45
+ - teh logic to allow or filter submissions in `src/submission/submit.py` and `src/submission/check_validity.py`
app.py CHANGED
@@ -1,141 +1,131 @@
1
- import os
2
-
3
  import gradio as gr
4
  from apscheduler.schedulers.background import BackgroundScheduler
5
  from huggingface_hub import snapshot_download
6
 
7
- from src.about import BENCHMARKS_TEXT, EVALUATION_QUEUE_TEXT, INTRODUCTION_TEXT, TITLE
8
- from src.benchmarks import LongDocBenchmarks, QABenchmarks
9
- from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
10
- from src.components import (
11
- get_anonymous_checkbox,
12
- get_domain_dropdown,
13
- get_language_dropdown,
14
- get_leaderboard_table,
15
- get_metric_dropdown,
16
- get_noreranking_dropdown,
17
- get_reranking_dropdown,
18
- get_revision_and_ts_checkbox,
19
- get_search_bar,
20
- get_version_dropdown,
 
 
 
 
 
 
 
 
21
  )
22
- from src.css_html_js import custom_css
23
  from src.envs import (
24
  API,
25
- BENCHMARK_VERSION_LIST,
26
- DEFAULT_METRIC_LONG_DOC,
27
- DEFAULT_METRIC_QA,
28
  EVAL_RESULTS_PATH,
29
- LATEST_BENCHMARK_VERSION,
30
- METRIC_LIST,
31
  REPO_ID,
32
  RESULTS_REPO,
33
  TOKEN,
 
 
 
34
  )
35
- from src.loaders import load_eval_results
36
- from src.models import TaskType, model_hyperlink
37
- from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def restart_space():
41
  API.restart_space(repo_id=REPO_ID)
42
 
43
 
44
  try:
45
- if not os.environ.get("LOCAL_MODE", False):
46
- print("Running in local mode")
47
- snapshot_download(
48
- repo_id=RESULTS_REPO,
49
- local_dir=EVAL_RESULTS_PATH,
50
- repo_type="dataset",
51
- tqdm_class=None,
52
- etag_timeout=30,
53
- token=TOKEN,
54
- )
55
- except Exception:
56
- print("failed to download")
57
  restart_space()
58
 
59
- global ds_dict
60
- ds_dict = load_eval_results(EVAL_RESULTS_PATH)
61
- global datastore
62
- datastore = ds_dict[LATEST_BENCHMARK_VERSION]
63
-
64
-
65
- def update_qa_metric(
66
- metric: str,
67
- domains: list,
68
- langs: list,
69
- reranking_model: list,
70
- query: str,
71
- show_anonymous: bool,
72
- show_revision_and_timestamp: bool,
73
- ):
74
- global datastore
75
- return update_metric(
76
- datastore,
77
- TaskType.qa,
78
- metric,
79
- domains,
80
- langs,
81
- reranking_model,
82
- query,
83
- show_anonymous,
 
 
 
 
 
 
 
 
 
84
  show_revision_and_timestamp,
85
- )
86
-
87
-
88
- def update_doc_metric(
89
- metric: str,
90
- domains: list,
91
- langs: list,
92
- reranking_model: list,
93
- query: str,
94
- show_anonymous: bool,
95
- show_revision_and_timestamp,
96
  ):
97
- global datastore
98
- return update_metric(
99
- datastore,
100
- TaskType.long_doc,
101
- metric,
102
- domains,
103
- langs,
104
- reranking_model,
105
- query,
106
- show_anonymous,
107
  show_revision_and_timestamp,
108
- )
109
-
110
-
111
- def update_qa_version(version):
112
- global datastore
113
- global ds_dict
114
- datastore = ds_dict[version]
115
- domain_elem = get_domain_dropdown(QABenchmarks[datastore.slug])
116
- lang_elem = get_language_dropdown(QABenchmarks[datastore.slug])
117
- model_elem = get_reranking_dropdown(datastore.reranking_models)
118
- df_elem = get_leaderboard_table(datastore.qa_fmt_df, datastore.qa_types)
119
- hidden_df_elem = get_leaderboard_table(datastore.qa_raw_df, datastore.qa_types, visible=False)
120
- return domain_elem, lang_elem, model_elem, df_elem, hidden_df_elem
121
-
122
-
123
- def update_doc_version(version):
124
- global datastore
125
- global ds_dict
126
- datastore = ds_dict[version]
127
- domain_elem = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
128
- lang_elem = get_language_dropdown(LongDocBenchmarks[datastore.slug])
129
- model_elem = get_reranking_dropdown(datastore.reranking_models)
130
- df_elem = get_leaderboard_table(datastore.doc_fmt_df, datastore.doc_types)
131
- hidden_df_elem = get_leaderboard_table(datastore.doc_raw_df, datastore.doc_types, visible=False)
132
- return domain_elem, lang_elem, model_elem, df_elem, hidden_df_elem
133
 
134
 
135
  demo = gr.Blocks(css=custom_css)
136
 
137
- BM25_LINK = model_hyperlink("https://github.com/castorini/pyserini", "BM25")
138
-
139
  with demo:
140
  gr.HTML(TITLE)
141
  gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
@@ -143,24 +133,25 @@ with demo:
143
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
144
  with gr.TabItem("Results", elem_id="results-tab-table"):
145
  with gr.Row():
146
- version = get_version_dropdown()
147
 
148
  with gr.TabItem("QA", elem_id="qa-benchmark-tab-table", id=0):
149
  with gr.Row():
150
  with gr.Column(min_width=320):
151
  # select domain
152
  with gr.Row():
153
- domains = get_domain_dropdown(QABenchmarks[datastore.slug])
154
  # select language
155
  with gr.Row():
156
- langs = get_language_dropdown(QABenchmarks[datastore.slug])
 
157
  with gr.Column():
158
  # select the metric
159
- metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_QA)
160
  with gr.Row():
161
  show_anonymous = get_anonymous_checkbox()
162
  with gr.Row():
163
- show_rev_ts = get_revision_and_ts_checkbox()
164
  with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
165
  with gr.TabItem("Retrieval + Reranking", id=10):
166
  with gr.Row():
@@ -169,327 +160,273 @@ with demo:
169
  search_bar = get_search_bar()
170
  # select reranking models
171
  with gr.Column():
172
- models = get_reranking_dropdown(datastore.reranking_models)
173
- # shown_table
174
- qa_df_elem_ret_rerank = get_leaderboard_table(datastore.qa_fmt_df, datastore.qa_types)
175
  # Dummy leaderboard for handling the case when the user uses backspace key
176
- qa_df_elem_ret_rerank_hidden = get_leaderboard_table(
177
- datastore.qa_raw_df, datastore.qa_types, visible=False
178
- )
179
-
180
- version.change(
181
- update_qa_version,
182
- version,
183
- [domains, langs, models, qa_df_elem_ret_rerank, qa_df_elem_ret_rerank_hidden],
184
- )
185
 
186
  set_listeners(
187
- TaskType.qa,
188
- qa_df_elem_ret_rerank,
189
- qa_df_elem_ret_rerank_hidden,
190
  search_bar,
191
- version,
192
- domains,
193
- langs,
194
- models,
195
  show_anonymous,
196
- show_rev_ts,
197
  )
198
 
199
  # set metric listener
200
- metric.change(
201
- update_qa_metric,
202
- [metric, domains, langs, models, search_bar, show_anonymous, show_rev_ts],
203
- qa_df_elem_ret_rerank,
204
- queue=True,
 
 
 
 
 
 
 
 
205
  )
206
-
207
  with gr.TabItem("Retrieval Only", id=11):
208
  with gr.Row():
209
  with gr.Column(scale=1):
210
- search_bar_ret = get_search_bar()
211
  with gr.Column(scale=1):
212
- models_ret = get_noreranking_dropdown()
213
-
214
- _qa_df_ret = datastore.qa_fmt_df[datastore.qa_fmt_df[COL_NAME_RERANKING_MODEL] == "NoReranker"]
215
- _qa_df_ret = reset_rank(_qa_df_ret)
216
- qa_df_elem_ret = get_leaderboard_table(_qa_df_ret, datastore.qa_types)
217
-
218
  # Dummy leaderboard for handling the case when the user uses backspace key
219
- _qa_df_ret_hidden = datastore.qa_raw_df[
220
- datastore.qa_raw_df[COL_NAME_RERANKING_MODEL] == "NoReranker"
221
- ]
222
- _qa_df_ret_hidden = reset_rank(_qa_df_ret_hidden)
223
- qa_df_elem_ret_hidden = get_leaderboard_table(
224
- _qa_df_ret_hidden, datastore.qa_types, visible=False
225
- )
226
-
227
- version.change(
228
- update_qa_version,
229
- version,
230
- [
231
- domains,
232
- langs,
233
- models_ret,
234
- qa_df_elem_ret,
235
- qa_df_elem_ret_hidden,
236
- ],
237
- )
238
 
239
  set_listeners(
240
- TaskType.qa,
241
- qa_df_elem_ret,
242
- qa_df_elem_ret_hidden,
243
- search_bar_ret,
244
- version,
245
- domains,
246
- langs,
247
- models_ret,
248
  show_anonymous,
249
- show_rev_ts,
250
  )
251
 
252
- metric.change(
253
- update_qa_metric,
 
254
  [
255
- metric,
256
- domains,
257
- langs,
258
- models_ret,
259
- search_bar_ret,
260
  show_anonymous,
261
- show_rev_ts,
262
  ],
263
- qa_df_elem_ret,
264
- queue=True,
265
  )
266
-
267
  with gr.TabItem("Reranking Only", id=12):
268
- _qa_df_rerank = datastore.qa_fmt_df[datastore.qa_fmt_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
269
- _qa_df_rerank = reset_rank(_qa_df_rerank)
270
- qa_rerank_models = _qa_df_rerank[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
271
  with gr.Row():
272
  with gr.Column(scale=1):
273
- qa_models_rerank = get_reranking_dropdown(qa_rerank_models)
274
  with gr.Column(scale=1):
275
- qa_search_bar_rerank = gr.Textbox(show_label=False, visible=False)
276
- qa_df_elem_rerank = get_leaderboard_table(_qa_df_rerank, datastore.qa_types)
277
-
278
- _qa_df_rerank_hidden = datastore.qa_raw_df[
279
- datastore.qa_raw_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
280
- ]
281
- _qa_df_rerank_hidden = reset_rank(_qa_df_rerank_hidden)
282
- qa_df_elem_rerank_hidden = get_leaderboard_table(
283
- _qa_df_rerank_hidden, datastore.qa_types, visible=False
284
- )
285
-
286
- version.change(
287
- update_qa_version,
288
- version,
289
- [domains, langs, qa_models_rerank, qa_df_elem_rerank, qa_df_elem_rerank_hidden],
290
  )
291
 
292
  set_listeners(
293
- TaskType.qa,
294
- qa_df_elem_rerank,
295
- qa_df_elem_rerank_hidden,
296
- qa_search_bar_rerank,
297
- version,
298
- domains,
299
- langs,
300
- qa_models_rerank,
301
  show_anonymous,
302
- show_rev_ts,
303
  )
304
-
305
- metric.change(
306
- update_qa_metric,
307
  [
308
- metric,
309
- domains,
310
- langs,
311
- qa_models_rerank,
312
- qa_search_bar_rerank,
313
  show_anonymous,
314
- show_rev_ts,
315
  ],
316
- qa_df_elem_rerank,
317
- queue=True,
318
  )
319
  with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
320
  with gr.Row():
321
  with gr.Column(min_width=320):
322
  # select domain
323
  with gr.Row():
324
- domains = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
325
  # select language
326
  with gr.Row():
327
- langs = get_language_dropdown(LongDocBenchmarks[datastore.slug])
 
 
328
  with gr.Column():
329
  # select the metric
330
  with gr.Row():
331
- metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_LONG_DOC)
332
  with gr.Row():
333
  show_anonymous = get_anonymous_checkbox()
334
  with gr.Row():
335
- show_rev_ts = get_revision_and_ts_checkbox()
336
- with gr.Tabs(elem_classes="tab-buttons"):
337
  with gr.TabItem("Retrieval + Reranking", id=20):
338
  with gr.Row():
339
  with gr.Column():
340
  search_bar = get_search_bar()
 
341
  with gr.Column():
342
- models = get_reranking_dropdown(datastore.reranking_models)
343
-
344
- doc_df_elem_ret_rerank = get_leaderboard_table(datastore.doc_fmt_df, datastore.doc_types)
345
 
346
- # Dummy leaderboard for handling the case when the user uses backspace key
347
- doc_df_elem_ret_rerank_hidden = get_leaderboard_table(
348
- datastore.doc_raw_df, datastore.doc_types, visible=False
349
  )
350
 
351
- version.change(
352
- update_doc_version,
353
- version,
354
- [domains, langs, models, doc_df_elem_ret_rerank, doc_df_elem_ret_rerank_hidden],
355
  )
356
 
357
  set_listeners(
358
- TaskType.long_doc,
359
- doc_df_elem_ret_rerank,
360
- doc_df_elem_ret_rerank_hidden,
361
  search_bar,
362
- version,
363
- domains,
364
- langs,
365
- models,
366
  show_anonymous,
367
- show_rev_ts,
368
  )
369
 
370
  # set metric listener
371
- metric.change(
372
- update_doc_metric,
373
  [
374
- metric,
375
- domains,
376
- langs,
377
- models,
378
  search_bar,
379
  show_anonymous,
380
- show_rev_ts,
381
  ],
382
- doc_df_elem_ret_rerank,
383
- queue=True,
384
  )
385
  with gr.TabItem("Retrieval Only", id=21):
386
  with gr.Row():
387
  with gr.Column(scale=1):
388
- search_bar_ret = get_search_bar()
389
  with gr.Column(scale=1):
390
- models_ret = get_noreranking_dropdown()
391
-
392
- _doc_df_ret = datastore.doc_fmt_df[
393
- datastore.doc_fmt_df[COL_NAME_RERANKING_MODEL] == "NoReranker"
394
  ]
395
- _doc_df_ret = reset_rank(_doc_df_ret)
396
- doc_df_elem_ret = get_leaderboard_table(_doc_df_ret, datastore.doc_types)
397
-
398
- _doc_df_ret_hidden = datastore.doc_raw_df[
399
- datastore.doc_raw_df[COL_NAME_RERANKING_MODEL] == "NoReranker"
400
  ]
401
- _doc_df_ret_hidden = reset_rank(_doc_df_ret_hidden)
402
- doc_df_elem_ret_hidden = get_leaderboard_table(
403
- _doc_df_ret_hidden, datastore.doc_types, visible=False
404
- )
405
-
406
- version.change(
407
- update_doc_version,
408
- version,
409
- [domains, langs, models_ret, doc_df_elem_ret, doc_df_elem_ret_hidden],
410
  )
411
 
412
  set_listeners(
413
- TaskType.long_doc,
414
- doc_df_elem_ret,
415
- doc_df_elem_ret_hidden,
416
- search_bar_ret,
417
- version,
418
- domains,
419
- langs,
420
- models_ret,
421
  show_anonymous,
422
- show_rev_ts,
423
  )
424
 
425
- metric.change(
426
- update_doc_metric,
427
  [
428
- metric,
429
- domains,
430
- langs,
431
- models_ret,
432
- search_bar_ret,
433
  show_anonymous,
434
- show_rev_ts,
435
  ],
436
- doc_df_elem_ret,
437
- queue=True,
438
  )
439
  with gr.TabItem("Reranking Only", id=22):
440
- _doc_df_rerank = datastore.doc_fmt_df[
441
- datastore.doc_fmt_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
442
- ]
443
- _doc_df_rerank = reset_rank(_doc_df_rerank)
444
- doc_rerank_models = (
445
- _doc_df_rerank[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
446
- )
447
  with gr.Row():
448
  with gr.Column(scale=1):
449
- doc_models_rerank = get_reranking_dropdown(doc_rerank_models)
450
  with gr.Column(scale=1):
451
- doc_search_bar_rerank = gr.Textbox(show_label=False, visible=False)
452
- doc_df_elem_rerank = get_leaderboard_table(_doc_df_rerank, datastore.doc_types)
453
- _doc_df_rerank_hidden = datastore.doc_raw_df[
454
- datastore.doc_raw_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
455
- ]
456
- _doc_df_rerank_hidden = reset_rank(_doc_df_rerank_hidden)
457
- doc_df_elem_rerank_hidden = get_leaderboard_table(
458
- _doc_df_rerank_hidden, datastore.doc_types, visible=False
459
- )
460
-
461
- version.change(
462
- update_doc_version,
463
- version,
464
- [domains, langs, doc_models_rerank, doc_df_elem_rerank, doc_df_elem_rerank_hidden],
465
  )
466
 
467
  set_listeners(
468
- TaskType.long_doc,
469
- doc_df_elem_rerank,
470
- doc_df_elem_rerank_hidden,
471
- doc_search_bar_rerank,
472
- version,
473
- domains,
474
- langs,
475
- doc_models_rerank,
476
  show_anonymous,
477
- show_rev_ts,
478
  )
479
-
480
- metric.change(
481
- update_doc_metric,
482
  [
483
- metric,
484
- domains,
485
- langs,
486
- doc_models_rerank,
487
- doc_search_bar_rerank,
488
  show_anonymous,
489
- show_rev_ts,
490
  ],
491
- doc_df_elem_rerank,
492
- queue=True,
493
  )
494
 
495
  with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
@@ -506,18 +443,23 @@ with demo:
506
  with gr.Row():
507
  with gr.Column():
508
  reranking_model_name = gr.Textbox(
509
- label="Reranking Model name", info="Optional", value="NoReranker"
 
 
510
  )
511
  with gr.Column():
512
- reranking_model_url = gr.Textbox(label="Reranking Model URL", info="Optional", value="")
 
 
 
 
513
  with gr.Row():
514
  with gr.Column():
515
  benchmark_version = gr.Dropdown(
516
  BENCHMARK_VERSION_LIST,
517
  value=LATEST_BENCHMARK_VERSION,
518
  interactive=True,
519
- label="AIR-Bench Version",
520
- )
521
  with gr.Row():
522
  upload_button = gr.UploadButton("Click to upload search results", file_count="single")
523
  with gr.Row():
@@ -526,8 +468,7 @@ with demo:
526
  is_anonymous = gr.Checkbox(
527
  label="Nope. I want to submit anonymously 🥷",
528
  value=False,
529
- info="Do you want to shown on the leaderboard by default?",
530
- )
531
  with gr.Row():
532
  submit_button = gr.Button("Submit")
533
  with gr.Row():
@@ -537,8 +478,7 @@ with demo:
537
  [
538
  upload_button,
539
  ],
540
- file_output,
541
- )
542
  submit_button.click(
543
  submit_results,
544
  [
@@ -548,10 +488,10 @@ with demo:
548
  reranking_model_name,
549
  reranking_model_url,
550
  benchmark_version,
551
- is_anonymous,
552
  ],
553
  submission_result,
554
- show_progress="hidden",
555
  )
556
 
557
  with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
 
 
 
1
  import gradio as gr
2
  from apscheduler.schedulers.background import BackgroundScheduler
3
  from huggingface_hub import snapshot_download
4
 
5
+ from src.about import (
6
+ INTRODUCTION_TEXT,
7
+ BENCHMARKS_TEXT,
8
+ TITLE,
9
+ EVALUATION_QUEUE_TEXT
10
+ )
11
+ from src.benchmarks import (
12
+ DOMAIN_COLS_QA,
13
+ LANG_COLS_QA,
14
+ DOMAIN_COLS_LONG_DOC,
15
+ LANG_COLS_LONG_DOC,
16
+ METRIC_LIST,
17
+ DEFAULT_METRIC_QA,
18
+ DEFAULT_METRIC_LONG_DOC
19
+ )
20
+ from src.display.css_html_js import custom_css
21
+ from src.display.utils import (
22
+ COL_NAME_IS_ANONYMOUS,
23
+ COL_NAME_REVISION,
24
+ COL_NAME_TIMESTAMP,
25
+ COL_NAME_RERANKING_MODEL,
26
+ COL_NAME_RETRIEVAL_MODEL
27
  )
 
28
  from src.envs import (
29
  API,
 
 
 
30
  EVAL_RESULTS_PATH,
 
 
31
  REPO_ID,
32
  RESULTS_REPO,
33
  TOKEN,
34
+ BM25_LINK,
35
+ BENCHMARK_VERSION_LIST,
36
+ LATEST_BENCHMARK_VERSION
37
  )
38
+ from src.read_evals import (
39
+ get_raw_eval_results,
40
+ get_leaderboard_df
41
+ )
42
+ from src.utils import (
43
+ update_metric,
44
+ upload_file,
45
+ get_default_cols,
46
+ submit_results,
47
+ reset_rank,
48
+ remove_html
49
+ )
50
+ from src.display.gradio_formatting import (
51
+ get_version_dropdown,
52
+ get_search_bar,
53
+ get_reranking_dropdown,
54
+ get_metric_dropdown,
55
+ get_domain_dropdown,
56
+ get_language_dropdown,
57
+ get_anonymous_checkbox,
58
+ get_revision_and_ts_checkbox,
59
+ get_leaderboard_table,
60
+ get_noreranking_dropdown
61
+ )
62
+ from src.display.gradio_listener import set_listeners
63
 
64
  def restart_space():
65
  API.restart_space(repo_id=REPO_ID)
66
 
67
 
68
  try:
69
+ snapshot_download(
70
+ repo_id=RESULTS_REPO, local_dir=EVAL_RESULTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30,
71
+ token=TOKEN
72
+ )
73
+ except Exception as e:
74
+ print(f'failed to download')
 
 
 
 
 
 
75
  restart_space()
76
 
77
+ raw_data = get_raw_eval_results(f"{EVAL_RESULTS_PATH}/{LATEST_BENCHMARK_VERSION}")
78
+
79
+ original_df_qa = get_leaderboard_df(
80
+ raw_data, task='qa', metric=DEFAULT_METRIC_QA)
81
+ original_df_long_doc = get_leaderboard_df(
82
+ raw_data, task='long-doc', metric=DEFAULT_METRIC_LONG_DOC)
83
+ print(f'raw data: {len(raw_data)}')
84
+ print(f'QA data loaded: {original_df_qa.shape}')
85
+ print(f'Long-Doc data loaded: {len(original_df_long_doc)}')
86
+
87
+ leaderboard_df_qa = original_df_qa.copy()
88
+ # leaderboard_df_qa = leaderboard_df_qa[has_no_nan_values(df, _benchmark_cols)]
89
+ shown_columns_qa, types_qa = get_default_cols(
90
+ 'qa', leaderboard_df_qa.columns, add_fix_cols=True)
91
+ leaderboard_df_qa = leaderboard_df_qa[~leaderboard_df_qa[COL_NAME_IS_ANONYMOUS]][shown_columns_qa]
92
+ leaderboard_df_qa.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
93
+
94
+ leaderboard_df_long_doc = original_df_long_doc.copy()
95
+ shown_columns_long_doc, types_long_doc = get_default_cols(
96
+ 'long-doc', leaderboard_df_long_doc.columns, add_fix_cols=True)
97
+ leaderboard_df_long_doc = leaderboard_df_long_doc[~leaderboard_df_long_doc[COL_NAME_IS_ANONYMOUS]][shown_columns_long_doc]
98
+ leaderboard_df_long_doc.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
99
+
100
+ # select reranking model
101
+ reranking_models = sorted(list(frozenset([eval_result.reranking_model for eval_result in raw_data])))
102
+
103
+
104
+ def update_metric_qa(
105
+ metric: str,
106
+ domains: list,
107
+ langs: list,
108
+ reranking_model: list,
109
+ query: str,
110
+ show_anonymous: bool,
111
  show_revision_and_timestamp,
 
 
 
 
 
 
 
 
 
 
 
112
  ):
113
+ return update_metric(raw_data, 'qa', metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
114
+
115
+ def update_metric_long_doc(
116
+ metric: str,
117
+ domains: list,
118
+ langs: list,
119
+ reranking_model: list,
120
+ query: str,
121
+ show_anonymous: bool,
 
122
  show_revision_and_timestamp,
123
+ ):
124
+ return update_metric(raw_data, "long-doc", metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
 
127
  demo = gr.Blocks(css=custom_css)
128
 
 
 
129
  with demo:
130
  gr.HTML(TITLE)
131
  gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
 
133
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
134
  with gr.TabItem("Results", elem_id="results-tab-table"):
135
  with gr.Row():
136
+ selected_version = get_version_dropdown()
137
 
138
  with gr.TabItem("QA", elem_id="qa-benchmark-tab-table", id=0):
139
  with gr.Row():
140
  with gr.Column(min_width=320):
141
  # select domain
142
  with gr.Row():
143
+ selected_domains = get_domain_dropdown(DOMAIN_COLS_QA, DOMAIN_COLS_QA)
144
  # select language
145
  with gr.Row():
146
+ selected_langs = get_language_dropdown(LANG_COLS_QA, LANG_COLS_QA)
147
+
148
  with gr.Column():
149
  # select the metric
150
+ selected_metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_QA)
151
  with gr.Row():
152
  show_anonymous = get_anonymous_checkbox()
153
  with gr.Row():
154
+ show_revision_and_timestamp = get_revision_and_ts_checkbox()
155
  with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
156
  with gr.TabItem("Retrieval + Reranking", id=10):
157
  with gr.Row():
 
160
  search_bar = get_search_bar()
161
  # select reranking models
162
  with gr.Column():
163
+ selected_rerankings = get_reranking_dropdown(reranking_models)
164
+ leaderboard_table = get_leaderboard_table(leaderboard_df_qa, types_qa)
 
165
  # Dummy leaderboard for handling the case when the user uses backspace key
166
+ hidden_leaderboard_table_for_search = get_leaderboard_table(original_df_qa, types_qa, visible=False)
 
 
 
 
 
 
 
 
167
 
168
  set_listeners(
169
+ "qa",
170
+ leaderboard_table,
171
+ hidden_leaderboard_table_for_search,
172
  search_bar,
173
+ selected_domains,
174
+ selected_langs,
175
+ selected_rerankings,
 
176
  show_anonymous,
177
+ show_revision_and_timestamp,
178
  )
179
 
180
  # set metric listener
181
+ selected_metric.change(
182
+ update_metric_qa,
183
+ [
184
+ selected_metric,
185
+ selected_domains,
186
+ selected_langs,
187
+ selected_rerankings,
188
+ search_bar,
189
+ show_anonymous,
190
+ show_revision_and_timestamp,
191
+ ],
192
+ leaderboard_table,
193
+ queue=True
194
  )
 
195
  with gr.TabItem("Retrieval Only", id=11):
196
  with gr.Row():
197
  with gr.Column(scale=1):
198
+ search_bar_retriever = get_search_bar()
199
  with gr.Column(scale=1):
200
+ selected_noreranker = get_noreranking_dropdown()
201
+ lb_df_retriever = leaderboard_df_qa[leaderboard_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"]
202
+ lb_df_retriever = reset_rank(lb_df_retriever)
203
+ lb_table_retriever = get_leaderboard_table(lb_df_retriever, types_qa)
 
 
204
  # Dummy leaderboard for handling the case when the user uses backspace key
205
+ hidden_lb_df_retriever = original_df_qa[original_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"]
206
+ hidden_lb_df_retriever = reset_rank(hidden_lb_df_retriever)
207
+ hidden_lb_table_retriever = get_leaderboard_table(hidden_lb_df_retriever, types_qa, visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  set_listeners(
210
+ "qa",
211
+ lb_table_retriever,
212
+ hidden_lb_table_retriever,
213
+ search_bar_retriever,
214
+ selected_domains,
215
+ selected_langs,
216
+ selected_noreranker,
 
217
  show_anonymous,
218
+ show_revision_and_timestamp,
219
  )
220
 
221
+ # set metric listener
222
+ selected_metric.change(
223
+ update_metric_qa,
224
  [
225
+ selected_metric,
226
+ selected_domains,
227
+ selected_langs,
228
+ selected_noreranker,
229
+ search_bar_retriever,
230
  show_anonymous,
231
+ show_revision_and_timestamp,
232
  ],
233
+ lb_table_retriever,
234
+ queue=True
235
  )
 
236
  with gr.TabItem("Reranking Only", id=12):
237
+ lb_df_reranker = leaderboard_df_qa[leaderboard_df_qa[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
238
+ lb_df_reranker = reset_rank(lb_df_reranker)
239
+ reranking_models_reranker = lb_df_reranker[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
240
  with gr.Row():
241
  with gr.Column(scale=1):
242
+ selected_rerankings_reranker = get_reranking_dropdown(reranking_models_reranker)
243
  with gr.Column(scale=1):
244
+ search_bar_reranker = gr.Textbox(show_label=False, visible=False)
245
+ lb_table_reranker = get_leaderboard_table(lb_df_reranker, types_qa)
246
+ hidden_lb_df_reranker = original_df_qa[original_df_qa[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
247
+ hidden_lb_df_reranker = reset_rank(hidden_lb_df_reranker)
248
+ hidden_lb_table_reranker = get_leaderboard_table(
249
+ hidden_lb_df_reranker, types_qa, visible=False
 
 
 
 
 
 
 
 
 
250
  )
251
 
252
  set_listeners(
253
+ "qa",
254
+ lb_table_reranker,
255
+ hidden_lb_table_reranker,
256
+ search_bar_reranker,
257
+ selected_domains,
258
+ selected_langs,
259
+ selected_rerankings_reranker,
 
260
  show_anonymous,
261
+ show_revision_and_timestamp,
262
  )
263
+ # set metric listener
264
+ selected_metric.change(
265
+ update_metric_qa,
266
  [
267
+ selected_metric,
268
+ selected_domains,
269
+ selected_langs,
270
+ selected_rerankings_reranker,
271
+ search_bar_reranker,
272
  show_anonymous,
273
+ show_revision_and_timestamp,
274
  ],
275
+ lb_table_reranker,
276
+ queue=True
277
  )
278
  with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
279
  with gr.Row():
280
  with gr.Column(min_width=320):
281
  # select domain
282
  with gr.Row():
283
+ selected_domains = get_domain_dropdown(DOMAIN_COLS_LONG_DOC, DOMAIN_COLS_LONG_DOC)
284
  # select language
285
  with gr.Row():
286
+ selected_langs = get_language_dropdown(
287
+ LANG_COLS_LONG_DOC, LANG_COLS_LONG_DOC
288
+ )
289
  with gr.Column():
290
  # select the metric
291
  with gr.Row():
292
+ selected_metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_LONG_DOC)
293
  with gr.Row():
294
  show_anonymous = get_anonymous_checkbox()
295
  with gr.Row():
296
+ show_revision_and_timestamp = get_revision_and_ts_checkbox()
297
+ with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
298
  with gr.TabItem("Retrieval + Reranking", id=20):
299
  with gr.Row():
300
  with gr.Column():
301
  search_bar = get_search_bar()
302
+ # select reranking model
303
  with gr.Column():
304
+ selected_rerankings = get_reranking_dropdown(reranking_models)
 
 
305
 
306
+ lb_table = get_leaderboard_table(
307
+ leaderboard_df_long_doc, types_long_doc
 
308
  )
309
 
310
+ # Dummy leaderboard for handling the case when the user uses backspace key
311
+ hidden_lb_table_for_search = get_leaderboard_table(
312
+ original_df_long_doc, types_long_doc, visible=False
 
313
  )
314
 
315
  set_listeners(
316
+ "long-doc",
317
+ lb_table,
318
+ hidden_lb_table_for_search,
319
  search_bar,
320
+ selected_domains,
321
+ selected_langs,
322
+ selected_rerankings,
 
323
  show_anonymous,
324
+ show_revision_and_timestamp,
325
  )
326
 
327
  # set metric listener
328
+ selected_metric.change(
329
+ update_metric_long_doc,
330
  [
331
+ selected_metric,
332
+ selected_domains,
333
+ selected_langs,
334
+ selected_rerankings,
335
  search_bar,
336
  show_anonymous,
337
+ show_revision_and_timestamp
338
  ],
339
+ lb_table,
340
+ queue=True
341
  )
342
  with gr.TabItem("Retrieval Only", id=21):
343
  with gr.Row():
344
  with gr.Column(scale=1):
345
+ search_bar_retriever = get_search_bar()
346
  with gr.Column(scale=1):
347
+ selected_noreranker = get_noreranking_dropdown()
348
+ lb_df_retriever_long_doc = leaderboard_df_long_doc[
349
+ leaderboard_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
 
350
  ]
351
+ lb_df_retriever_long_doc = reset_rank(lb_df_retriever_long_doc)
352
+ hidden_lb_db_retriever_long_doc = original_df_long_doc[
353
+ original_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
 
 
354
  ]
355
+ hidden_lb_db_retriever_long_doc = reset_rank(hidden_lb_db_retriever_long_doc)
356
+ lb_table_retriever_long_doc = get_leaderboard_table(
357
+ lb_df_retriever_long_doc, types_long_doc)
358
+ hidden_lb_table_retriever_long_doc = get_leaderboard_table(
359
+ hidden_lb_db_retriever_long_doc, types_long_doc, visible=False
 
 
 
 
360
  )
361
 
362
  set_listeners(
363
+ "long-doc",
364
+ lb_table_retriever_long_doc,
365
+ hidden_lb_table_retriever_long_doc,
366
+ search_bar_retriever,
367
+ selected_domains,
368
+ selected_langs,
369
+ selected_noreranker,
 
370
  show_anonymous,
371
+ show_revision_and_timestamp,
372
  )
373
 
374
+ selected_metric.change(
375
+ update_metric_long_doc,
376
  [
377
+ selected_metric,
378
+ selected_domains,
379
+ selected_langs,
380
+ selected_noreranker,
381
+ search_bar_retriever,
382
  show_anonymous,
383
+ show_revision_and_timestamp,
384
  ],
385
+ lb_table_retriever_long_doc,
386
+ queue=True
387
  )
388
  with gr.TabItem("Reranking Only", id=22):
389
+ lb_df_reranker_ldoc = leaderboard_df_long_doc[
390
+ leaderboard_df_long_doc[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
391
+ ]
392
+ lb_df_reranker_ldoc = reset_rank(lb_df_reranker_ldoc)
393
+ reranking_models_reranker_ldoc = lb_df_reranker_ldoc[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
 
 
394
  with gr.Row():
395
  with gr.Column(scale=1):
396
+ selected_rerankings_reranker_ldoc = get_reranking_dropdown(reranking_models_reranker_ldoc)
397
  with gr.Column(scale=1):
398
+ search_bar_reranker_ldoc = gr.Textbox(show_label=False, visible=False)
399
+ lb_table_reranker_ldoc = get_leaderboard_table(lb_df_reranker_ldoc, types_long_doc)
400
+ hidden_lb_df_reranker_ldoc = original_df_long_doc[original_df_long_doc[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
401
+ hidden_lb_df_reranker_ldoc = reset_rank(hidden_lb_df_reranker_ldoc)
402
+ hidden_lb_table_reranker_ldoc = get_leaderboard_table(
403
+ hidden_lb_df_reranker_ldoc, types_long_doc, visible=False
 
 
 
 
 
 
 
 
404
  )
405
 
406
  set_listeners(
407
+ "long-doc",
408
+ lb_table_reranker_ldoc,
409
+ hidden_lb_table_reranker_ldoc,
410
+ search_bar_reranker_ldoc,
411
+ selected_domains,
412
+ selected_langs,
413
+ selected_rerankings_reranker_ldoc,
 
414
  show_anonymous,
415
+ show_revision_and_timestamp,
416
  )
417
+ selected_metric.change(
418
+ update_metric_long_doc,
 
419
  [
420
+ selected_metric,
421
+ selected_domains,
422
+ selected_langs,
423
+ selected_rerankings_reranker_ldoc,
424
+ search_bar_reranker_ldoc,
425
  show_anonymous,
426
+ show_revision_and_timestamp,
427
  ],
428
+ lb_table_reranker_ldoc,
429
+ queue=True
430
  )
431
 
432
  with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
 
443
  with gr.Row():
444
  with gr.Column():
445
  reranking_model_name = gr.Textbox(
446
+ label="Reranking Model name",
447
+ info="Optional",
448
+ value="NoReranker"
449
  )
450
  with gr.Column():
451
+ reranking_model_url = gr.Textbox(
452
+ label="Reranking Model URL",
453
+ info="Optional",
454
+ value=""
455
+ )
456
  with gr.Row():
457
  with gr.Column():
458
  benchmark_version = gr.Dropdown(
459
  BENCHMARK_VERSION_LIST,
460
  value=LATEST_BENCHMARK_VERSION,
461
  interactive=True,
462
+ label="AIR-Bench Version")
 
463
  with gr.Row():
464
  upload_button = gr.UploadButton("Click to upload search results", file_count="single")
465
  with gr.Row():
 
468
  is_anonymous = gr.Checkbox(
469
  label="Nope. I want to submit anonymously 🥷",
470
  value=False,
471
+ info="Do you want to shown on the leaderboard by default?")
 
472
  with gr.Row():
473
  submit_button = gr.Button("Submit")
474
  with gr.Row():
 
478
  [
479
  upload_button,
480
  ],
481
+ file_output)
 
482
  submit_button.click(
483
  submit_results,
484
  [
 
488
  reranking_model_name,
489
  reranking_model_url,
490
  benchmark_version,
491
+ is_anonymous
492
  ],
493
  submission_result,
494
+ show_progress="hidden"
495
  )
496
 
497
  with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
pyproject.toml CHANGED
@@ -1,9 +1,9 @@
1
  [tool.ruff]
2
  # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
3
- lint.select = ["E", "F"]
4
- lint.ignore = ["E501"] # line too long (black is taking care of this)
5
  line-length = 119
6
- lint.fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
7
 
8
  [tool.isort]
9
  profile = "black"
 
1
  [tool.ruff]
2
  # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
3
+ select = ["E", "F"]
4
+ ignore = ["E501"] # line too long (black is taking care of this)
5
  line-length = 119
6
+ fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
7
 
8
  [tool.isort]
9
  profile = "black"
requirements.txt CHANGED
@@ -2,7 +2,7 @@ APScheduler>=3.10.1
2
  black>=23.11.0
3
  click>=8.1.3
4
  datasets>=2.14.5
5
- gradio<5.0.0
6
  gradio_client>=0.16.1
7
  huggingface-hub>=0.18.0
8
  numpy>=1.24.2
@@ -12,4 +12,4 @@ requests>=2.31.0
12
  tqdm>=4.65.0
13
  accelerate>=0.24.1
14
  socksio>=1.0.0
15
- air-benchmark>=0.1.0
 
2
  black>=23.11.0
3
  click>=8.1.3
4
  datasets>=2.14.5
5
+ gradio>=4.29.0
6
  gradio_client>=0.16.1
7
  huggingface-hub>=0.18.0
8
  numpy>=1.24.2
 
12
  tqdm>=4.65.0
13
  accelerate>=0.24.1
14
  socksio>=1.0.0
15
+ air-benchmark>=0.0.4
src/about.py CHANGED
@@ -8,7 +8,7 @@ INTRODUCTION_TEXT = """
8
  """
9
 
10
  # Which evaluations are you running? how can people reproduce what you have?
11
- BENCHMARKS_TEXT = """
12
  ## How the test data are generated?
13
  ### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
14
 
 
8
  """
9
 
10
  # Which evaluations are you running? how can people reproduce what you have?
11
+ BENCHMARKS_TEXT = f"""
12
  ## How the test data are generated?
13
  ### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
14
 
src/benchmarks.py CHANGED
@@ -1,71 +1,92 @@
1
  from dataclasses import dataclass
2
  from enum import Enum
3
-
4
  from air_benchmark.tasks.tasks import BenchmarkTable
5
 
6
- from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
7
- from src.models import TaskType, get_safe_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  @dataclass
11
  class Benchmark:
12
  name: str # [domain]_[language]_[metric], task_key in the json file,
13
- metric: str # metric_key in the json file
14
  col_name: str # [domain]_[language], name to display in the leaderboard
15
  domain: str
16
  lang: str
17
  task: str
18
 
19
 
20
- # create a function return an enum class containing all the benchmarks
21
- def get_qa_benchmarks_dict(version: str):
22
- benchmark_dict = {}
23
- for task, domain_dict in BenchmarkTable[version].items():
24
- if task != TaskType.qa.value:
25
- continue
26
- for domain, lang_dict in domain_dict.items():
27
- for lang, dataset_list in lang_dict.items():
28
- benchmark_name = get_safe_name(f"{domain}_{lang}")
29
  col_name = benchmark_name
30
  for metric in dataset_list:
31
- if "test" not in dataset_list[metric]["splits"]:
32
- continue
33
- benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
34
- return benchmark_dict
35
-
36
-
37
- def get_doc_benchmarks_dict(version: str):
38
- benchmark_dict = {}
39
- for task, domain_dict in BenchmarkTable[version].items():
40
- if task != TaskType.long_doc.value:
41
- continue
42
- for domain, lang_dict in domain_dict.items():
43
- for lang, dataset_list in lang_dict.items():
44
  for dataset in dataset_list:
45
  benchmark_name = f"{domain}_{lang}_{dataset}"
46
  benchmark_name = get_safe_name(benchmark_name)
47
  col_name = benchmark_name
48
- if "test" not in dataset_list[dataset]["splits"]:
49
- continue
50
  for metric in METRIC_LIST:
51
- benchmark_dict[benchmark_name] = Benchmark(
52
- benchmark_name, metric, col_name, domain, lang, task
53
- )
54
- return benchmark_dict
55
 
 
 
56
 
57
- _qa_benchmark_dict = {}
58
- for version in BENCHMARK_VERSION_LIST:
59
- safe_version_name = get_safe_name(version)
60
- _qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_qa_benchmarks_dict(version))
61
 
62
- _doc_benchmark_dict = {}
63
- for version in BENCHMARK_VERSION_LIST:
64
- safe_version_name = get_safe_name(version)
65
- _doc_benchmark_dict[safe_version_name] = Enum(
66
- f"LongDocBenchmarks_{safe_version_name}", get_doc_benchmarks_dict(version)
67
- )
68
 
 
 
69
 
70
- QABenchmarks = Enum("QABenchmarks", _qa_benchmark_dict)
71
- LongDocBenchmarks = Enum("LongDocBenchmarks", _doc_benchmark_dict)
 
1
  from dataclasses import dataclass
2
  from enum import Enum
 
3
  from air_benchmark.tasks.tasks import BenchmarkTable
4
 
5
+
6
+ def get_safe_name(name: str):
7
+ """Get RFC 1123 compatible safe name"""
8
+ name = name.replace('-', '_')
9
+ return ''.join(
10
+ character.lower()
11
+ for character in name
12
+ if (character.isalnum() or character == '_'))
13
+
14
+
15
+ METRIC_LIST = [
16
+ "ndcg_at_1",
17
+ "ndcg_at_3",
18
+ "ndcg_at_5",
19
+ "ndcg_at_10",
20
+ "ndcg_at_100",
21
+ "ndcg_at_1000",
22
+ "map_at_1",
23
+ "map_at_3",
24
+ "map_at_5",
25
+ "map_at_10",
26
+ "map_at_100",
27
+ "map_at_1000",
28
+ "recall_at_1",
29
+ "recall_at_3",
30
+ "recall_at_5",
31
+ "recall_at_10",
32
+ "recall_at_100",
33
+ "recall_at_1000",
34
+ "precision_at_1",
35
+ "precision_at_3",
36
+ "precision_at_5",
37
+ "precision_at_10",
38
+ "precision_at_100",
39
+ "precision_at_1000",
40
+ "mrr_at_1",
41
+ "mrr_at_3",
42
+ "mrr_at_5",
43
+ "mrr_at_10",
44
+ "mrr_at_100",
45
+ "mrr_at_1000"
46
+ ]
47
 
48
 
49
  @dataclass
50
  class Benchmark:
51
  name: str # [domain]_[language]_[metric], task_key in the json file,
52
+ metric: str # ndcg_at_1 ,metric_key in the json file
53
  col_name: str # [domain]_[language], name to display in the leaderboard
54
  domain: str
55
  lang: str
56
  task: str
57
 
58
 
59
+ qa_benchmark_dict = {}
60
+ long_doc_benchmark_dict = {}
61
+ for task, domain_dict in BenchmarkTable['AIR-Bench_24.04'].items():
62
+ for domain, lang_dict in domain_dict.items():
63
+ for lang, dataset_list in lang_dict.items():
64
+ if task == "qa":
65
+ benchmark_name = f"{domain}_{lang}"
66
+ benchmark_name = get_safe_name(benchmark_name)
 
67
  col_name = benchmark_name
68
  for metric in dataset_list:
69
+ qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
70
+ elif task == "long-doc":
 
 
 
 
 
 
 
 
 
 
 
71
  for dataset in dataset_list:
72
  benchmark_name = f"{domain}_{lang}_{dataset}"
73
  benchmark_name = get_safe_name(benchmark_name)
74
  col_name = benchmark_name
 
 
75
  for metric in METRIC_LIST:
76
+ long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain,
77
+ lang, task)
 
 
78
 
79
+ BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
80
+ BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)
81
 
82
+ BENCHMARK_COLS_QA = [c.col_name for c in qa_benchmark_dict.values()]
83
+ BENCHMARK_COLS_LONG_DOC = [c.col_name for c in long_doc_benchmark_dict.values()]
 
 
84
 
85
+ DOMAIN_COLS_QA = list(frozenset([c.domain for c in qa_benchmark_dict.values()]))
86
+ LANG_COLS_QA = list(frozenset([c.lang for c in qa_benchmark_dict.values()]))
 
 
 
 
87
 
88
+ DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in long_doc_benchmark_dict.values()]))
89
+ LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in long_doc_benchmark_dict.values()]))
90
 
91
+ DEFAULT_METRIC_QA = "ndcg_at_10"
92
+ DEFAULT_METRIC_LONG_DOC = "recall_at_10"
src/{css_html_js.py → display/css_html_js.py} RENAMED
File without changes
src/display/formatting.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def model_hyperlink(link, model_name):
2
+ return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
3
+
4
+
5
+ def make_clickable_model(model_name: str, model_link: str):
6
+ # link = f"https://huggingface.co/{model_name}"
7
+ if not model_link or not model_link.startswith("https://"):
8
+ return model_name
9
+ return model_hyperlink(model_link, model_name)
10
+
11
+
12
+ def styled_error(error):
13
+ return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
14
+
15
+
16
+ def styled_warning(warn):
17
+ return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
18
+
19
+
20
+ def styled_message(message):
21
+ return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
22
+
23
+
24
+ def has_no_nan_values(df, columns):
25
+ return df[columns].notna().all(axis=1)
26
+
27
+
28
+ def has_nan_values(df, columns):
29
+ return df[columns].isna().any(axis=1)
src/{components.py → display/gradio_formatting.py} RENAMED
@@ -1,14 +1,12 @@
1
  import gradio as gr
2
-
3
  from src.envs import BENCHMARK_VERSION_LIST, LATEST_BENCHMARK_VERSION
4
 
5
-
6
  def get_version_dropdown():
7
  return gr.Dropdown(
8
  choices=BENCHMARK_VERSION_LIST,
9
  value=LATEST_BENCHMARK_VERSION,
10
  label="Select the version of AIR-Bench",
11
- interactive=True,
12
  )
13
 
14
 
@@ -16,25 +14,26 @@ def get_search_bar():
16
  return gr.Textbox(
17
  placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
18
  show_label=False,
19
- info="Search the retrieval methods",
20
  )
21
 
22
 
23
  def get_reranking_dropdown(model_list):
24
- return gr.Dropdown(choices=model_list, label="Select the reranking models", interactive=True, multiselect=True)
 
 
 
 
 
25
 
26
 
27
  def get_noreranking_dropdown():
28
  return gr.Dropdown(
29
- choices=[
30
- "NoReranker",
31
- ],
32
- value=[
33
- "NoReranker",
34
- ],
35
  interactive=False,
36
  multiselect=True,
37
- visible=False,
38
  )
39
 
40
 
@@ -53,10 +52,7 @@ def get_metric_dropdown(metric_list, default_metrics):
53
  )
54
 
55
 
56
- def get_domain_dropdown(benchmarks, default_domains=None):
57
- domain_list = list(frozenset([c.value.domain for c in list(benchmarks.value)]))
58
- if default_domains is None:
59
- default_domains = domain_list
60
  return gr.CheckboxGroup(
61
  choices=domain_list,
62
  value=default_domains,
@@ -65,16 +61,13 @@ def get_domain_dropdown(benchmarks, default_domains=None):
65
  )
66
 
67
 
68
- def get_language_dropdown(benchmarks, default_languages=None):
69
- language_list = list(frozenset([c.value.lang for c in list(benchmarks.value)]))
70
- if default_languages is None:
71
- default_languages = language_list
72
  return gr.Dropdown(
73
  choices=language_list,
74
- value=default_languages,
75
  label="Select the languages",
76
  multiselect=True,
77
- interactive=True,
78
  )
79
 
80
 
@@ -82,13 +75,15 @@ def get_anonymous_checkbox():
82
  return gr.Checkbox(
83
  label="Show anonymous submissions",
84
  value=False,
85
- info="The anonymous submissions might have invalid model information.",
86
  )
87
 
88
 
89
  def get_revision_and_ts_checkbox():
90
  return gr.Checkbox(
91
- label="Show submission details", value=False, info="Show the revision and timestamp information of submissions"
 
 
92
  )
93
 
94
 
 
1
  import gradio as gr
 
2
  from src.envs import BENCHMARK_VERSION_LIST, LATEST_BENCHMARK_VERSION
3
 
 
4
  def get_version_dropdown():
5
  return gr.Dropdown(
6
  choices=BENCHMARK_VERSION_LIST,
7
  value=LATEST_BENCHMARK_VERSION,
8
  label="Select the version of AIR-Bench",
9
+ interactive=True
10
  )
11
 
12
 
 
14
  return gr.Textbox(
15
  placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
16
  show_label=False,
17
+ info="Search the retrieval methods"
18
  )
19
 
20
 
21
  def get_reranking_dropdown(model_list):
22
+ return gr.Dropdown(
23
+ choices=model_list,
24
+ label="Select the reranking models",
25
+ interactive=True,
26
+ multiselect=True
27
+ )
28
 
29
 
30
  def get_noreranking_dropdown():
31
  return gr.Dropdown(
32
+ choices=["NoReranker", ],
33
+ value=["NoReranker", ],
 
 
 
 
34
  interactive=False,
35
  multiselect=True,
36
+ visible=False
37
  )
38
 
39
 
 
52
  )
53
 
54
 
55
+ def get_domain_dropdown(domain_list, default_domains):
 
 
 
56
  return gr.CheckboxGroup(
57
  choices=domain_list,
58
  value=default_domains,
 
61
  )
62
 
63
 
64
+ def get_language_dropdown(language_list, default_languages):
 
 
 
65
  return gr.Dropdown(
66
  choices=language_list,
67
+ value=language_list,
68
  label="Select the languages",
69
  multiselect=True,
70
+ interactive=True
71
  )
72
 
73
 
 
75
  return gr.Checkbox(
76
  label="Show anonymous submissions",
77
  value=False,
78
+ info="The anonymous submissions might have invalid model information."
79
  )
80
 
81
 
82
  def get_revision_and_ts_checkbox():
83
  return gr.Checkbox(
84
+ label="Show submission details",
85
+ value=False,
86
+ info="Show the revision and timestamp information of submissions"
87
  )
88
 
89
 
src/display/gradio_listener.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils import update_table, update_table_long_doc
2
+
3
+
4
+ def set_listeners(
5
+ task,
6
+ displayed_leaderboard,
7
+ hidden_leaderboard,
8
+ search_bar,
9
+ selected_domains,
10
+ selected_langs,
11
+ selected_rerankings,
12
+ show_anonymous,
13
+ show_revision_and_timestamp,
14
+
15
+ ):
16
+ if task == "qa":
17
+ update_table_func = update_table
18
+ elif task == "long-doc":
19
+ update_table_func = update_table_long_doc
20
+ else:
21
+ raise NotImplementedError
22
+ # Set search_bar listener
23
+ search_bar.submit(
24
+ update_table_func,
25
+ [
26
+ hidden_leaderboard, # hidden_leaderboard_table_for_search,
27
+ selected_domains,
28
+ selected_langs,
29
+ selected_rerankings,
30
+ search_bar,
31
+ show_anonymous,
32
+ ],
33
+ displayed_leaderboard
34
+ )
35
+
36
+ # Set column-wise listener
37
+ for selector in [
38
+ selected_domains, selected_langs, show_anonymous, show_revision_and_timestamp, selected_rerankings
39
+ ]:
40
+ selector.change(
41
+ update_table_func,
42
+ [
43
+ hidden_leaderboard,
44
+ selected_domains,
45
+ selected_langs,
46
+ selected_rerankings,
47
+ search_bar,
48
+ show_anonymous,
49
+ show_revision_and_timestamp
50
+ ],
51
+ displayed_leaderboard,
52
+ queue=True,
53
+ )
src/{columns.py → display/utils.py} RENAMED
@@ -1,7 +1,9 @@
1
  from dataclasses import dataclass, make_dataclass
2
 
 
3
 
4
- def _fields(raw_class):
 
5
  return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
6
 
7
 
@@ -17,22 +19,28 @@ class ColumnContent:
17
  never_hidden: bool = False
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  def get_default_auto_eval_column_dict():
21
  auto_eval_column_dict = []
22
- auto_eval_column_dict.append(["rank", ColumnContent, ColumnContent(COL_NAME_RANK, "number", True)])
23
  auto_eval_column_dict.append(
24
- [
25
- "retrieval_model",
26
- ColumnContent,
27
- ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", True, never_hidden=True),
28
- ]
29
  )
30
  auto_eval_column_dict.append(
31
- [
32
- "reranking_model",
33
- ColumnContent,
34
- ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, never_hidden=True),
35
- ]
36
  )
37
  auto_eval_column_dict.append(
38
  ["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
@@ -40,30 +48,14 @@ def get_default_auto_eval_column_dict():
40
  auto_eval_column_dict.append(
41
  ["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
42
  )
43
- auto_eval_column_dict.append(["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)])
44
  auto_eval_column_dict.append(
45
- [
46
- "retrieval_model_link",
47
- ColumnContent,
48
- ColumnContent(
49
- COL_NAME_RETRIEVAL_MODEL_LINK,
50
- "markdown",
51
- False,
52
- hidden=True,
53
- ),
54
- ]
55
  )
56
  auto_eval_column_dict.append(
57
- [
58
- "reranking_model_link",
59
- ColumnContent,
60
- ColumnContent(
61
- COL_NAME_RERANKING_MODEL_LINK,
62
- "markdown",
63
- False,
64
- hidden=True,
65
- ),
66
- ]
67
  )
68
  auto_eval_column_dict.append(
69
  ["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
@@ -71,10 +63,10 @@ def get_default_auto_eval_column_dict():
71
  return auto_eval_column_dict
72
 
73
 
74
- def make_autoevalcolumn(cls_name, benchmarks):
75
  auto_eval_column_dict = get_default_auto_eval_column_dict()
76
- # Leaderboard columns
77
- for benchmark in list(benchmarks.value):
78
  auto_eval_column_dict.append(
79
  [benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
80
  )
@@ -83,24 +75,19 @@ def make_autoevalcolumn(cls_name, benchmarks):
83
  return make_dataclass(cls_name, auto_eval_column_dict, frozen=True)
84
 
85
 
86
- def get_default_col_names_and_types(benchmarks):
87
- AutoEvalColumn = make_autoevalcolumn("AutoEvalColumn", benchmarks)
88
- col_names = [c.name for c in _fields(AutoEvalColumn) if not c.hidden]
89
- col_types = [c.type for c in _fields(AutoEvalColumn) if not c.hidden]
90
- return col_names, col_types
91
 
92
 
93
- def get_fixed_col_names_and_types():
94
- fixed_cols = get_default_auto_eval_column_dict()[:-3]
95
- return [c.name for _, _, c in fixed_cols], [c.type for _, _, c in fixed_cols]
 
 
 
96
 
 
97
 
98
- COL_NAME_AVG = "Average ⬆️"
99
- COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
100
- COL_NAME_RERANKING_MODEL = "Reranking Model"
101
- COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
102
- COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
103
- COL_NAME_RANK = "Rank 🏆"
104
- COL_NAME_REVISION = "Revision"
105
- COL_NAME_TIMESTAMP = "Submission Date"
106
- COL_NAME_IS_ANONYMOUS = "Anonymous Submission"
 
1
  from dataclasses import dataclass, make_dataclass
2
 
3
+ from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
4
 
5
+
6
+ def fields(raw_class):
7
  return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
8
 
9
 
 
19
  never_hidden: bool = False
20
 
21
 
22
+ COL_NAME_AVG = "Average ⬆️"
23
+ COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
24
+ COL_NAME_RERANKING_MODEL = "Reranking Model"
25
+ COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
26
+ COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
27
+ COL_NAME_RANK = "Rank 🏆"
28
+ COL_NAME_REVISION = "Revision"
29
+ COL_NAME_TIMESTAMP = "Submission Date"
30
+ COL_NAME_IS_ANONYMOUS = "Anonymous Submission"
31
+
32
+
33
  def get_default_auto_eval_column_dict():
34
  auto_eval_column_dict = []
35
+ # Init
36
  auto_eval_column_dict.append(
37
+ ["rank", ColumnContent, ColumnContent(COL_NAME_RANK, "number", True)]
 
 
 
 
38
  )
39
  auto_eval_column_dict.append(
40
+ ["retrieval_model", ColumnContent, ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", True, hidden=False, never_hidden=True)]
41
+ )
42
+ auto_eval_column_dict.append(
43
+ ["reranking_model", ColumnContent, ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, hidden=False, never_hidden=True)]
 
44
  )
45
  auto_eval_column_dict.append(
46
  ["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
 
48
  auto_eval_column_dict.append(
49
  ["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
50
  )
 
51
  auto_eval_column_dict.append(
52
+ ["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)]
53
+ )
54
+ auto_eval_column_dict.append(
55
+ ["retrieval_model_link", ColumnContent, ColumnContent(COL_NAME_RETRIEVAL_MODEL_LINK, "markdown", False, hidden=True, never_hidden=False)]
 
 
 
 
 
 
56
  )
57
  auto_eval_column_dict.append(
58
+ ["reranking_model_link", ColumnContent, ColumnContent(COL_NAME_RERANKING_MODEL_LINK, "markdown", False, hidden=True, never_hidden=False)]
 
 
 
 
 
 
 
 
 
59
  )
60
  auto_eval_column_dict.append(
61
  ["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
 
63
  return auto_eval_column_dict
64
 
65
 
66
+ def make_autoevalcolumn(cls_name="BenchmarksQA", benchmarks=BenchmarksQA):
67
  auto_eval_column_dict = get_default_auto_eval_column_dict()
68
+ ## Leaderboard columns
69
+ for benchmark in benchmarks:
70
  auto_eval_column_dict.append(
71
  [benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
72
  )
 
75
  return make_dataclass(cls_name, auto_eval_column_dict, frozen=True)
76
 
77
 
78
+ AutoEvalColumnQA = make_autoevalcolumn(
79
+ "AutoEvalColumnQA", BenchmarksQA)
80
+ AutoEvalColumnLongDoc = make_autoevalcolumn(
81
+ "AutoEvalColumnLongDoc", BenchmarksLongDoc)
 
82
 
83
 
84
+ # Column selection
85
+ COLS_QA = [c.name for c in fields(AutoEvalColumnQA) if not c.hidden]
86
+ COLS_LONG_DOC = [c.name for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
87
+ TYPES_QA = [c.type for c in fields(AutoEvalColumnQA) if not c.hidden]
88
+ TYPES_LONG_DOC = [c.type for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
89
+ COLS_LITE = [c.name for c in fields(AutoEvalColumnQA) if c.displayed_by_default and not c.hidden]
90
 
91
+ QA_BENCHMARK_COLS = [t.value.col_name for t in BenchmarksQA]
92
 
93
+ LONG_DOC_BENCHMARK_COLS = [t.value.col_name for t in BenchmarksLongDoc]
 
 
 
 
 
 
 
 
src/envs.py CHANGED
@@ -1,14 +1,12 @@
1
  import os
2
-
3
  from huggingface_hub import HfApi
4
 
5
  # Info to change for your repository
6
  # ----------------------------------
7
  TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
8
 
9
- OWNER = (
10
- "AIR-Bench" # Change to your org - don't forget to create a results and request dataset, with the correct format!
11
- )
12
  # ----------------------------------
13
 
14
  REPO_ID = f"{OWNER}/leaderboard"
@@ -17,7 +15,7 @@ RESULTS_REPO = f"{OWNER}/eval_results"
17
  # repo for submitting the evaluation
18
  SEARCH_RESULTS_REPO = f"{OWNER}/search_results"
19
 
20
- # If you set up a cache later, just change HF_HOME
21
  CACHE_PATH = os.getenv("HF_HOME", ".")
22
 
23
  # Local caches
@@ -25,43 +23,11 @@ EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval_results")
25
 
26
  API = HfApi(token=TOKEN)
27
 
 
 
28
  BENCHMARK_VERSION_LIST = [
29
  "AIR-Bench_24.04",
30
  "AIR-Bench_24.05",
31
  ]
32
 
33
- LATEST_BENCHMARK_VERSION = BENCHMARK_VERSION_LIST[0]
34
- DEFAULT_METRIC_QA = "ndcg_at_10"
35
- DEFAULT_METRIC_LONG_DOC = "recall_at_10"
36
- METRIC_LIST = [
37
- "ndcg_at_1",
38
- "ndcg_at_3",
39
- "ndcg_at_5",
40
- "ndcg_at_10",
41
- "ndcg_at_100",
42
- "ndcg_at_1000",
43
- "map_at_1",
44
- "map_at_3",
45
- "map_at_5",
46
- "map_at_10",
47
- "map_at_100",
48
- "map_at_1000",
49
- "recall_at_1",
50
- "recall_at_3",
51
- "recall_at_5",
52
- "recall_at_10",
53
- "recall_at_100",
54
- "recall_at_1000",
55
- "precision_at_1",
56
- "precision_at_3",
57
- "precision_at_5",
58
- "precision_at_10",
59
- "precision_at_100",
60
- "precision_at_1000",
61
- "mrr_at_1",
62
- "mrr_at_3",
63
- "mrr_at_5",
64
- "mrr_at_10",
65
- "mrr_at_100",
66
- "mrr_at_1000",
67
- ]
 
1
  import os
2
+ from src.display.formatting import model_hyperlink
3
  from huggingface_hub import HfApi
4
 
5
  # Info to change for your repository
6
  # ----------------------------------
7
  TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
8
 
9
+ OWNER = "AIR-Bench" # "nan" # Change to your org - don't forget to create a results and request dataset, with the correct format!
 
 
10
  # ----------------------------------
11
 
12
  REPO_ID = f"{OWNER}/leaderboard"
 
15
  # repo for submitting the evaluation
16
  SEARCH_RESULTS_REPO = f"{OWNER}/search_results"
17
 
18
+ # If you setup a cache later, just change HF_HOME
19
  CACHE_PATH = os.getenv("HF_HOME", ".")
20
 
21
  # Local caches
 
23
 
24
  API = HfApi(token=TOKEN)
25
 
26
+ BM25_LINK = model_hyperlink("https://github.com/castorini/pyserini", "BM25")
27
+
28
  BENCHMARK_VERSION_LIST = [
29
  "AIR-Bench_24.04",
30
  "AIR-Bench_24.05",
31
  ]
32
 
33
+ LATEST_BENCHMARK_VERSION = BENCHMARK_VERSION_LIST[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/loaders.py DELETED
@@ -1,88 +0,0 @@
1
- import os.path
2
- from pathlib import Path
3
- from typing import Dict, List, Union
4
-
5
- import pandas as pd
6
-
7
- from src.columns import COL_NAME_IS_ANONYMOUS, COL_NAME_REVISION, COL_NAME_TIMESTAMP
8
- from src.envs import BENCHMARK_VERSION_LIST, DEFAULT_METRIC_LONG_DOC, DEFAULT_METRIC_QA
9
- from src.models import FullEvalResult, LeaderboardDataStore, TaskType, get_safe_name
10
- from src.utils import get_default_cols, get_leaderboard_df, reset_rank
11
-
12
- pd.options.mode.copy_on_write = True
13
-
14
-
15
- def load_raw_eval_results(results_path: Union[Path, str]) -> List[FullEvalResult]:
16
- """
17
- Load the evaluation results from a json file
18
- """
19
- model_result_filepaths = []
20
- for root, dirs, files in os.walk(results_path):
21
- if len(files) == 0:
22
- continue
23
-
24
- # select the latest results
25
- for file in files:
26
- if not (file.startswith("results") and file.endswith(".json")):
27
- print(f"skip {file}")
28
- continue
29
- model_result_filepaths.append(os.path.join(root, file))
30
-
31
- eval_results = {}
32
- for model_result_filepath in model_result_filepaths:
33
- # create evaluation results
34
- try:
35
- eval_result = FullEvalResult.init_from_json_file(model_result_filepath)
36
- except UnicodeDecodeError:
37
- print(f"loading file failed. {model_result_filepath}")
38
- continue
39
- print(f"file loaded: {model_result_filepath}")
40
- timestamp = eval_result.timestamp
41
- eval_results[timestamp] = eval_result
42
-
43
- results = []
44
- for k, v in eval_results.items():
45
- try:
46
- v.to_dict()
47
- results.append(v)
48
- except KeyError:
49
- print(f"loading failed: {k}")
50
- continue
51
- return results
52
-
53
-
54
- def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
55
- ds = LeaderboardDataStore(version, get_safe_name(version))
56
- ds.raw_data = load_raw_eval_results(file_path)
57
- print(f"raw data: {len(ds.raw_data)}")
58
-
59
- ds.qa_raw_df = get_leaderboard_df(ds, TaskType.qa, DEFAULT_METRIC_QA)
60
- print(f"QA data loaded: {ds.qa_raw_df.shape}")
61
- ds.qa_fmt_df = ds.qa_raw_df.copy()
62
- qa_cols, ds.qa_types = get_default_cols(TaskType.qa, ds.slug, add_fix_cols=True)
63
- # by default, drop the anonymous submissions
64
- ds.qa_fmt_df = ds.qa_fmt_df[~ds.qa_fmt_df[COL_NAME_IS_ANONYMOUS]][qa_cols]
65
- # reset the rank after dropping the anonymous submissions
66
- ds.qa_fmt_df = reset_rank(ds.qa_fmt_df)
67
- ds.qa_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
68
-
69
- ds.doc_raw_df = get_leaderboard_df(ds, TaskType.long_doc, DEFAULT_METRIC_LONG_DOC)
70
- print(f"Long-Doc data loaded: {len(ds.doc_raw_df)}")
71
- ds.doc_fmt_df = ds.doc_raw_df.copy()
72
- doc_cols, ds.doc_types = get_default_cols(TaskType.long_doc, ds.slug, add_fix_cols=True)
73
- # by default, drop the anonymous submissions
74
- ds.doc_fmt_df = ds.doc_fmt_df[~ds.doc_fmt_df[COL_NAME_IS_ANONYMOUS]][doc_cols]
75
- # reset the rank after dropping the anonymous submissions
76
- ds.doc_fmt_df = reset_rank(ds.doc_fmt_df)
77
- ds.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
78
-
79
- ds.reranking_models = sorted(list(frozenset([eval_result.reranking_model for eval_result in ds.raw_data])))
80
- return ds
81
-
82
-
83
- def load_eval_results(file_path: Union[str, Path]) -> Dict[str, LeaderboardDataStore]:
84
- output = {}
85
- for version in BENCHMARK_VERSION_LIST:
86
- fn = f"{file_path}/{version}"
87
- output[version] = load_leaderboard_datastore(fn, version)
88
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/{models.py → read_evals.py} RENAMED
@@ -1,21 +1,38 @@
1
  import json
 
2
  from collections import defaultdict
3
  from dataclasses import dataclass
4
- from enum import Enum
5
  from typing import List
6
 
7
  import pandas as pd
8
 
9
- from src.columns import (
10
- COL_NAME_IS_ANONYMOUS,
11
  COL_NAME_RERANKING_MODEL,
12
- COL_NAME_RERANKING_MODEL_LINK,
13
  COL_NAME_RETRIEVAL_MODEL,
 
14
  COL_NAME_RETRIEVAL_MODEL_LINK,
15
  COL_NAME_REVISION,
16
  COL_NAME_TIMESTAMP,
 
 
 
 
 
 
 
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @dataclass
21
  class EvalResult:
@@ -23,7 +40,6 @@ class EvalResult:
23
  Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
24
  domains, languages, and datasets
25
  """
26
-
27
  eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
28
  retrieval_model: str
29
  reranking_model: str
@@ -40,7 +56,6 @@ class FullEvalResult:
40
  """
41
  Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
42
  """
43
-
44
  eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
45
  retrieval_model: str
46
  reranking_model: str
@@ -64,6 +79,7 @@ class FullEvalResult:
64
  result_list = []
65
  retrieval_model_link = ""
66
  reranking_model_link = ""
 
67
  for item in model_data:
68
  config = item.get("config", {})
69
  # eval results for different metrics
@@ -82,26 +98,24 @@ class FullEvalResult:
82
  metric=config["metric"],
83
  timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
84
  revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
85
- is_anonymous=config.get("is_anonymous", False),
86
  )
87
  result_list.append(eval_result)
88
- eval_result = result_list[0]
89
  return cls(
90
- eval_name=f"{eval_result.retrieval_model}_{eval_result.reranking_model}",
91
- retrieval_model=eval_result.retrieval_model,
92
- reranking_model=eval_result.reranking_model,
93
  retrieval_model_link=retrieval_model_link,
94
  reranking_model_link=reranking_model_link,
95
  results=result_list,
96
- timestamp=eval_result.timestamp,
97
- revision=eval_result.revision,
98
- is_anonymous=eval_result.is_anonymous,
99
  )
100
 
101
- def to_dict(self, task="qa", metric="ndcg_at_3") -> List:
102
  """
103
- Convert the results in all the EvalResults over different tasks and metrics.
104
- The output is a list of dict compatible with the dataframe UI
105
  """
106
  results = defaultdict(dict)
107
  for eval_result in self.results:
@@ -109,66 +123,106 @@ class FullEvalResult:
109
  continue
110
  if eval_result.task != task:
111
  continue
112
- eval_name = eval_result.eval_name
113
- results[eval_name]["eval_name"] = eval_name
114
- results[eval_name][COL_NAME_RETRIEVAL_MODEL] = make_clickable_model(
115
- self.retrieval_model, self.retrieval_model_link
116
- )
117
- results[eval_name][COL_NAME_RERANKING_MODEL] = make_clickable_model(
118
- self.reranking_model, self.reranking_model_link
119
- )
120
- results[eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
121
- results[eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
122
- results[eval_name][COL_NAME_REVISION] = self.revision
123
- results[eval_name][COL_NAME_TIMESTAMP] = self.timestamp
124
- results[eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
125
-
126
  for result in eval_result.results:
127
  # add result for each domain, language, and dataset
128
  domain = result["domain"]
129
  lang = result["lang"]
130
  dataset = result["dataset"]
131
  value = result["value"] * 100
132
- if dataset == "default":
133
  benchmark_name = f"{domain}_{lang}"
134
  else:
135
  benchmark_name = f"{domain}_{lang}_{dataset}"
136
- results[eval_name][get_safe_name(benchmark_name)] = value
137
  return [v for v in results.values()]
138
 
139
 
140
- @dataclass
141
- class LeaderboardDataStore:
142
- version: str
143
- slug: str
144
- raw_data: list = None
145
- qa_raw_df: pd.DataFrame = pd.DataFrame()
146
- doc_raw_df: pd.DataFrame = pd.DataFrame()
147
- qa_fmt_df: pd.DataFrame = pd.DataFrame()
148
- doc_fmt_df: pd.DataFrame = pd.DataFrame()
149
- reranking_models: list = None
150
- qa_types: list = None
151
- doc_types: list = None
152
-
153
-
154
- # Define an enum class with the name `TaskType`. There are two types of tasks, `qa` and `long-doc`.
155
- class TaskType(Enum):
156
- qa = "qa"
157
- long_doc = "long-doc"
158
-
159
-
160
- def make_clickable_model(model_name: str, model_link: str):
161
- # link = f"https://huggingface.co/{model_name}"
162
- if not model_link or not model_link.startswith("https://"):
163
- return model_name
164
- return model_hyperlink(model_link, model_name)
165
-
166
-
167
- def model_hyperlink(link, model_name):
168
- return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
169
-
170
-
171
- def get_safe_name(name: str):
172
- """Get RFC 1123 compatible safe name"""
173
- name = name.replace("-", "_")
174
- return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import os.path
3
  from collections import defaultdict
4
  from dataclasses import dataclass
 
5
  from typing import List
6
 
7
  import pandas as pd
8
 
9
+ from src.benchmarks import get_safe_name
10
+ from src.display.utils import (
11
  COL_NAME_RERANKING_MODEL,
 
12
  COL_NAME_RETRIEVAL_MODEL,
13
+ COL_NAME_RERANKING_MODEL_LINK,
14
  COL_NAME_RETRIEVAL_MODEL_LINK,
15
  COL_NAME_REVISION,
16
  COL_NAME_TIMESTAMP,
17
+ COL_NAME_IS_ANONYMOUS,
18
+ COLS_QA,
19
+ QA_BENCHMARK_COLS,
20
+ COLS_LONG_DOC,
21
+ LONG_DOC_BENCHMARK_COLS,
22
+ COL_NAME_AVG,
23
+ COL_NAME_RANK
24
  )
25
 
26
+ from src.display.formatting import make_clickable_model
27
+
28
+ pd.options.mode.copy_on_write = True
29
+
30
+ def calculate_mean(row):
31
+ if pd.isna(row).any():
32
+ return -1
33
+ else:
34
+ return row.mean()
35
+
36
 
37
  @dataclass
38
  class EvalResult:
 
40
  Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
41
  domains, languages, and datasets
42
  """
 
43
  eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
44
  retrieval_model: str
45
  reranking_model: str
 
56
  """
57
  Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
58
  """
 
59
  eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
60
  retrieval_model: str
61
  reranking_model: str
 
79
  result_list = []
80
  retrieval_model_link = ""
81
  reranking_model_link = ""
82
+ revision = ""
83
  for item in model_data:
84
  config = item.get("config", {})
85
  # eval results for different metrics
 
98
  metric=config["metric"],
99
  timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
100
  revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
101
+ is_anonymous=config.get("is_anonymous", False)
102
  )
103
  result_list.append(eval_result)
 
104
  return cls(
105
+ eval_name=f"{result_list[0].retrieval_model}_{result_list[0].reranking_model}",
106
+ retrieval_model=result_list[0].retrieval_model,
107
+ reranking_model=result_list[0].reranking_model,
108
  retrieval_model_link=retrieval_model_link,
109
  reranking_model_link=reranking_model_link,
110
  results=result_list,
111
+ timestamp=result_list[0].timestamp,
112
+ revision=result_list[0].revision,
113
+ is_anonymous=result_list[0].is_anonymous
114
  )
115
 
116
+ def to_dict(self, task='qa', metric='ndcg_at_3') -> List:
117
  """
118
+ Convert the results in all the EvalResults over different tasks and metrics. The output is a list of dict compatible with the dataframe UI
 
119
  """
120
  results = defaultdict(dict)
121
  for eval_result in self.results:
 
123
  continue
124
  if eval_result.task != task:
125
  continue
126
+ results[eval_result.eval_name]["eval_name"] = eval_result.eval_name
127
+ results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL] = (
128
+ make_clickable_model(self.retrieval_model, self.retrieval_model_link))
129
+ results[eval_result.eval_name][COL_NAME_RERANKING_MODEL] = (
130
+ make_clickable_model(self.reranking_model, self.reranking_model_link))
131
+ results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
132
+ results[eval_result.eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
133
+ results[eval_result.eval_name][COL_NAME_REVISION] = self.revision
134
+ results[eval_result.eval_name][COL_NAME_TIMESTAMP] = self.timestamp
135
+ results[eval_result.eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
136
+
137
+ # print(f'result loaded: {eval_result.eval_name}')
 
 
138
  for result in eval_result.results:
139
  # add result for each domain, language, and dataset
140
  domain = result["domain"]
141
  lang = result["lang"]
142
  dataset = result["dataset"]
143
  value = result["value"] * 100
144
+ if dataset == 'default':
145
  benchmark_name = f"{domain}_{lang}"
146
  else:
147
  benchmark_name = f"{domain}_{lang}_{dataset}"
148
+ results[eval_result.eval_name][get_safe_name(benchmark_name)] = value
149
  return [v for v in results.values()]
150
 
151
 
152
+ def get_raw_eval_results(results_path: str) -> List[FullEvalResult]:
153
+ """
154
+ Load the evaluation results from a json file
155
+ """
156
+ model_result_filepaths = []
157
+ for root, dirs, files in os.walk(results_path):
158
+ if len(files) == 0:
159
+ continue
160
+
161
+ # select the latest results
162
+ for file in files:
163
+ if not (file.startswith("results") and file.endswith(".json")):
164
+ print(f'skip {file}')
165
+ continue
166
+ model_result_filepaths.append(os.path.join(root, file))
167
+
168
+ eval_results = {}
169
+ for model_result_filepath in model_result_filepaths:
170
+ # create evaluation results
171
+ try:
172
+ eval_result = FullEvalResult.init_from_json_file(model_result_filepath)
173
+ except UnicodeDecodeError as e:
174
+ print(f"loading file failed. {model_result_filepath}")
175
+ continue
176
+ print(f'file loaded: {model_result_filepath}')
177
+ timestamp = eval_result.timestamp
178
+ eval_results[timestamp] = eval_result
179
+
180
+ results = []
181
+ for k, v in eval_results.items():
182
+ try:
183
+ v.to_dict()
184
+ results.append(v)
185
+ except KeyError:
186
+ print(f"loading failed: {k}")
187
+ continue
188
+ return results
189
+
190
+
191
+ def get_leaderboard_df(raw_data: List[FullEvalResult], task: str, metric: str) -> pd.DataFrame:
192
+ """
193
+ Creates a dataframe from all the individual experiment results
194
+ """
195
+ cols = [COL_NAME_IS_ANONYMOUS, ]
196
+ if task == "qa":
197
+ cols += COLS_QA
198
+ benchmark_cols = QA_BENCHMARK_COLS
199
+ elif task == "long-doc":
200
+ cols += COLS_LONG_DOC
201
+ benchmark_cols = LONG_DOC_BENCHMARK_COLS
202
+ else:
203
+ raise NotImplemented
204
+ all_data_json = []
205
+ for v in raw_data:
206
+ all_data_json += v.to_dict(task=task, metric=metric)
207
+ df = pd.DataFrame.from_records(all_data_json)
208
+ # print(f'dataframe created: {df.shape}')
209
+
210
+ _benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
211
+
212
+ # calculate the average score for selected benchmarks
213
+ df[COL_NAME_AVG] = df[list(_benchmark_cols)].apply(calculate_mean, axis=1).round(decimals=2)
214
+ df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
215
+ df.reset_index(inplace=True, drop=True)
216
+
217
+ _cols = frozenset(cols).intersection(frozenset(df.columns.to_list()))
218
+ df = df[_cols].round(decimals=2)
219
+
220
+ # filter out if any of the benchmarks have not been produced
221
+ df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
222
+
223
+ # shorten the revision
224
+ df[COL_NAME_REVISION] = df[COL_NAME_REVISION].str[:6]
225
+
226
+ # # replace "0" with "-" for average score
227
+ # df[COL_NAME_AVG] = df[COL_NAME_AVG].replace(0, "-")
228
+ return df
src/utils.py CHANGED
@@ -1,37 +1,24 @@
1
- import hashlib
2
  import json
3
- import re
4
  from datetime import datetime, timezone
5
  from pathlib import Path
 
6
 
7
  import pandas as pd
8
 
9
- from src.benchmarks import LongDocBenchmarks, QABenchmarks
10
- from src.columns import (
11
- COL_NAME_AVG,
12
- COL_NAME_IS_ANONYMOUS,
13
- COL_NAME_RANK,
14
- COL_NAME_RERANKING_MODEL,
15
- COL_NAME_RETRIEVAL_MODEL,
16
- COL_NAME_REVISION,
17
- COL_NAME_TIMESTAMP,
18
- get_default_col_names_and_types,
19
- get_fixed_col_names_and_types,
20
- )
21
- from src.envs import API, LATEST_BENCHMARK_VERSION, SEARCH_RESULTS_REPO
22
- from src.models import TaskType, get_safe_name
23
-
24
-
25
- def calculate_mean(row):
26
- if pd.isna(row).any():
27
- return -1
28
- else:
29
- return row.mean()
30
 
31
 
32
  def remove_html(input_str):
33
  # Regular expression for finding HTML tags
34
- clean = re.sub(r"<.*?>", "", input_str)
35
  return clean
36
 
37
 
@@ -68,152 +55,160 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
68
  return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
69
 
70
 
71
- def get_default_cols(task: TaskType, version_slug, add_fix_cols: bool = True) -> tuple:
72
  cols = []
73
  types = []
74
- if task == TaskType.qa:
75
- benchmarks = QABenchmarks[version_slug]
76
- elif task == TaskType.long_doc:
77
- benchmarks = LongDocBenchmarks[version_slug]
 
 
 
 
78
  else:
79
- raise NotImplementedError
80
- cols_list, types_list = get_default_col_names_and_types(benchmarks)
81
- benchmark_list = [c.value.col_name for c in list(benchmarks.value)]
82
  for col_name, col_type in zip(cols_list, types_list):
83
  if col_name not in benchmark_list:
84
  continue
 
 
85
  cols.append(col_name)
86
  types.append(col_type)
 
87
  if add_fix_cols:
88
  _cols = []
89
  _types = []
90
- fixed_cols, fixed_cols_types = get_fixed_col_names_and_types()
91
  for col_name, col_type in zip(cols, types):
92
- if col_name in fixed_cols:
93
  continue
94
  _cols.append(col_name)
95
  _types.append(col_type)
96
- cols = fixed_cols + _cols
97
- types = fixed_cols_types + _types
98
  return cols, types
99
 
100
 
101
- def get_selected_cols(task, version_slug, domains, languages):
102
- cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
 
 
 
 
 
 
 
 
 
 
 
 
103
  selected_cols = []
104
  for c in cols:
105
- if task == TaskType.qa:
106
- eval_col = QABenchmarks[version_slug].value[c].value
107
- elif task == TaskType.long_doc:
108
- eval_col = LongDocBenchmarks[version_slug].value[c].value
109
- else:
110
- raise NotImplementedError
111
- if eval_col.domain not in domains:
112
  continue
113
- if eval_col.lang not in languages:
114
  continue
115
  selected_cols.append(c)
116
  # We use COLS to maintain sorting
117
- return selected_cols
118
-
119
-
120
- def select_columns(
121
- df: pd.DataFrame,
122
- domains: list,
123
- languages: list,
124
- task: TaskType = TaskType.qa,
125
- reset_ranking: bool = True,
126
- version_slug: str = None,
127
- ) -> pd.DataFrame:
128
- selected_cols = get_selected_cols(task, version_slug, domains, languages)
129
- fixed_cols, _ = get_fixed_col_names_and_types()
130
- filtered_df = df[fixed_cols + selected_cols]
131
- filtered_df.replace({"": pd.NA}, inplace=True)
132
  if reset_ranking:
133
  filtered_df[COL_NAME_AVG] = filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
134
  filtered_df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
135
  filtered_df.reset_index(inplace=True, drop=True)
136
  filtered_df = reset_rank(filtered_df)
 
137
  return filtered_df
138
 
139
 
140
- def _update_df_elem(
141
- task: TaskType,
142
- version: str,
143
- source_df: pd.DataFrame,
144
- domains: list,
145
- langs: list,
146
- reranking_query: list,
147
- query: str,
148
- show_anonymous: bool,
149
- reset_ranking: bool = True,
150
- show_revision_and_timestamp: bool = False,
151
  ):
152
- filtered_df = source_df.copy()
153
  if not show_anonymous:
154
  filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
155
  filtered_df = filter_models(filtered_df, reranking_query)
156
  filtered_df = filter_queries(query, filtered_df)
157
- filtered_df = select_columns(filtered_df, domains, langs, task, reset_ranking, get_safe_name(version))
158
  if not show_revision_and_timestamp:
159
  filtered_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
160
  return filtered_df
161
 
162
 
163
- def update_doc_df_elem(
164
- version: str,
165
- hidden_df: pd.DataFrame,
166
- domains: list,
167
- langs: list,
168
- reranking_query: list,
169
- query: str,
170
- show_anonymous: bool,
171
- show_revision_and_timestamp: bool = False,
172
- reset_ranking: bool = True,
173
  ):
174
- return _update_df_elem(
175
- TaskType.long_doc,
176
- version,
177
- hidden_df,
178
- domains,
179
- langs,
180
- reranking_query,
181
- query,
182
- show_anonymous,
183
- reset_ranking,
184
- show_revision_and_timestamp,
185
- )
 
 
 
 
 
186
 
187
 
188
  def update_metric(
189
- datastore,
190
- task: TaskType,
191
- metric: str,
192
- domains: list,
193
- langs: list,
194
- reranking_model: list,
195
- query: str,
196
- show_anonymous: bool = False,
197
- show_revision_and_timestamp: bool = False,
198
  ) -> pd.DataFrame:
199
- if task == TaskType.qa:
200
- update_func = update_qa_df_elem
201
- elif task == TaskType.long_doc:
202
- update_func = update_doc_df_elem
203
- else:
204
- raise NotImplementedError
205
- df_elem = get_leaderboard_df(datastore, task=task, metric=metric)
206
- version = datastore.version
207
- return update_func(
208
- version,
209
- df_elem,
210
- domains,
211
- langs,
212
- reranking_model,
213
- query,
214
- show_anonymous,
215
- show_revision_and_timestamp,
216
- )
 
 
 
 
217
 
218
 
219
  def upload_file(filepath: str):
@@ -223,6 +218,7 @@ def upload_file(filepath: str):
223
  return filepath
224
 
225
 
 
226
  def get_iso_format_timestamp():
227
  # Get the current timestamp with UTC as the timezone
228
  current_timestamp = datetime.now(timezone.utc)
@@ -231,15 +227,15 @@ def get_iso_format_timestamp():
231
  current_timestamp = current_timestamp.replace(microsecond=0)
232
 
233
  # Convert to ISO 8601 format and replace the offset with 'Z'
234
- iso_format_timestamp = current_timestamp.isoformat().replace("+00:00", "Z")
235
- filename_friendly_timestamp = current_timestamp.strftime("%Y%m%d%H%M%S")
236
  return iso_format_timestamp, filename_friendly_timestamp
237
 
238
 
239
  def calculate_file_md5(file_path):
240
  md5 = hashlib.md5()
241
 
242
- with open(file_path, "rb") as f:
243
  while True:
244
  data = f.read(4096)
245
  if not data:
@@ -250,14 +246,13 @@ def calculate_file_md5(file_path):
250
 
251
 
252
  def submit_results(
253
- filepath: str,
254
- model: str,
255
- model_url: str,
256
- reranking_model: str = "",
257
- reranking_model_url: str = "",
258
- version: str = LATEST_BENCHMARK_VERSION,
259
- is_anonymous=False,
260
- ):
261
  if not filepath.endswith(".zip"):
262
  return styled_error(f"file uploading aborted. wrong file type: {filepath}")
263
 
@@ -270,13 +265,11 @@ def submit_results(
270
  if not model_url.startswith("https://") and not model_url.startswith("http://"):
271
  # TODO: retrieve the model page and find the model name on the page
272
  return styled_error(
273
- f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
274
- )
275
  if reranking_model != "NoReranker":
276
  if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
277
  return styled_error(
278
- f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
279
- )
280
 
281
  # rename the uploaded file
282
  input_fp = Path(filepath)
@@ -286,15 +279,14 @@ def submit_results(
286
  input_folder_path = input_fp.parent
287
 
288
  if not reranking_model:
289
- reranking_model = "NoReranker"
290
-
291
  API.upload_file(
292
  path_or_fileobj=filepath,
293
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
294
  repo_id=SEARCH_RESULTS_REPO,
295
  repo_type="dataset",
296
- commit_message=f"feat: submit {model} to evaluate",
297
- )
298
 
299
  output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
300
  output_config = {
@@ -305,7 +297,7 @@ def submit_results(
305
  "version": f"{version}",
306
  "is_anonymous": is_anonymous,
307
  "revision": f"{revision}",
308
- "timestamp": f"{timestamp_config}",
309
  }
310
  with open(input_folder_path / output_config_fn, "w") as f:
311
  json.dump(output_config, f, indent=4, ensure_ascii=False)
@@ -314,8 +306,7 @@ def submit_results(
314
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
315
  repo_id=SEARCH_RESULTS_REPO,
316
  repo_type="dataset",
317
- commit_message=f"feat: submit {model} + {reranking_model} config",
318
- )
319
  return styled_message(
320
  f"Thanks for submission!\n"
321
  f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
@@ -325,125 +316,3 @@ def submit_results(
325
  def reset_rank(df):
326
  df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
327
  return df
328
-
329
-
330
- def get_leaderboard_df(datastore, task: TaskType, metric: str) -> pd.DataFrame:
331
- """
332
- Creates a dataframe from all the individual experiment results
333
- """
334
- # load the selected metrics into a DataFrame from the raw json
335
- all_data_json = []
336
- for v in datastore.raw_data:
337
- all_data_json += v.to_dict(task=task.value, metric=metric)
338
- df = pd.DataFrame.from_records(all_data_json)
339
-
340
- # calculate the average scores for selected task
341
- if task == TaskType.qa:
342
- benchmarks = QABenchmarks[datastore.slug]
343
- elif task == TaskType.long_doc:
344
- benchmarks = LongDocBenchmarks[datastore.slug]
345
- else:
346
- raise NotImplementedError
347
- valid_cols = frozenset(df.columns.to_list())
348
- benchmark_cols = []
349
- for t in list(benchmarks.value):
350
- if t.value.col_name not in valid_cols:
351
- continue
352
- benchmark_cols.append(t.value.col_name)
353
-
354
- # filter out the columns that are not in the data
355
- df[COL_NAME_AVG] = df[list(benchmark_cols)].apply(calculate_mean, axis=1).round(decimals=2)
356
- df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
357
- df.reset_index(inplace=True, drop=True)
358
-
359
- # filter out columns that are not in the data
360
- display_cols = [COL_NAME_IS_ANONYMOUS, COL_NAME_AVG]
361
- default_cols, _ = get_default_col_names_and_types(benchmarks)
362
- for col in default_cols:
363
- if col in valid_cols:
364
- display_cols.append(col)
365
- df = df[display_cols].round(decimals=2)
366
-
367
- # rank the scores
368
- df = reset_rank(df)
369
-
370
- # shorten the revision
371
- df[COL_NAME_REVISION] = df[COL_NAME_REVISION].str[:6]
372
-
373
- return df
374
-
375
-
376
- def set_listeners(
377
- task: TaskType,
378
- target_df,
379
- source_df,
380
- search_bar,
381
- version,
382
- selected_domains,
383
- selected_langs,
384
- selected_rerankings,
385
- show_anonymous,
386
- show_revision_and_timestamp,
387
- ):
388
- if task == TaskType.qa:
389
- update_table_func = update_qa_df_elem
390
- elif task == TaskType.long_doc:
391
- update_table_func = update_doc_df_elem
392
- else:
393
- raise NotImplementedError
394
- selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
395
- search_bar_args = [
396
- source_df,
397
- version,
398
- ] + selector_list
399
- selector_args = (
400
- [version, source_df]
401
- + selector_list
402
- + [
403
- show_revision_and_timestamp,
404
- ]
405
- )
406
- # Set search_bar listener
407
- search_bar.submit(update_table_func, search_bar_args, target_df)
408
-
409
- # Set column-wise listener
410
- for selector in selector_list:
411
- selector.change(
412
- update_table_func,
413
- selector_args,
414
- target_df,
415
- queue=True,
416
- )
417
-
418
-
419
- def update_qa_df_elem(
420
- version: str,
421
- hidden_df: pd.DataFrame,
422
- domains: list,
423
- langs: list,
424
- reranking_query: list,
425
- query: str,
426
- show_anonymous: bool,
427
- show_revision_and_timestamp: bool = False,
428
- reset_ranking: bool = True,
429
- ):
430
- return _update_df_elem(
431
- TaskType.qa,
432
- version,
433
- hidden_df,
434
- domains,
435
- langs,
436
- reranking_query,
437
- query,
438
- show_anonymous,
439
- reset_ranking,
440
- show_revision_and_timestamp,
441
- )
442
-
443
-
444
- def styled_error(error):
445
- return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
446
-
447
-
448
- def styled_message(message):
449
- return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
 
 
1
  import json
2
+ import hashlib
3
  from datetime import datetime, timezone
4
  from pathlib import Path
5
+ from typing import List
6
 
7
  import pandas as pd
8
 
9
+ from src.benchmarks import BENCHMARK_COLS_QA, BENCHMARK_COLS_LONG_DOC, BenchmarksQA, BenchmarksLongDoc
10
+ from src.display.formatting import styled_message, styled_error
11
+ from src.display.utils import COLS_QA, TYPES_QA, COLS_LONG_DOC, TYPES_LONG_DOC, COL_NAME_RANK, COL_NAME_AVG, \
12
+ COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL, COL_NAME_IS_ANONYMOUS, COL_NAME_TIMESTAMP, COL_NAME_REVISION, get_default_auto_eval_column_dict
13
+ from src.envs import API, SEARCH_RESULTS_REPO, LATEST_BENCHMARK_VERSION
14
+ from src.read_evals import FullEvalResult, get_leaderboard_df, calculate_mean
15
+
16
+ import re
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def remove_html(input_str):
20
  # Regular expression for finding HTML tags
21
+ clean = re.sub(r'<.*?>', '', input_str)
22
  return clean
23
 
24
 
 
55
  return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
56
 
57
 
58
+ def get_default_cols(task: str, columns: list=[], add_fix_cols: bool=True) -> list:
59
  cols = []
60
  types = []
61
+ if task == "qa":
62
+ cols_list = COLS_QA
63
+ types_list = TYPES_QA
64
+ benchmark_list = BENCHMARK_COLS_QA
65
+ elif task == "long-doc":
66
+ cols_list = COLS_LONG_DOC
67
+ types_list = TYPES_LONG_DOC
68
+ benchmark_list = BENCHMARK_COLS_LONG_DOC
69
  else:
70
+ raise NotImplemented
 
 
71
  for col_name, col_type in zip(cols_list, types_list):
72
  if col_name not in benchmark_list:
73
  continue
74
+ if len(columns) > 0 and col_name not in columns:
75
+ continue
76
  cols.append(col_name)
77
  types.append(col_type)
78
+
79
  if add_fix_cols:
80
  _cols = []
81
  _types = []
 
82
  for col_name, col_type in zip(cols, types):
83
+ if col_name in FIXED_COLS:
84
  continue
85
  _cols.append(col_name)
86
  _types.append(col_type)
87
+ cols = FIXED_COLS + _cols
88
+ types = FIXED_COLS_TYPES + _types
89
  return cols, types
90
 
91
 
92
+ fixed_cols = get_default_auto_eval_column_dict()[:-3]
93
+
94
+ FIXED_COLS = [c.name for _, _, c in fixed_cols]
95
+ FIXED_COLS_TYPES = [c.type for _, _, c in fixed_cols]
96
+
97
+
98
+ def select_columns(
99
+ df: pd.DataFrame,
100
+ domain_query: list,
101
+ language_query: list,
102
+ task: str = "qa",
103
+ reset_ranking: bool = True
104
+ ) -> pd.DataFrame:
105
+ cols, _ = get_default_cols(task=task, columns=df.columns, add_fix_cols=False)
106
  selected_cols = []
107
  for c in cols:
108
+ if task == "qa":
109
+ eval_col = BenchmarksQA[c].value
110
+ elif task == "long-doc":
111
+ eval_col = BenchmarksLongDoc[c].value
112
+ if eval_col.domain not in domain_query:
 
 
113
  continue
114
+ if eval_col.lang not in language_query:
115
  continue
116
  selected_cols.append(c)
117
  # We use COLS to maintain sorting
118
+ filtered_df = df[FIXED_COLS + selected_cols]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  if reset_ranking:
120
  filtered_df[COL_NAME_AVG] = filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
121
  filtered_df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
122
  filtered_df.reset_index(inplace=True, drop=True)
123
  filtered_df = reset_rank(filtered_df)
124
+
125
  return filtered_df
126
 
127
 
128
+ def _update_table(
129
+ task: str,
130
+ hidden_df: pd.DataFrame,
131
+ domains: list,
132
+ langs: list,
133
+ reranking_query: list,
134
+ query: str,
135
+ show_anonymous: bool,
136
+ reset_ranking: bool = True,
137
+ show_revision_and_timestamp: bool = False
 
138
  ):
139
+ filtered_df = hidden_df.copy()
140
  if not show_anonymous:
141
  filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
142
  filtered_df = filter_models(filtered_df, reranking_query)
143
  filtered_df = filter_queries(query, filtered_df)
144
+ filtered_df = select_columns(filtered_df, domains, langs, task, reset_ranking)
145
  if not show_revision_and_timestamp:
146
  filtered_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
147
  return filtered_df
148
 
149
 
150
+ def update_table(
151
+ hidden_df: pd.DataFrame,
152
+ domains: list,
153
+ langs: list,
154
+ reranking_query: list,
155
+ query: str,
156
+ show_anonymous: bool,
157
+ show_revision_and_timestamp: bool = False,
158
+ reset_ranking: bool = True
 
159
  ):
160
+ return _update_table(
161
+ "qa", hidden_df, domains, langs, reranking_query, query, show_anonymous, reset_ranking, show_revision_and_timestamp)
162
+
163
+
164
+ def update_table_long_doc(
165
+ hidden_df: pd.DataFrame,
166
+ domains: list,
167
+ langs: list,
168
+ reranking_query: list,
169
+ query: str,
170
+ show_anonymous: bool,
171
+ show_revision_and_timestamp: bool = False,
172
+ reset_ranking: bool = True
173
+
174
+ ):
175
+ return _update_table(
176
+ "long-doc", hidden_df, domains, langs, reranking_query, query, show_anonymous, reset_ranking, show_revision_and_timestamp)
177
 
178
 
179
  def update_metric(
180
+ raw_data: List[FullEvalResult],
181
+ task: str,
182
+ metric: str,
183
+ domains: list,
184
+ langs: list,
185
+ reranking_model: list,
186
+ query: str,
187
+ show_anonymous: bool = False,
188
+ show_revision_and_timestamp: bool = False,
189
  ) -> pd.DataFrame:
190
+ if task == 'qa':
191
+ leaderboard_df = get_leaderboard_df(raw_data, task=task, metric=metric)
192
+ return update_table(
193
+ leaderboard_df,
194
+ domains,
195
+ langs,
196
+ reranking_model,
197
+ query,
198
+ show_anonymous,
199
+ show_revision_and_timestamp
200
+ )
201
+ elif task == "long-doc":
202
+ leaderboard_df = get_leaderboard_df(raw_data, task=task, metric=metric)
203
+ return update_table_long_doc(
204
+ leaderboard_df,
205
+ domains,
206
+ langs,
207
+ reranking_model,
208
+ query,
209
+ show_anonymous,
210
+ show_revision_and_timestamp
211
+ )
212
 
213
 
214
  def upload_file(filepath: str):
 
218
  return filepath
219
 
220
 
221
+
222
  def get_iso_format_timestamp():
223
  # Get the current timestamp with UTC as the timezone
224
  current_timestamp = datetime.now(timezone.utc)
 
227
  current_timestamp = current_timestamp.replace(microsecond=0)
228
 
229
  # Convert to ISO 8601 format and replace the offset with 'Z'
230
+ iso_format_timestamp = current_timestamp.isoformat().replace('+00:00', 'Z')
231
+ filename_friendly_timestamp = current_timestamp.strftime('%Y%m%d%H%M%S')
232
  return iso_format_timestamp, filename_friendly_timestamp
233
 
234
 
235
  def calculate_file_md5(file_path):
236
  md5 = hashlib.md5()
237
 
238
+ with open(file_path, 'rb') as f:
239
  while True:
240
  data = f.read(4096)
241
  if not data:
 
246
 
247
 
248
  def submit_results(
249
+ filepath: str,
250
+ model: str,
251
+ model_url: str,
252
+ reranking_model: str="",
253
+ reranking_model_url: str="",
254
+ version: str=LATEST_BENCHMARK_VERSION,
255
+ is_anonymous=False):
 
256
  if not filepath.endswith(".zip"):
257
  return styled_error(f"file uploading aborted. wrong file type: {filepath}")
258
 
 
265
  if not model_url.startswith("https://") and not model_url.startswith("http://"):
266
  # TODO: retrieve the model page and find the model name on the page
267
  return styled_error(
268
+ f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}")
 
269
  if reranking_model != "NoReranker":
270
  if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
271
  return styled_error(
272
+ f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}")
 
273
 
274
  # rename the uploaded file
275
  input_fp = Path(filepath)
 
279
  input_folder_path = input_fp.parent
280
 
281
  if not reranking_model:
282
+ reranking_model = 'NoReranker'
283
+
284
  API.upload_file(
285
  path_or_fileobj=filepath,
286
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
287
  repo_id=SEARCH_RESULTS_REPO,
288
  repo_type="dataset",
289
+ commit_message=f"feat: submit {model} to evaluate")
 
290
 
291
  output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
292
  output_config = {
 
297
  "version": f"{version}",
298
  "is_anonymous": is_anonymous,
299
  "revision": f"{revision}",
300
+ "timestamp": f"{timestamp_config}"
301
  }
302
  with open(input_folder_path / output_config_fn, "w") as f:
303
  json.dump(output_config, f, indent=4, ensure_ascii=False)
 
306
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
307
  repo_id=SEARCH_RESULTS_REPO,
308
  repo_type="dataset",
309
+ commit_message=f"feat: submit {model} + {reranking_model} config")
 
310
  return styled_message(
311
  f"Thanks for submission!\n"
312
  f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
 
316
  def reset_rank(df):
317
  df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
318
  return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/src/display/test_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from src.display.utils import fields, AutoEvalColumnQA, COLS_QA, COLS_LONG_DOC, COLS_LITE, TYPES_QA, TYPES_LONG_DOC, QA_BENCHMARK_COLS, LONG_DOC_BENCHMARK_COLS, get_default_auto_eval_column_dict
3
+
4
+
5
+ def test_fields():
6
+ for c in fields(AutoEvalColumnQA):
7
+ print(c)
8
+
9
+
10
+ def test_macro_variables():
11
+ print(f'COLS_QA: {COLS_QA}')
12
+ print(f'COLS_LONG_DOC: {COLS_LONG_DOC}')
13
+ print(f'COLS_LITE: {COLS_LITE}')
14
+ print(f'TYPES_QA: {TYPES_QA}')
15
+ print(f'TYPES_LONG_DOC: {TYPES_LONG_DOC}')
16
+ print(f'QA_BENCHMARK_COLS: {QA_BENCHMARK_COLS}')
17
+ print(f'LONG_DOC_BENCHMARK_COLS: {LONG_DOC_BENCHMARK_COLS}')
18
+
19
+
20
+ def test_get_default_auto_eval_column_dict():
21
+ auto_eval_column_dict_list = get_default_auto_eval_column_dict()
22
+ assert len(auto_eval_column_dict_list) == 9
23
+
tests/src/test_benchmarks.py CHANGED
@@ -1,33 +1,9 @@
1
- import pytest
2
 
3
- from src.benchmarks import LongDocBenchmarks, QABenchmarks
4
- from src.envs import BENCHMARK_VERSION_LIST
5
 
6
- # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
7
- # 24.05
8
- # | Task | dev | test |
9
- # | ---- | --- | ---- |
10
- # | Long-Doc | 4 | 11 |
11
- # | QA | 54 | 53 |
12
- #
13
- # 24.04
14
- # | Task | test |
15
- # | ---- | ---- |
16
- # | Long-Doc | 15 |
17
- # | QA | 13 |
18
 
19
 
20
- @pytest.mark.parametrize("num_datasets_dict", [{"air_bench_2404": 13, "air_bench_2405": 53}])
21
- def test_qa_benchmarks(num_datasets_dict):
22
- assert len(QABenchmarks) == len(BENCHMARK_VERSION_LIST)
23
- for benchmark_list in list(QABenchmarks):
24
- version_slug = benchmark_list.name
25
- assert num_datasets_dict[version_slug] == len(benchmark_list.value)
26
-
27
-
28
- @pytest.mark.parametrize("num_datasets_dict", [{"air_bench_2404": 15, "air_bench_2405": 11}])
29
- def test_doc_benchmarks(num_datasets_dict):
30
- assert len(LongDocBenchmarks) == len(BENCHMARK_VERSION_LIST)
31
- for benchmark_list in list(LongDocBenchmarks):
32
- version_slug = benchmark_list.name
33
- assert num_datasets_dict[version_slug] == len(benchmark_list.value)
 
1
+ from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
2
 
 
 
3
 
4
+ def test_qabenchmarks():
5
+ print(list(BenchmarksQA))
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
+ def test_longdocbenchmarks():
9
+ print(list(BenchmarksLongDoc))
 
 
 
 
 
 
 
 
 
 
 
 
tests/src/test_columns.py DELETED
@@ -1,119 +0,0 @@
1
- import pytest
2
-
3
- from src.benchmarks import LongDocBenchmarks, QABenchmarks
4
- from src.columns import (
5
- COL_NAME_AVG,
6
- COL_NAME_RANK,
7
- COL_NAME_RERANKING_MODEL,
8
- COL_NAME_RETRIEVAL_MODEL,
9
- COL_NAME_REVISION,
10
- COL_NAME_TIMESTAMP,
11
- get_default_auto_eval_column_dict,
12
- get_default_col_names_and_types,
13
- get_fixed_col_names_and_types,
14
- make_autoevalcolumn,
15
- )
16
-
17
- # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
18
- # 24.05
19
- # | Task | dev | test |
20
- # | ---- | --- | ---- |
21
- # | Long-Doc | 4 | 11 |
22
- # | QA | 54 | 53 |
23
- #
24
- # 24.04
25
- # | Task | test |
26
- # | ---- | ---- |
27
- # | Long-Doc | 15 |
28
- # | QA | 13 |
29
-
30
-
31
- @pytest.fixture()
32
- def expected_col_names():
33
- return [
34
- "rank",
35
- "retrieval_model",
36
- "reranking_model",
37
- "revision",
38
- "timestamp",
39
- "average",
40
- "retrieval_model_link",
41
- "reranking_model_link",
42
- "is_anonymous",
43
- ]
44
-
45
-
46
- @pytest.fixture()
47
- def expected_hidden_col_names():
48
- return [
49
- "retrieval_model_link",
50
- "reranking_model_link",
51
- "is_anonymous",
52
- ]
53
-
54
-
55
- def test_get_default_auto_eval_column_dict(expected_col_names, expected_hidden_col_names):
56
- col_list = get_default_auto_eval_column_dict()
57
- assert len(col_list) == 9
58
- hidden_cols = []
59
- for col_tuple, expected_col in zip(col_list, expected_col_names):
60
- col, _, col_content = col_tuple
61
- assert col == expected_col
62
- if col_content.hidden:
63
- hidden_cols.append(col)
64
- assert hidden_cols == expected_hidden_col_names
65
-
66
-
67
- def test_get_fixed_col_names_and_types():
68
- col_names, col_types = get_fixed_col_names_and_types()
69
- assert len(col_names) == 6
70
- assert len(col_types) == 6
71
- expected_col_and_type = [
72
- (COL_NAME_RANK, "number"),
73
- (COL_NAME_RETRIEVAL_MODEL, "markdown"),
74
- (COL_NAME_RERANKING_MODEL, "markdown"),
75
- (COL_NAME_REVISION, "markdown"),
76
- (COL_NAME_TIMESTAMP, "date"),
77
- (COL_NAME_AVG, "number"),
78
- ]
79
- for col_name, col_type, (c_name, c_type) in zip(col_names, col_types, expected_col_and_type):
80
- assert col_name == c_name
81
- assert col_type == c_type
82
-
83
-
84
- @pytest.mark.parametrize(
85
- "benchmarks, expected_benchmark_len",
86
- [
87
- (QABenchmarks, {"air_bench_2404": 13, "air_bench_2405": 53}),
88
- (LongDocBenchmarks, {"air_bench_2404": 15, "air_bench_2405": 11}),
89
- ],
90
- )
91
- def test_make_autoevalcolumn(benchmarks, expected_benchmark_len, expected_col_names):
92
- expected_default_attrs = frozenset(expected_col_names)
93
- for benchmark in benchmarks:
94
- TestEvalColumn = make_autoevalcolumn("TestEvalColumn", benchmark)
95
- attrs = []
96
- for k, v in TestEvalColumn.__dict__.items():
97
- if not k.startswith("__"):
98
- attrs.append(k)
99
- attrs = frozenset(attrs)
100
- assert expected_default_attrs.issubset(attrs)
101
- benchmark_attrs = attrs.difference(expected_default_attrs)
102
- assert len(benchmark_attrs) == expected_benchmark_len[benchmark.name]
103
-
104
-
105
- @pytest.mark.parametrize(
106
- "benchmarks, expected_benchmark_len",
107
- [
108
- (QABenchmarks, {"air_bench_2404": 13, "air_bench_2405": 53}),
109
- (LongDocBenchmarks, {"air_bench_2404": 15, "air_bench_2405": 11}),
110
- ],
111
- )
112
- def test_get_default_col_names_and_types(
113
- benchmarks, expected_benchmark_len, expected_col_names, expected_hidden_col_names
114
- ):
115
- default_col_len = len(expected_col_names)
116
- hidden_col_len = len(expected_hidden_col_names)
117
- for benchmark in benchmarks:
118
- col_names, col_types = get_default_col_names_and_types(benchmark)
119
- assert len(col_names) == expected_benchmark_len[benchmark.name] + default_col_len - hidden_col_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/src/test_envs.py DELETED
@@ -1,14 +0,0 @@
1
- from air_benchmark.tasks import BenchmarkTable
2
-
3
- from src.envs import BENCHMARK_VERSION_LIST, DEFAULT_METRIC_LONG_DOC, DEFAULT_METRIC_QA, METRIC_LIST
4
-
5
-
6
- def test_benchmark_version_list():
7
- leaderboard_versions = frozenset(BENCHMARK_VERSION_LIST)
8
- available_versions = frozenset([k for k in BenchmarkTable.keys()])
9
- assert leaderboard_versions.issubset(available_versions)
10
-
11
-
12
- def test_default_metrics():
13
- assert DEFAULT_METRIC_QA in METRIC_LIST
14
- assert DEFAULT_METRIC_LONG_DOC in METRIC_LIST
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/src/test_loaders.py DELETED
@@ -1,46 +0,0 @@
1
- from pathlib import Path
2
-
3
- import pandas as pd
4
- import pytest
5
-
6
- from src.loaders import load_eval_results, load_leaderboard_datastore, load_raw_eval_results
7
-
8
- cur_fp = Path(__file__)
9
-
10
-
11
- @pytest.mark.parametrize("version", ["AIR-Bench_24.04", "AIR-Bench_24.05"])
12
- def test_load_raw_eval_results(version):
13
- raw_data = load_raw_eval_results(cur_fp.parents[1] / f"toydata/eval_results/{version}")
14
- assert len(raw_data) == 1
15
- full_eval_result = raw_data[0]
16
- expected_attr = [
17
- "eval_name",
18
- "retrieval_model",
19
- "reranking_model",
20
- "retrieval_model_link",
21
- "reranking_model_link",
22
- "results",
23
- "timestamp",
24
- "revision",
25
- "is_anonymous",
26
- ]
27
- result_attr = [k for k in full_eval_result.__dict__.keys() if k[:2] != "__" and k[-2:] != "__"]
28
- assert sorted(expected_attr) == sorted(result_attr)
29
-
30
-
31
- @pytest.mark.parametrize("version", ["AIR-Bench_24.04", "AIR-Bench_24.05"])
32
- def test_load_leaderboard_datastore(version):
33
- file_path = cur_fp.parents[1] / f"toydata/eval_results/{version}"
34
- datastore = load_leaderboard_datastore(file_path, version)
35
- for k, v in datastore.__dict__.items():
36
- if k[:2] != "__" and k[-2:] != "__":
37
- if isinstance(v, list):
38
- assert v
39
- elif isinstance(v, pd.DataFrame):
40
- assert not v.empty
41
-
42
-
43
- def test_load_eval_results():
44
- file_path = cur_fp.parents[1] / "toydata/eval_results/"
45
- datastore_dict = load_eval_results(file_path)
46
- assert len(datastore_dict) == 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/src/test_models.py DELETED
@@ -1,89 +0,0 @@
1
- from pathlib import Path
2
-
3
- import pytest
4
-
5
- from src.models import EvalResult, FullEvalResult
6
-
7
- cur_fp = Path(__file__)
8
-
9
-
10
- # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
11
- # 24.05
12
- # | Task | dev | test |
13
- # | ---- | --- | ---- |
14
- # | Long-Doc | 4 | 11 |
15
- # | QA | 54 | 53 |
16
- #
17
- # 24.04
18
- # | Task | test |
19
- # | ---- | ---- |
20
- # | Long-Doc | 15 |
21
- # | QA | 13 |
22
- NUM_QA_BENCHMARKS_24_05 = 53
23
- NUM_DOC_BENCHMARKS_24_05 = 11
24
- NUM_QA_BENCHMARKS_24_04 = 13
25
- NUM_DOC_BENCHMARKS_24_04 = 15
26
-
27
-
28
- def test_eval_result():
29
- EvalResult(
30
- eval_name="eval_name",
31
- retrieval_model="bge-m3",
32
- reranking_model="NoReranking",
33
- results=[{"domain": "law", "lang": "en", "dataset": "lex_files_500K-600K", "value": 0.45723}],
34
- task="qa",
35
- metric="ndcg_at_3",
36
- timestamp="2024-05-14T03:09:08Z",
37
- revision="1e243f14bd295ccdea7a118fe847399d",
38
- is_anonymous=True,
39
- )
40
-
41
-
42
- @pytest.mark.parametrize(
43
- "file_path",
44
- [
45
- "AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json",
46
- "AIR-Bench_24.05/bge-m3/NoReranker/results.json",
47
- ],
48
- )
49
- def test_full_eval_result_init_from_json_file(file_path):
50
- json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
51
- full_eval_result = FullEvalResult.init_from_json_file(json_fp)
52
- assert json_fp.parents[0].stem == full_eval_result.reranking_model
53
- assert json_fp.parents[1].stem == full_eval_result.retrieval_model
54
- assert len(full_eval_result.results) == 70
55
-
56
-
57
- @pytest.mark.parametrize(
58
- "file_path, task, expected_num_results",
59
- [
60
- ("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "qa", NUM_QA_BENCHMARKS_24_04),
61
- (
62
- "AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json",
63
- "long-doc",
64
- NUM_DOC_BENCHMARKS_24_04,
65
- ),
66
- ("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "qa", NUM_QA_BENCHMARKS_24_05),
67
- ("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_05),
68
- ],
69
- )
70
- def test_full_eval_result_to_dict(file_path, task, expected_num_results):
71
- json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
72
- full_eval_result = FullEvalResult.init_from_json_file(json_fp)
73
- result_dict_list = full_eval_result.to_dict(task)
74
- assert len(result_dict_list) == 1
75
- result = result_dict_list[0]
76
- attr_list = frozenset(
77
- [
78
- "eval_name",
79
- "Retrieval Method",
80
- "Reranking Model",
81
- "Retrieval Model LINK",
82
- "Reranking Model LINK",
83
- "Revision",
84
- "Submission Date",
85
- "Anonymous Submission",
86
- ]
87
- )
88
- result_cols = list(result.keys())
89
- assert len(result_cols) == (expected_num_results + len(attr_list))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/src/test_read_evals.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from src.read_evals import FullEvalResult, get_raw_eval_results, get_leaderboard_df
4
+
5
+ cur_fp = Path(__file__)
6
+
7
+
8
+ def test_init_from_json_file():
9
+ json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
10
+ full_eval_result = FullEvalResult.init_from_json_file(json_fp)
11
+ num_different_task_domain_lang_metric_dataset_combination = 6
12
+ assert len(full_eval_result.results) == \
13
+ num_different_task_domain_lang_metric_dataset_combination
14
+ assert full_eval_result.retrieval_model == "bge-m3"
15
+ assert full_eval_result.reranking_model == "bge-reranker-v2-m3"
16
+
17
+
18
+ def test_to_dict():
19
+ json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
20
+ full_eval_result = FullEvalResult.init_from_json_file(json_fp)
21
+ result_list = full_eval_result.to_dict(task='qa', metric='ndcg_at_1')
22
+ assert len(result_list) == 1
23
+ result_dict = result_list[0]
24
+ assert result_dict["Retrieval Model"] == "bge-m3"
25
+ assert result_dict["Reranking Model"] == "bge-reranker-v2-m3"
26
+ assert result_dict["wiki_en"] is not None
27
+ assert result_dict["wiki_zh"] is not None
28
+
29
+
30
+ def test_get_raw_eval_results():
31
+ results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
32
+ results = get_raw_eval_results(results_path)
33
+ # only load the latest results
34
+ assert len(results) == 4
35
+ assert results[0].eval_name == "bge-base-en-v1.5_NoReranker"
36
+ assert len(results[0].results) == 70
37
+ assert results[0].eval_name == "bge-base-en-v1.5_bge-reranker-v2-m3"
38
+ assert len(results[1].results) == 70
39
+
40
+
41
+ def test_get_leaderboard_df():
42
+ results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
43
+ raw_data = get_raw_eval_results(results_path)
44
+ df = get_leaderboard_df(raw_data, 'qa', 'ndcg_at_10')
45
+ assert df.shape[0] == 4
46
+ # the results contain only one embedding model
47
+ # for i in range(4):
48
+ # assert df["Retrieval Model"][i] == "bge-m3"
49
+ # # the results contain only two reranking model
50
+ # assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
51
+ # assert df["Reranking Model"][1] == "NoReranker"
52
+ # assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
53
+ # assert not df[['Average ⬆️', 'wiki_en', 'wiki_zh', ]].isnull().values.any()
54
+
55
+
56
+ def test_get_leaderboard_df_long_doc():
57
+ results_path = cur_fp.parents[2] / "toydata" / "test_results"
58
+ raw_data = get_raw_eval_results(results_path)
59
+ df = get_leaderboard_df(raw_data, 'long-doc', 'ndcg_at_1')
60
+ assert df.shape[0] == 2
61
+ # the results contain only one embedding model
62
+ for i in range(2):
63
+ assert df["Retrieval Model"][i] == "bge-m3"
64
+ # the results contains only two reranking model
65
+ assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
66
+ assert df["Reranking Model"][1] == "NoReranker"
67
+ assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
68
+ assert not df[['Average ⬆️', 'law_en_lex_files_500k_600k', ]].isnull().values.any()
tests/src/test_utils.py DELETED
@@ -1,237 +0,0 @@
1
- from pathlib import Path
2
-
3
- import pandas as pd
4
- import pytest
5
-
6
- from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
7
- from src.models import TaskType, model_hyperlink
8
- from src.utils import (
9
- _update_df_elem,
10
- calculate_mean,
11
- filter_models,
12
- filter_queries,
13
- get_default_cols,
14
- get_leaderboard_df,
15
- get_selected_cols,
16
- remove_html,
17
- select_columns,
18
- )
19
-
20
- cur_fp = Path(__file__)
21
-
22
- NUM_QA_BENCHMARKS_24_05 = 53
23
- NUM_DOC_BENCHMARKS_24_05 = 11
24
- NUM_QA_BENCHMARKS_24_04 = 13
25
- NUM_DOC_BENCHMARKS_24_04 = 15
26
-
27
-
28
- @pytest.fixture
29
- def toy_df():
30
- return pd.DataFrame(
31
- {
32
- "Retrieval Method": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
33
- "Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
34
- "Rank 🏆": [1, 2, 3, 4],
35
- "Revision": ["123", "234", "345", "456"],
36
- "Submission Date": ["", "", "", ""],
37
- "Average ⬆️": [0.6, 0.4, 0.3, 0.2],
38
- "wiki_en": [0.8, 0.7, 0.2, 0.1],
39
- "wiki_zh": [0.4, 0.1, 0.4, 0.3],
40
- "news_en": [0.8, 0.7, 0.2, 0.1],
41
- "news_zh": [0.4, 0.1, 0.2, 0.3],
42
- "Anonymous Submission": [False, False, False, True],
43
- }
44
- )
45
-
46
-
47
- def test_remove_html():
48
- model_name = "jina-embeddings-v3"
49
- html_str = model_hyperlink("https://jina.ai", model_name)
50
- output_str = remove_html(html_str)
51
- assert output_str == model_name
52
-
53
-
54
- def test_calculate_mean():
55
- valid_row = [1, 3]
56
- invalid_row = [2, pd.NA]
57
- df = pd.DataFrame([valid_row, invalid_row], columns=["a", "b"])
58
- result = list(df.apply(calculate_mean, axis=1))
59
- assert result[0] == sum(valid_row) / 2
60
- assert result[1] == -1
61
-
62
-
63
- @pytest.mark.parametrize(
64
- "models, expected",
65
- [
66
- (["model1", "model3"], 2),
67
- (["model1", "model_missing"], 1),
68
- (["model1", "model2", "model3"], 3),
69
- (
70
- [
71
- "model1",
72
- ],
73
- 1,
74
- ),
75
- ([], 3),
76
- ],
77
- )
78
- def test_filter_models(models, expected):
79
- df = pd.DataFrame(
80
- {
81
- COL_NAME_RERANKING_MODEL: [
82
- "model1",
83
- "model2",
84
- "model3",
85
- ],
86
- "col2": [1, 2, 3],
87
- }
88
- )
89
- output_df = filter_models(df, models)
90
- assert len(output_df) == expected
91
-
92
-
93
- @pytest.mark.parametrize(
94
- "query, expected",
95
- [
96
- ("model1;model3", 2),
97
- ("model1;model4", 1),
98
- ("model1;model2;model3", 3),
99
- ("model1", 1),
100
- ("", 3),
101
- ],
102
- )
103
- def test_filter_queries(query, expected):
104
- df = pd.DataFrame(
105
- {
106
- COL_NAME_RETRIEVAL_MODEL: [
107
- "model1",
108
- "model2",
109
- "model3",
110
- ],
111
- COL_NAME_RERANKING_MODEL: [
112
- "model4",
113
- "model5",
114
- "model6",
115
- ],
116
- }
117
- )
118
- output_df = filter_queries(query, df)
119
- assert len(output_df) == expected
120
-
121
-
122
- @pytest.mark.parametrize(
123
- "task_type, slug, add_fix_cols, expected",
124
- [
125
- (TaskType.qa, "air_bench_2404", True, NUM_QA_BENCHMARKS_24_04),
126
- (TaskType.long_doc, "air_bench_2404", True, NUM_DOC_BENCHMARKS_24_04),
127
- (TaskType.qa, "air_bench_2405", False, NUM_QA_BENCHMARKS_24_05),
128
- (TaskType.long_doc, "air_bench_2405", False, NUM_DOC_BENCHMARKS_24_05),
129
- ],
130
- )
131
- def test_get_default_cols(task_type, slug, add_fix_cols, expected):
132
- attr_cols = ["Rank 🏆", "Retrieval Method", "Reranking Model", "Revision", "Submission Date", "Average ⬆️"]
133
- cols, types = get_default_cols(task_type, slug)
134
- cols_set = frozenset(cols)
135
- attrs_set = frozenset(attr_cols)
136
- if add_fix_cols:
137
- assert attrs_set.issubset(cols_set)
138
- benchmark_cols = list(cols_set.difference(attrs_set))
139
- assert len(benchmark_cols) == expected
140
-
141
-
142
- @pytest.mark.parametrize(
143
- "task_type, domains, languages, expected",
144
- [
145
- (
146
- TaskType.qa,
147
- ["wiki", "news"],
148
- [
149
- "zh",
150
- ],
151
- ["wiki_zh", "news_zh"],
152
- ),
153
- (
154
- TaskType.qa,
155
- [
156
- "law",
157
- ],
158
- ["zh", "en"],
159
- ["law_en"],
160
- ),
161
- (
162
- TaskType.long_doc,
163
- ["healthcare"],
164
- ["zh", "en"],
165
- [
166
- "healthcare_en_pubmed_100k_200k_1",
167
- "healthcare_en_pubmed_100k_200k_2",
168
- "healthcare_en_pubmed_100k_200k_3",
169
- "healthcare_en_pubmed_40k_50k_5_merged",
170
- "healthcare_en_pubmed_30k_40k_10_merged",
171
- ],
172
- ),
173
- ],
174
- )
175
- def test_get_selected_cols(task_type, domains, languages, expected):
176
- slug = "air_bench_2404"
177
- cols = get_selected_cols(task_type, slug, domains, languages)
178
- assert sorted(cols) == sorted(expected)
179
-
180
-
181
- @pytest.mark.parametrize("reset_rank", [False])
182
- def test_select_columns(toy_df, reset_rank):
183
- expected = [
184
- "Rank 🏆",
185
- "Retrieval Method",
186
- "Reranking Model",
187
- "Revision",
188
- "Submission Date",
189
- "Average ⬆️",
190
- "news_zh",
191
- ]
192
- df_result = select_columns(toy_df, ["news"], ["zh"], version_slug="air_bench_2404", reset_ranking=reset_rank)
193
- assert len(df_result.columns) == len(expected)
194
- if reset_rank:
195
- assert df_result["Average ⬆️"].equals(df_result["news_zh"])
196
- else:
197
- assert df_result["Average ⬆️"].equals(toy_df["Average ⬆️"])
198
-
199
-
200
- @pytest.mark.parametrize(
201
- "reset_rank, show_anony",
202
- [
203
- (False, True),
204
- (True, True),
205
- (True, False),
206
- ],
207
- )
208
- def test__update_df_elem(toy_df, reset_rank, show_anony):
209
- df = _update_df_elem(TaskType.qa, "AIR-Bench_24.04", toy_df, ["news"], ["zh"], [], "", show_anony, reset_rank)
210
- if show_anony:
211
- assert df.shape[0] == 4
212
- else:
213
- assert df.shape[0] == 3
214
- if show_anony:
215
- if reset_rank:
216
- assert df["Average ⬆️"].equals(df["news_zh"])
217
- else:
218
- assert df["Average ⬆️"].equals(toy_df["Average ⬆️"])
219
-
220
-
221
- @pytest.mark.parametrize(
222
- "version, task_type",
223
- [
224
- ("AIR-Bench_24.04", TaskType.qa),
225
- ("AIR-Bench_24.04", TaskType.long_doc),
226
- ("AIR-Bench_24.05", TaskType.qa),
227
- ("AIR-Bench_24.05", TaskType.long_doc),
228
- ],
229
- )
230
- def test_get_leaderboard_df(version, task_type):
231
- from src.loaders import load_raw_eval_results
232
- from src.models import LeaderboardDataStore, get_safe_name
233
-
234
- raw_data = load_raw_eval_results(cur_fp.parents[1] / f"toydata/eval_results/{version}")
235
- ds = LeaderboardDataStore(version, get_safe_name(version), raw_data=raw_data)
236
- df = get_leaderboard_df(ds, task_type, "ndcg_at_10")
237
- assert df.shape[0] == 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import pytest
3
+
4
+ from src.utils import filter_models, search_table, filter_queries, select_columns, update_table_long_doc, get_iso_format_timestamp, get_default_cols, update_table
5
+ from src.display.utils import COL_NAME_IS_ANONYMOUS, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RANK, COL_NAME_AVG
6
+
7
+
8
+ @pytest.fixture
9
+ def toy_df():
10
+ return pd.DataFrame(
11
+ {
12
+ "Retrieval Model": [
13
+ "bge-m3",
14
+ "bge-m3",
15
+ "jina-embeddings-v2-base",
16
+ "jina-embeddings-v2-base"
17
+ ],
18
+ "Reranking Model": [
19
+ "bge-reranker-v2-m3",
20
+ "NoReranker",
21
+ "bge-reranker-v2-m3",
22
+ "NoReranker"
23
+ ],
24
+ "Average ⬆️": [0.6, 0.4, 0.3, 0.2],
25
+ "wiki_en": [0.8, 0.7, 0.2, 0.1],
26
+ "wiki_zh": [0.4, 0.1, 0.4, 0.3],
27
+ "news_en": [0.8, 0.7, 0.2, 0.1],
28
+ "news_zh": [0.4, 0.1, 0.4, 0.3],
29
+ }
30
+ )
31
+
32
+
33
+ @pytest.fixture
34
+ def toy_df_long_doc():
35
+ return pd.DataFrame(
36
+ {
37
+ "Retrieval Model": [
38
+ "bge-m3",
39
+ "bge-m3",
40
+ "jina-embeddings-v2-base",
41
+ "jina-embeddings-v2-base"
42
+ ],
43
+ "Reranking Model": [
44
+ "bge-reranker-v2-m3",
45
+ "NoReranker",
46
+ "bge-reranker-v2-m3",
47
+ "NoReranker"
48
+ ],
49
+ "Average ⬆️": [0.6, 0.4, 0.3, 0.2],
50
+ "law_en_lex_files_300k_400k": [0.4, 0.1, 0.4, 0.3],
51
+ "law_en_lex_files_400k_500k": [0.8, 0.7, 0.2, 0.1],
52
+ "law_en_lex_files_500k_600k": [0.8, 0.7, 0.2, 0.1],
53
+ "law_en_lex_files_600k_700k": [0.4, 0.1, 0.4, 0.3],
54
+ }
55
+ )
56
+ def test_filter_models(toy_df):
57
+ df_result = filter_models(toy_df, ["bge-reranker-v2-m3", ])
58
+ assert len(df_result) == 2
59
+ assert df_result.iloc[0]["Reranking Model"] == "bge-reranker-v2-m3"
60
+
61
+
62
+ def test_search_table(toy_df):
63
+ df_result = search_table(toy_df, "jina")
64
+ assert len(df_result) == 2
65
+ assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
66
+
67
+
68
+ def test_filter_queries(toy_df):
69
+ df_result = filter_queries("jina", toy_df)
70
+ assert len(df_result) == 2
71
+ assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
72
+
73
+
74
+ def test_select_columns(toy_df):
75
+ df_result = select_columns(toy_df, ['news',], ['zh',])
76
+ assert len(df_result.columns) == 4
77
+ assert df_result['Average ⬆️'].equals(df_result['news_zh'])
78
+
79
+
80
+ def test_update_table_long_doc(toy_df_long_doc):
81
+ df_result = update_table_long_doc(toy_df_long_doc, ['law',], ['en',], ["bge-reranker-v2-m3", ], "jina")
82
+ print(df_result)
83
+
84
+
85
+ def test_get_iso_format_timestamp():
86
+ timestamp_config, timestamp_fn = get_iso_format_timestamp()
87
+ assert len(timestamp_fn) == 14
88
+ assert len(timestamp_config) == 20
89
+ assert timestamp_config[-1] == "Z"
90
+
91
+
92
+ def test_get_default_cols():
93
+ cols, types = get_default_cols("qa")
94
+ for c, t in zip(cols, types):
95
+ print(f"type({c}): {t}")
96
+ assert len(frozenset(cols)) == len(cols)
97
+
98
+
99
+ def test_update_table():
100
+ df = pd.DataFrame(
101
+ {
102
+ COL_NAME_IS_ANONYMOUS: [False, False, False],
103
+ COL_NAME_REVISION: ["a1", "a2", "a3"],
104
+ COL_NAME_TIMESTAMP: ["2024-05-12T12:24:02Z"] * 3,
105
+ COL_NAME_RERANKING_MODEL: ["NoReranker"] * 3,
106
+ COL_NAME_RETRIEVAL_MODEL: ["Foo"] * 3,
107
+ COL_NAME_RANK: [1, 2, 3],
108
+ COL_NAME_AVG: [0.1, 0.2, 0.3], # unsorted values
109
+ "wiki_en": [0.1, 0.2, 0.3]
110
+ }
111
+ )
112
+ results = update_table(df, "wiki", "en", ["NoReranker"], "", show_anonymous=False, reset_ranking=False, show_revision_and_timestamp=False)
113
+ # keep the RANK as the same regardless of the unsorted averages
114
+ assert results[COL_NAME_RANK].to_list() == [1, 2, 3]
115
+
tests/toydata/eval_results/AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json DELETED
The diff for this file is too large to render. See raw diff
 
tests/toydata/eval_results/AIR-Bench_24.05/bge-m3/NoReranker/results.json DELETED
The diff for this file is too large to render. See raw diff
 
tests/toydata/test_data.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "config": {
4
+ "retrieval_model": "bge-m3",
5
+ "reranking_model": "bge-reranker-v2-m3",
6
+ "task": "long_doc",
7
+ "metric": "ndcg_at_1"
8
+ },
9
+ "results": [
10
+ {
11
+ "domain": "law",
12
+ "lang": "en",
13
+ "dataset": "lex_files_500K-600K",
14
+ "value": 0.75723
15
+ }
16
+ ]
17
+ },
18
+ {
19
+ "config": {
20
+ "retrieval_model": "bge-m3",
21
+ "reranking_model": "bge-reranker-v2-m3",
22
+ "task": "long_doc",
23
+ "metric": "ndcg_at_3"
24
+ },
25
+ "results": [
26
+ {
27
+ "domain": "law",
28
+ "lang": "en",
29
+ "dataset": "lex_files_500K-600K",
30
+ "value": 0.69909
31
+ }
32
+ ]
33
+ },
34
+ {
35
+ "config": {
36
+ "retrieval_model": "bge-m3",
37
+ "reranking_model": "bge-reranker-v2-m3",
38
+ "task": "qa",
39
+ "metric": "ndcg_at_1"
40
+ },
41
+ "results": [
42
+ {
43
+ "domain": "wiki",
44
+ "lang": "en",
45
+ "dataset": "unknown",
46
+ "value": 0.69083
47
+ }
48
+ ]
49
+ },
50
+ {
51
+ "config": {
52
+ "retrieval_model": "bge-m3",
53
+ "reranking_model": "bge-reranker-v2-m3",
54
+ "task": "qa",
55
+ "metric": "ndcg_at_3"
56
+ },
57
+ "results": [
58
+ {
59
+ "domain": "wiki",
60
+ "lang": "en",
61
+ "dataset": "unknown",
62
+ "value": 0.73359
63
+ }
64
+ ]
65
+ },
66
+ {
67
+ "config": {
68
+ "retrieval_model": "bge-m3",
69
+ "reranking_model": "bge-reranker-v2-m3",
70
+ "task": "qa",
71
+ "metric": "ndcg_at_1"
72
+ },
73
+ "results": [
74
+ {
75
+ "domain": "wiki",
76
+ "lang": "zh",
77
+ "dataset": "unknown",
78
+ "value": 0.78358
79
+ }
80
+ ]
81
+ },
82
+ {
83
+ "config": {
84
+ "retrieval_model": "bge-m3",
85
+ "reranking_model": "bge-reranker-v2-m3",
86
+ "task": "qa",
87
+ "metric": "ndcg_at_3"
88
+ },
89
+ "results": [
90
+ {
91
+ "domain": "wiki",
92
+ "lang": "zh",
93
+ "dataset": "unknown",
94
+ "value": 0.78358
95
+ }
96
+ ]
97
+ }
98
+ ]
tests/toydata/test_results/bge-m3/NoReranker/results_2023-11-21T18-10-08.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "config": {
4
+ "retrieval_model": "bge-m3",
5
+ "reranking_model": "NoReranker",
6
+ "task": "long_doc",
7
+ "metric": "ndcg_at_1"
8
+ },
9
+ "results": [
10
+ {
11
+ "domain": "law",
12
+ "lang": "en",
13
+ "dataset": "lex_files_500K-600K",
14
+ "value": 0.45723
15
+ }
16
+ ]
17
+ },
18
+ {
19
+ "config": {
20
+ "retrieval_model": "bge-m3",
21
+ "reranking_model": "NoReranker",
22
+ "task": "long_doc",
23
+ "metric": "ndcg_at_3"
24
+ },
25
+ "results": [
26
+ {
27
+ "domain": "law",
28
+ "lang": "en",
29
+ "dataset": "lex_files_500K-600K",
30
+ "value": 0.49909
31
+ }
32
+ ]
33
+ },
34
+ {
35
+ "config": {
36
+ "retrieval_model": "bge-m3",
37
+ "reranking_model": "NoReranker",
38
+ "task": "qa",
39
+ "metric": "ndcg_at_1"
40
+ },
41
+ "results": [
42
+ {
43
+ "domain": "wiki",
44
+ "lang": "en",
45
+ "dataset": "unknown",
46
+ "value": 0.49083
47
+ }
48
+ ]
49
+ },
50
+ {
51
+ "config": {
52
+ "retrieval_model": "bge-m3",
53
+ "reranking_model": "NoReranker",
54
+ "task": "qa",
55
+ "metric": "ndcg_at_3"
56
+ },
57
+ "results": [
58
+ {
59
+ "domain": "wiki",
60
+ "lang": "en",
61
+ "dataset": "unknown",
62
+ "value": 0.43359
63
+ }
64
+ ]
65
+ },
66
+ {
67
+ "config": {
68
+ "retrieval_model": "bge-m3",
69
+ "reranking_model": "NoReranker",
70
+ "task": "qa",
71
+ "metric": "ndcg_at_1"
72
+ },
73
+ "results": [
74
+ {
75
+ "domain": "wiki",
76
+ "lang": "zh",
77
+ "dataset": "unknown",
78
+ "value": 0.78358
79
+ }
80
+ ]
81
+ },
82
+ {
83
+ "config": {
84
+ "retrieval_model": "bge-m3",
85
+ "reranking_model": "NoReranker",
86
+ "task": "qa",
87
+ "metric": "ndcg_at_3"
88
+ },
89
+ "results": [
90
+ {
91
+ "domain": "wiki",
92
+ "lang": "zh",
93
+ "dataset": "unknown",
94
+ "value": 0.78358
95
+ }
96
+ ]
97
+ }
98
+ ]
tests/toydata/test_results/bge-m3/bge-reranker-v2-m3/results_2023-11-21T18-10-08.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "config": {
4
+ "retrieval_model": "bge-m3",
5
+ "reranking_model": "bge-reranker-v2-m3",
6
+ "task": "long_doc",
7
+ "metric": "ndcg_at_1"
8
+ },
9
+ "results": [
10
+ {
11
+ "domain": "law",
12
+ "lang": "en",
13
+ "dataset": "lex_files_500K-600K",
14
+ "value": 0.75723
15
+ }
16
+ ]
17
+ },
18
+ {
19
+ "config": {
20
+ "retrieval_model": "bge-m3",
21
+ "reranking_model": "bge-reranker-v2-m3",
22
+ "task": "long_doc",
23
+ "metric": "ndcg_at_3"
24
+ },
25
+ "results": [
26
+ {
27
+ "domain": "law",
28
+ "lang": "en",
29
+ "dataset": "lex_files_500K-600K",
30
+ "value": 0.69909
31
+ }
32
+ ]
33
+ },
34
+ {
35
+ "config": {
36
+ "retrieval_model": "bge-m3",
37
+ "reranking_model": "bge-reranker-v2-m3",
38
+ "task": "qa",
39
+ "metric": "ndcg_at_1"
40
+ },
41
+ "results": [
42
+ {
43
+ "domain": "wiki",
44
+ "lang": "en",
45
+ "dataset": "unknown",
46
+ "value": 0.69083
47
+ }
48
+ ]
49
+ },
50
+ {
51
+ "config": {
52
+ "retrieval_model": "bge-m3",
53
+ "reranking_model": "bge-reranker-v2-m3",
54
+ "task": "qa",
55
+ "metric": "ndcg_at_3"
56
+ },
57
+ "results": [
58
+ {
59
+ "domain": "wiki",
60
+ "lang": "en",
61
+ "dataset": "unknown",
62
+ "value": 0.73359
63
+ }
64
+ ]
65
+ },
66
+ {
67
+ "config": {
68
+ "retrieval_model": "bge-m3",
69
+ "reranking_model": "bge-reranker-v2-m3",
70
+ "task": "qa",
71
+ "metric": "ndcg_at_1"
72
+ },
73
+ "results": [
74
+ {
75
+ "domain": "wiki",
76
+ "lang": "zh",
77
+ "dataset": "unknown",
78
+ "value": 0.78358
79
+ }
80
+ ]
81
+ },
82
+ {
83
+ "config": {
84
+ "retrieval_model": "bge-m3",
85
+ "reranking_model": "bge-reranker-v2-m3",
86
+ "task": "qa",
87
+ "metric": "ndcg_at_3"
88
+ },
89
+ "results": [
90
+ {
91
+ "domain": "wiki",
92
+ "lang": "zh",
93
+ "dataset": "unknown",
94
+ "value": 0.78358
95
+ }
96
+ ]
97
+ }
98
+ ]