Spaces:
Runtime error
Runtime error
feat: fix the table updating
Browse files- app.py +132 -7
- src/benchmarks.py +17 -13
- src/display/utils.py +2 -1
- src/leaderboard/read_evals.py +1 -1
- src/populate.py +8 -3
- tests/src/display/test_utils.py +5 -3
- tests/src/test_populate.py +16 -0
- tests/test_utils.py +30 -2
- tests/toydata/test_results/bge-m3/NoReranker/results_2023-12-21T18-10-08.json +1 -1
- utils.py +55 -23
app.py
CHANGED
@@ -10,15 +10,17 @@ from src.about import (
|
|
10 |
from src.display.css_html_js import custom_css
|
11 |
from src.display.utils import (
|
12 |
QA_BENCHMARK_COLS,
|
13 |
-
|
|
|
|
|
14 |
TYPES,
|
15 |
AutoEvalColumnQA,
|
16 |
fields
|
17 |
)
|
18 |
from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, QUEUE_REPO, REPO_ID, RESULTS_REPO, TOKEN
|
19 |
from src.populate import get_leaderboard_df
|
20 |
-
from utils import update_table, update_metric
|
21 |
-
from src.benchmarks import DOMAIN_COLS_QA, LANG_COLS_QA, metric_list
|
22 |
|
23 |
|
24 |
def restart_space():
|
@@ -43,9 +45,15 @@ def restart_space():
|
|
43 |
|
44 |
from src.leaderboard.read_evals import get_raw_eval_results
|
45 |
raw_data_qa = get_raw_eval_results(EVAL_RESULTS_PATH, EVAL_REQUESTS_PATH)
|
46 |
-
original_df_qa = get_leaderboard_df(raw_data_qa,
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
leaderboard_df = original_df_qa.copy()
|
|
|
|
|
49 |
|
50 |
|
51 |
def update_metric_qa(
|
@@ -55,7 +63,18 @@ def update_metric_qa(
|
|
55 |
reranking_model: list,
|
56 |
query: str,
|
57 |
):
|
58 |
-
return update_metric(raw_data_qa, metric, domains, langs, reranking_model, query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# (
|
60 |
# finished_eval_queue_df,
|
61 |
# running_eval_queue_df,
|
@@ -178,7 +197,113 @@ with demo:
|
|
178 |
queue=True
|
179 |
)
|
180 |
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2):
|
184 |
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
|
|
|
10 |
from src.display.css_html_js import custom_css
|
11 |
from src.display.utils import (
|
12 |
QA_BENCHMARK_COLS,
|
13 |
+
LONG_DOC_BENCHMARK_COLS,
|
14 |
+
COLS_QA,
|
15 |
+
COLS_LONG_DOC,
|
16 |
TYPES,
|
17 |
AutoEvalColumnQA,
|
18 |
fields
|
19 |
)
|
20 |
from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, QUEUE_REPO, REPO_ID, RESULTS_REPO, TOKEN
|
21 |
from src.populate import get_leaderboard_df
|
22 |
+
from utils import update_table, update_metric, update_table_long_doc
|
23 |
+
from src.benchmarks import DOMAIN_COLS_QA, LANG_COLS_QA, DOMAIN_COLS_LONG_DOC, LANG_COLS_LONG_DOC, metric_list
|
24 |
|
25 |
|
26 |
def restart_space():
|
|
|
45 |
|
46 |
from src.leaderboard.read_evals import get_raw_eval_results
|
47 |
raw_data_qa = get_raw_eval_results(EVAL_RESULTS_PATH, EVAL_REQUESTS_PATH)
|
48 |
+
original_df_qa = get_leaderboard_df(raw_data_qa, COLS_QA, QA_BENCHMARK_COLS, task='qa', metric='ndcg_at_3')
|
49 |
+
original_df_long_doc = get_leaderboard_df(raw_data_qa, COLS_LONG_DOC, LONG_DOC_BENCHMARK_COLS, task='long_doc', metric='ndcg_at_3')
|
50 |
+
print(f'raw data: {len(raw_data_qa)}')
|
51 |
+
print(f'QA data loaded: {original_df_qa.shape}')
|
52 |
+
print(f'Long-Doc data loaded: {len(original_df_long_doc)}')
|
53 |
+
|
54 |
leaderboard_df = original_df_qa.copy()
|
55 |
+
leaderboard_df_long_doc = original_df_long_doc.copy()
|
56 |
+
print(leaderboard_df_long_doc.head())
|
57 |
|
58 |
|
59 |
def update_metric_qa(
|
|
|
63 |
reranking_model: list,
|
64 |
query: str,
|
65 |
):
|
66 |
+
return update_metric(raw_data_qa, 'qa', metric, domains, langs, reranking_model, query)
|
67 |
+
|
68 |
+
def update_metric_long_doc(
|
69 |
+
metric: str,
|
70 |
+
domains: list,
|
71 |
+
langs: list,
|
72 |
+
reranking_model: list,
|
73 |
+
query: str,
|
74 |
+
):
|
75 |
+
return update_metric(raw_data_qa, 'long_doc', metric, domains, langs, reranking_model, query)
|
76 |
+
|
77 |
+
|
78 |
# (
|
79 |
# finished_eval_queue_df,
|
80 |
# running_eval_queue_df,
|
|
|
197 |
queue=True
|
198 |
)
|
199 |
|
200 |
+
with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
|
201 |
+
with gr.Row():
|
202 |
+
with gr.Column():
|
203 |
+
with gr.Row():
|
204 |
+
search_bar = gr.Textbox(
|
205 |
+
placeholder=" 🔍 Search for your model (separate multiple queries with `;`) and press ENTER...",
|
206 |
+
show_label=False,
|
207 |
+
elem_id="search-bar-long-doc",
|
208 |
+
)
|
209 |
+
# select the metric
|
210 |
+
selected_metric = gr.Dropdown(
|
211 |
+
choices=metric_list,
|
212 |
+
value=metric_list[1],
|
213 |
+
label="Select the metric",
|
214 |
+
interactive=True,
|
215 |
+
elem_id="metric-select-long-doc",
|
216 |
+
)
|
217 |
+
with gr.Column(min_width=320):
|
218 |
+
# select domain
|
219 |
+
with gr.Row():
|
220 |
+
selected_domains = gr.CheckboxGroup(
|
221 |
+
choices=DOMAIN_COLS_LONG_DOC,
|
222 |
+
value=DOMAIN_COLS_LONG_DOC,
|
223 |
+
label="Select the domains",
|
224 |
+
elem_id="domain-column-select-long-doc",
|
225 |
+
interactive=True,
|
226 |
+
)
|
227 |
+
# select language
|
228 |
+
with gr.Row():
|
229 |
+
selected_langs = gr.CheckboxGroup(
|
230 |
+
choices=LANG_COLS_LONG_DOC,
|
231 |
+
value=LANG_COLS_LONG_DOC,
|
232 |
+
label="Select the languages",
|
233 |
+
elem_id="language-column-select-long-doc",
|
234 |
+
interactive=True
|
235 |
+
)
|
236 |
+
# select reranking model
|
237 |
+
reranking_models = list(frozenset([eval_result.reranking_model for eval_result in raw_data_qa]))
|
238 |
+
with gr.Row():
|
239 |
+
selected_rerankings = gr.CheckboxGroup(
|
240 |
+
choices=reranking_models,
|
241 |
+
value=reranking_models,
|
242 |
+
label="Select the reranking models",
|
243 |
+
elem_id="reranking-select-long-doc",
|
244 |
+
interactive=True
|
245 |
+
)
|
246 |
+
|
247 |
+
leaderboard_table_long_doc = gr.components.Dataframe(
|
248 |
+
value=leaderboard_df_long_doc,
|
249 |
+
# headers=shown_columns,
|
250 |
+
# datatype=TYPES,
|
251 |
+
elem_id="leaderboard-table-long-doc",
|
252 |
+
interactive=False,
|
253 |
+
visible=True,
|
254 |
+
)
|
255 |
+
|
256 |
+
# Dummy leaderboard for handling the case when the user uses backspace key
|
257 |
+
hidden_leaderboard_table_for_search = gr.components.Dataframe(
|
258 |
+
value=leaderboard_df_long_doc,
|
259 |
+
# headers=COLS,
|
260 |
+
# datatype=TYPES,
|
261 |
+
visible=False,
|
262 |
+
)
|
263 |
+
|
264 |
+
# Set search_bar listener
|
265 |
+
search_bar.submit(
|
266 |
+
update_table_long_doc,
|
267 |
+
[
|
268 |
+
hidden_leaderboard_table_for_search,
|
269 |
+
selected_domains,
|
270 |
+
selected_langs,
|
271 |
+
selected_rerankings,
|
272 |
+
search_bar,
|
273 |
+
],
|
274 |
+
leaderboard_table_long_doc,
|
275 |
+
)
|
276 |
+
|
277 |
+
# Set column-wise listener
|
278 |
+
for selector in [
|
279 |
+
selected_domains, selected_langs, selected_rerankings
|
280 |
+
]:
|
281 |
+
selector.change(
|
282 |
+
update_table_long_doc,
|
283 |
+
[
|
284 |
+
hidden_leaderboard_table_for_search,
|
285 |
+
selected_domains,
|
286 |
+
selected_langs,
|
287 |
+
selected_rerankings,
|
288 |
+
search_bar,
|
289 |
+
],
|
290 |
+
leaderboard_table_long_doc,
|
291 |
+
queue=True,
|
292 |
+
)
|
293 |
+
|
294 |
+
# set metric listener
|
295 |
+
selected_metric.change(
|
296 |
+
update_metric_long_doc,
|
297 |
+
[
|
298 |
+
selected_metric,
|
299 |
+
selected_domains,
|
300 |
+
selected_langs,
|
301 |
+
selected_rerankings,
|
302 |
+
search_bar,
|
303 |
+
],
|
304 |
+
leaderboard_table_long_doc,
|
305 |
+
queue=True
|
306 |
+
)
|
307 |
|
308 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2):
|
309 |
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
|
src/benchmarks.py
CHANGED
@@ -52,19 +52,19 @@ dataset_dict = {
|
|
52 |
},
|
53 |
"healthcare": {
|
54 |
"en": [
|
55 |
-
"
|
56 |
-
"
|
57 |
-
"
|
58 |
-
"
|
59 |
-
"
|
60 |
]
|
61 |
},
|
62 |
"law": {
|
63 |
"en": [
|
64 |
-
"
|
65 |
-
"
|
66 |
-
"
|
67 |
-
"
|
68 |
]
|
69 |
}
|
70 |
}
|
@@ -121,21 +121,25 @@ for task, domain_dict in dataset_dict.items():
|
|
121 |
if task == "qa":
|
122 |
benchmark_name = f"{domain}_{lang}"
|
123 |
benchmark_name = get_safe_name(benchmark_name)
|
124 |
-
col_name =
|
125 |
for metric in dataset_list:
|
126 |
qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
|
127 |
elif task == "long_doc":
|
128 |
for dataset in dataset_list:
|
129 |
-
|
|
|
|
|
130 |
for metric in metric_list:
|
131 |
-
benchmark_name = f"{domain}_{lang}_{dataset}_{metric}"
|
132 |
-
benchmark_name = get_safe_name(benchmark_name)
|
133 |
long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
|
134 |
|
135 |
BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
|
136 |
BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)
|
137 |
|
138 |
BENCHMARK_COLS_QA = [c.col_name for c in qa_benchmark_dict.values()]
|
|
|
139 |
|
140 |
DOMAIN_COLS_QA = list(frozenset([c.domain for c in qa_benchmark_dict.values()]))
|
141 |
LANG_COLS_QA = list(frozenset([c.lang for c in qa_benchmark_dict.values()]))
|
|
|
|
|
|
|
|
52 |
},
|
53 |
"healthcare": {
|
54 |
"en": [
|
55 |
+
"pubmed_100k-200k_1",
|
56 |
+
"pubmed_100k-200k_2",
|
57 |
+
"pubmed_100k-200k_3",
|
58 |
+
"pubmed_40k-50k_5-merged",
|
59 |
+
"pubmed_30k-40k_10-merged"
|
60 |
]
|
61 |
},
|
62 |
"law": {
|
63 |
"en": [
|
64 |
+
"lex_files_300k-400k",
|
65 |
+
"lex_files_400k-500k",
|
66 |
+
"lex_files_500k-600k",
|
67 |
+
"lex_files_600k-700k"
|
68 |
]
|
69 |
}
|
70 |
}
|
|
|
121 |
if task == "qa":
|
122 |
benchmark_name = f"{domain}_{lang}"
|
123 |
benchmark_name = get_safe_name(benchmark_name)
|
124 |
+
col_name = benchmark_name
|
125 |
for metric in dataset_list:
|
126 |
qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
|
127 |
elif task == "long_doc":
|
128 |
for dataset in dataset_list:
|
129 |
+
benchmark_name = f"{domain}_{lang}_{dataset}"
|
130 |
+
benchmark_name = get_safe_name(benchmark_name)
|
131 |
+
col_name = benchmark_name
|
132 |
for metric in metric_list:
|
|
|
|
|
133 |
long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
|
134 |
|
135 |
BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
|
136 |
BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)
|
137 |
|
138 |
BENCHMARK_COLS_QA = [c.col_name for c in qa_benchmark_dict.values()]
|
139 |
+
BENCHMARK_COLS_LONG_DOC = [c.col_name for c in long_doc_benchmark_dict.values()]
|
140 |
|
141 |
DOMAIN_COLS_QA = list(frozenset([c.domain for c in qa_benchmark_dict.values()]))
|
142 |
LANG_COLS_QA = list(frozenset([c.lang for c in qa_benchmark_dict.values()]))
|
143 |
+
|
144 |
+
DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in long_doc_benchmark_dict.values()]))
|
145 |
+
LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in long_doc_benchmark_dict.values()]))
|
src/display/utils.py
CHANGED
@@ -55,7 +55,8 @@ class EvalQueueColumn: # Queue column
|
|
55 |
|
56 |
|
57 |
# Column selection
|
58 |
-
|
|
|
59 |
TYPES = [c.type for c in fields(AutoEvalColumnQA) if not c.hidden]
|
60 |
COLS_LITE = [c.name for c in fields(AutoEvalColumnQA) if c.displayed_by_default and not c.hidden]
|
61 |
|
|
|
55 |
|
56 |
|
57 |
# Column selection
|
58 |
+
COLS_QA = [c.name for c in fields(AutoEvalColumnQA) if not c.hidden]
|
59 |
+
COLS_LONG_DOC = [c.name for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
|
60 |
TYPES = [c.type for c in fields(AutoEvalColumnQA) if not c.hidden]
|
61 |
COLS_LITE = [c.name for c in fields(AutoEvalColumnQA) if c.displayed_by_default and not c.hidden]
|
62 |
|
src/leaderboard/read_evals.py
CHANGED
@@ -87,7 +87,7 @@ class FullEvalResult:
|
|
87 |
if task == 'qa':
|
88 |
benchmark_name = f"{domain}_{lang}"
|
89 |
elif task == 'long_doc':
|
90 |
-
benchmark_name = f"{domain}_{lang}_{dataset}
|
91 |
results[eval_result.eval_name][get_safe_name(benchmark_name)] = value
|
92 |
return [v for v in results.values()]
|
93 |
|
|
|
87 |
if task == 'qa':
|
88 |
benchmark_name = f"{domain}_{lang}"
|
89 |
elif task == 'long_doc':
|
90 |
+
benchmark_name = f"{domain}_{lang}_{dataset}"
|
91 |
results[eval_result.eval_name][get_safe_name(benchmark_name)] = value
|
92 |
return [v for v in results.values()]
|
93 |
|
src/populate.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
import pandas as pd
|
5 |
|
6 |
from src.display.formatting import has_no_nan_values, make_clickable_model
|
7 |
-
from src.display.utils import AutoEvalColumnQA, EvalQueueColumn
|
8 |
from src.leaderboard.read_evals import get_raw_eval_results, EvalResult, FullEvalResult
|
9 |
from typing import Tuple, List
|
10 |
|
@@ -19,8 +19,13 @@ def get_leaderboard_df(raw_data: List[FullEvalResult], cols: list, benchmark_col
|
|
19 |
|
20 |
# calculate the average score for selected benchmarks
|
21 |
_benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
24 |
df.reset_index(inplace=True)
|
25 |
|
26 |
_cols = frozenset(cols).intersection(frozenset(df.columns.to_list()))
|
|
|
4 |
import pandas as pd
|
5 |
|
6 |
from src.display.formatting import has_no_nan_values, make_clickable_model
|
7 |
+
from src.display.utils import AutoEvalColumnQA, AutoEvalColumnLongDoc, EvalQueueColumn
|
8 |
from src.leaderboard.read_evals import get_raw_eval_results, EvalResult, FullEvalResult
|
9 |
from typing import Tuple, List
|
10 |
|
|
|
19 |
|
20 |
# calculate the average score for selected benchmarks
|
21 |
_benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
|
22 |
+
if task == 'qa':
|
23 |
+
df[AutoEvalColumnQA.average.name] = df[list(_benchmark_cols)].mean(axis=1).round(decimals=2)
|
24 |
+
df = df.sort_values(by=[AutoEvalColumnQA.average.name], ascending=False)
|
25 |
+
elif task == "long_doc":
|
26 |
+
df[AutoEvalColumnLongDoc.average.name] = df[list(_benchmark_cols)].mean(axis=1).round(decimals=2)
|
27 |
+
df = df.sort_values(by=[AutoEvalColumnLongDoc.average.name], ascending=False)
|
28 |
+
|
29 |
df.reset_index(inplace=True)
|
30 |
|
31 |
_cols = frozenset(cols).intersection(frozenset(df.columns.to_list()))
|
tests/src/display/test_utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import pytest
|
2 |
-
from src.display.utils import fields, AutoEvalColumnQA, AutoEvalColumnLongDoc,
|
3 |
|
4 |
|
5 |
def test_fields():
|
@@ -8,8 +8,10 @@ def test_fields():
|
|
8 |
|
9 |
|
10 |
def test_macro_variables():
|
11 |
-
print(f'
|
|
|
12 |
print(f'COLS_LITE: {COLS_LITE}')
|
13 |
print(f'TYPES: {TYPES}')
|
14 |
print(f'EVAL_COLS: {EVAL_COLS}')
|
15 |
-
print(f'
|
|
|
|
1 |
import pytest
|
2 |
+
from src.display.utils import fields, AutoEvalColumnQA, AutoEvalColumnLongDoc, COLS_QA, COLS_LONG_DOC, COLS_LITE, TYPES, EVAL_COLS, QA_BENCHMARK_COLS, LONG_DOC_BENCHMARK_COLS
|
3 |
|
4 |
|
5 |
def test_fields():
|
|
|
8 |
|
9 |
|
10 |
def test_macro_variables():
|
11 |
+
print(f'COLS_QA: {COLS_QA}')
|
12 |
+
print(f'COLS_LONG_DOC: {COLS_LONG_DOC}')
|
13 |
print(f'COLS_LITE: {COLS_LITE}')
|
14 |
print(f'TYPES: {TYPES}')
|
15 |
print(f'EVAL_COLS: {EVAL_COLS}')
|
16 |
+
print(f'QA_BENCHMARK_COLS: {QA_BENCHMARK_COLS}')
|
17 |
+
print(f'LONG_DOC_BENCHMARK_COLS: {LONG_DOC_BENCHMARK_COLS}')
|
tests/src/test_populate.py
CHANGED
@@ -23,3 +23,19 @@ def test_get_leaderboard_df():
|
|
23 |
assert not df[['Average ⬆️', 'wiki_en', 'wiki_zh',]].isnull().values.any()
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
assert not df[['Average ⬆️', 'wiki_en', 'wiki_zh',]].isnull().values.any()
|
24 |
|
25 |
|
26 |
+
def test_get_leaderboard_df_long_doc():
|
27 |
+
requests_path = cur_fp.parents[1] / "toydata" / "test_requests"
|
28 |
+
results_path = cur_fp.parents[1] / "toydata" / "test_results"
|
29 |
+
cols = ['Retrieval Model', 'Reranking Model', 'Average ⬆️', 'law_en_lex_files_500k_600k',]
|
30 |
+
benchmark_cols = ['law_en_lex_files_500k_600k',]
|
31 |
+
raw_data = get_raw_eval_results(results_path, requests_path)
|
32 |
+
df = get_leaderboard_df(raw_data, cols, benchmark_cols, 'long_doc', 'ndcg_at_1')
|
33 |
+
assert df.shape[0] == 2
|
34 |
+
# the results contain only one embedding model
|
35 |
+
for i in range(2):
|
36 |
+
assert df["Retrieval Model"][i] == "bge-m3"
|
37 |
+
# the results contains only two reranking model
|
38 |
+
assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
|
39 |
+
assert df["Reranking Model"][1] == "NoReranker"
|
40 |
+
assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
|
41 |
+
assert not df[['Average ⬆️', 'law_en_lex_files_500k_600k',]].isnull().values.any()
|
tests/test_utils.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
-
from utils import filter_models, search_table, filter_queries, select_columns
|
5 |
|
6 |
|
7 |
@pytest.fixture
|
@@ -29,6 +29,29 @@ def toy_df():
|
|
29 |
)
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def test_filter_models(toy_df):
|
33 |
df_result = filter_models(toy_df, ["bge-reranker-v2-m3", ])
|
34 |
assert len(df_result) == 2
|
@@ -50,4 +73,9 @@ def test_filter_queries(toy_df):
|
|
50 |
def test_select_columns(toy_df):
|
51 |
df_result = select_columns(toy_df, ['news',], ['zh',])
|
52 |
assert len(df_result.columns) == 4
|
53 |
-
assert df_result['Average ⬆️'].equals(df_result['news_zh'])
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
+
from utils import filter_models, search_table, filter_queries, select_columns, update_table_long_doc
|
5 |
|
6 |
|
7 |
@pytest.fixture
|
|
|
29 |
)
|
30 |
|
31 |
|
32 |
+
@pytest.fixture
|
33 |
+
def toy_df_long_doc():
|
34 |
+
return pd.DataFrame(
|
35 |
+
{
|
36 |
+
"Retrieval Model": [
|
37 |
+
"bge-m3",
|
38 |
+
"bge-m3",
|
39 |
+
"jina-embeddings-v2-base",
|
40 |
+
"jina-embeddings-v2-base"
|
41 |
+
],
|
42 |
+
"Reranking Model": [
|
43 |
+
"bge-reranker-v2-m3",
|
44 |
+
"NoReranker",
|
45 |
+
"bge-reranker-v2-m3",
|
46 |
+
"NoReranker"
|
47 |
+
],
|
48 |
+
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
49 |
+
"law_en_lex_files_300k_400k": [0.4, 0.1, 0.4, 0.3],
|
50 |
+
"law_en_lex_files_400k_500k": [0.8, 0.7, 0.2, 0.1],
|
51 |
+
"law_en_lex_files_500k_600k": [0.8, 0.7, 0.2, 0.1],
|
52 |
+
"law_en_lex_files_600k_700k": [0.4, 0.1, 0.4, 0.3],
|
53 |
+
}
|
54 |
+
)
|
55 |
def test_filter_models(toy_df):
|
56 |
df_result = filter_models(toy_df, ["bge-reranker-v2-m3", ])
|
57 |
assert len(df_result) == 2
|
|
|
73 |
def test_select_columns(toy_df):
|
74 |
df_result = select_columns(toy_df, ['news',], ['zh',])
|
75 |
assert len(df_result.columns) == 4
|
76 |
+
assert df_result['Average ⬆️'].equals(df_result['news_zh'])
|
77 |
+
|
78 |
+
|
79 |
+
def test_update_table_long_doc(toy_df_long_doc):
|
80 |
+
df_result = update_table_long_doc(toy_df_long_doc, ['law',], ['en',], ["bge-reranker-v2-m3", ], "jina")
|
81 |
+
print(df_result)
|
tests/toydata/test_results/bge-m3/NoReranker/results_2023-12-21T18-10-08.json
CHANGED
@@ -11,7 +11,7 @@
|
|
11 |
"domain": "law",
|
12 |
"lang": "en",
|
13 |
"dataset": "lex_files_500K-600K",
|
14 |
-
"value": 0.
|
15 |
}
|
16 |
]
|
17 |
},
|
|
|
11 |
"domain": "law",
|
12 |
"lang": "en",
|
13 |
"dataset": "lex_files_500K-600K",
|
14 |
+
"value": 0.45723
|
15 |
}
|
16 |
]
|
17 |
},
|
utils.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
import pandas as pd
|
2 |
|
3 |
-
from src.display.utils import AutoEvalColumnQA,
|
4 |
-
from src.benchmarks import BENCHMARK_COLS_QA, BenchmarksQA
|
5 |
from src.leaderboard.read_evals import FullEvalResult
|
6 |
from typing import List
|
7 |
from src.populate import get_leaderboard_df
|
8 |
-
from src.display.utils import COLS, QA_BENCHMARK_COLS
|
9 |
|
10 |
|
11 |
def filter_models(df: pd.DataFrame, reranking_query: list) -> pd.DataFrame:
|
@@ -38,19 +37,29 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
|
|
38 |
return df[(df[AutoEvalColumnQA.retrieval_model.name].str.contains(query, case=False))]
|
39 |
|
40 |
|
41 |
-
def select_columns(df: pd.DataFrame, domain_query: list, language_query: list) -> pd.DataFrame:
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
selected_cols = []
|
48 |
-
for c in
|
49 |
if c not in df.columns:
|
50 |
continue
|
51 |
-
if
|
52 |
-
|
53 |
-
|
|
|
54 |
if eval_col.domain not in domain_query:
|
55 |
continue
|
56 |
if eval_col.lang not in language_query:
|
@@ -58,7 +67,7 @@ def select_columns(df: pd.DataFrame, domain_query: list, language_query: list) -
|
|
58 |
selected_cols.append(c)
|
59 |
# We use COLS to maintain sorting
|
60 |
filtered_df = df[always_here_cols + selected_cols]
|
61 |
-
filtered_df[
|
62 |
return filtered_df
|
63 |
|
64 |
|
@@ -75,20 +84,43 @@ def update_table(
|
|
75 |
return df
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
def update_metric(
|
79 |
raw_data: List[FullEvalResult],
|
|
|
80 |
metric: str,
|
81 |
domains: list,
|
82 |
langs: list,
|
83 |
reranking_model: list,
|
84 |
query: str,
|
85 |
) -> pd.DataFrame:
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
|
3 |
+
from src.display.utils import AutoEvalColumnQA, AutoEvalColumnLongDoc, COLS_QA, COLS_LONG_DOC, QA_BENCHMARK_COLS, LONG_DOC_BENCHMARK_COLS
|
4 |
+
from src.benchmarks import BENCHMARK_COLS_QA, BENCHMARK_COLS_LONG_DOC, BenchmarksQA, BenchmarksLongDoc
|
5 |
from src.leaderboard.read_evals import FullEvalResult
|
6 |
from typing import List
|
7 |
from src.populate import get_leaderboard_df
|
|
|
8 |
|
9 |
|
10 |
def filter_models(df: pd.DataFrame, reranking_query: list) -> pd.DataFrame:
|
|
|
37 |
return df[(df[AutoEvalColumnQA.retrieval_model.name].str.contains(query, case=False))]
|
38 |
|
39 |
|
40 |
+
def select_columns(df: pd.DataFrame, domain_query: list, language_query: list, task: str="qa") -> pd.DataFrame:
|
41 |
+
if task == "qa":
|
42 |
+
always_here_cols = [
|
43 |
+
AutoEvalColumnQA.retrieval_model.name,
|
44 |
+
AutoEvalColumnQA.reranking_model.name,
|
45 |
+
AutoEvalColumnQA.average.name
|
46 |
+
]
|
47 |
+
cols = list(frozenset(COLS_QA).intersection(frozenset(BENCHMARK_COLS_QA)))
|
48 |
+
elif task == "long_doc":
|
49 |
+
always_here_cols = [
|
50 |
+
AutoEvalColumnLongDoc.retrieval_model.name,
|
51 |
+
AutoEvalColumnLongDoc.reranking_model.name,
|
52 |
+
AutoEvalColumnLongDoc.average.name
|
53 |
+
]
|
54 |
+
cols = list(frozenset(COLS_LONG_DOC).intersection(frozenset(BENCHMARK_COLS_LONG_DOC)))
|
55 |
selected_cols = []
|
56 |
+
for c in cols:
|
57 |
if c not in df.columns:
|
58 |
continue
|
59 |
+
if task == "qa":
|
60 |
+
eval_col = BenchmarksQA[c].value
|
61 |
+
elif task == "long_doc":
|
62 |
+
eval_col = BenchmarksLongDoc[c].value
|
63 |
if eval_col.domain not in domain_query:
|
64 |
continue
|
65 |
if eval_col.lang not in language_query:
|
|
|
67 |
selected_cols.append(c)
|
68 |
# We use COLS to maintain sorting
|
69 |
filtered_df = df[always_here_cols + selected_cols]
|
70 |
+
filtered_df[always_here_cols[2]] = filtered_df[selected_cols].mean(axis=1).round(decimals=2)
|
71 |
return filtered_df
|
72 |
|
73 |
|
|
|
84 |
return df
|
85 |
|
86 |
|
87 |
+
def update_table_long_doc(
|
88 |
+
hidden_df: pd.DataFrame,
|
89 |
+
domains: list,
|
90 |
+
langs: list,
|
91 |
+
reranking_query: list,
|
92 |
+
query: str,
|
93 |
+
):
|
94 |
+
filtered_df = filter_models(hidden_df, reranking_query)
|
95 |
+
filtered_df = filter_queries(query, filtered_df)
|
96 |
+
df = select_columns(filtered_df, domains, langs, task='long_doc')
|
97 |
+
return df
|
98 |
+
|
99 |
+
|
100 |
def update_metric(
|
101 |
raw_data: List[FullEvalResult],
|
102 |
+
task: str,
|
103 |
metric: str,
|
104 |
domains: list,
|
105 |
langs: list,
|
106 |
reranking_model: list,
|
107 |
query: str,
|
108 |
) -> pd.DataFrame:
|
109 |
+
if task == 'qa':
|
110 |
+
leaderboard_df = get_leaderboard_df(raw_data, COLS_QA, QA_BENCHMARK_COLS, task=task, metric=metric)
|
111 |
+
return update_table(
|
112 |
+
leaderboard_df,
|
113 |
+
domains,
|
114 |
+
langs,
|
115 |
+
reranking_model,
|
116 |
+
query
|
117 |
+
)
|
118 |
+
elif task == 'long_doc':
|
119 |
+
leaderboard_df = get_leaderboard_df(raw_data, COLS_LONG_DOC, LONG_DOC_BENCHMARK_COLS, task=task, metric=metric)
|
120 |
+
return update_table_long_doc(
|
121 |
+
leaderboard_df,
|
122 |
+
domains,
|
123 |
+
langs,
|
124 |
+
reranking_model,
|
125 |
+
query
|
126 |
+
)
|