Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
feat: add revision and timestamp information
Browse files- app.py +13 -11
- src/display/utils.py +20 -7
- src/leaderboard/read_evals.py +22 -6
- tests/src/display/test_utils.py +3 -2
- tests/test_utils.py +8 -1
- utils.py +26 -11
app.py
CHANGED
@@ -8,18 +8,18 @@ from src.about import (
|
|
8 |
TITLE,
|
9 |
EVALUATION_QUEUE_TEXT
|
10 |
)
|
|
|
|
|
11 |
from src.display.css_html_js import custom_css
|
12 |
-
from src.leaderboard.read_evals import get_raw_eval_results, get_leaderboard_df
|
13 |
-
|
14 |
from src.envs import API, EVAL_RESULTS_PATH, REPO_ID, RESULTS_REPO, TOKEN
|
|
|
15 |
from utils import update_table, update_metric, update_table_long_doc, upload_file, get_default_cols, submit_results
|
16 |
-
from src.benchmarks import DOMAIN_COLS_QA, LANG_COLS_QA, DOMAIN_COLS_LONG_DOC, LANG_COLS_LONG_DOC, METRIC_LIST, DEFAULT_METRIC
|
17 |
-
from src.display.utils import TYPES_QA, TYPES_LONG_DOC
|
18 |
|
19 |
|
20 |
def restart_space():
|
21 |
API.restart_space(repo_id=REPO_ID)
|
22 |
|
|
|
23 |
try:
|
24 |
snapshot_download(
|
25 |
repo_id=RESULTS_REPO, local_dir=EVAL_RESULTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30,
|
@@ -39,11 +39,12 @@ print(f'QA data loaded: {original_df_qa.shape}')
|
|
39 |
print(f'Long-Doc data loaded: {len(original_df_long_doc)}')
|
40 |
|
41 |
leaderboard_df_qa = original_df_qa.copy()
|
42 |
-
shown_columns_qa = get_default_cols('qa', leaderboard_df_qa.columns, add_fix_cols=True)
|
43 |
leaderboard_df_qa = leaderboard_df_qa[shown_columns_qa]
|
44 |
|
45 |
leaderboard_df_long_doc = original_df_long_doc.copy()
|
46 |
-
shown_columns_long_doc = get_default_cols('long-doc', leaderboard_df_long_doc.columns,
|
|
|
47 |
leaderboard_df_long_doc = leaderboard_df_long_doc[shown_columns_long_doc]
|
48 |
|
49 |
|
@@ -56,6 +57,7 @@ def update_metric_qa(
|
|
56 |
):
|
57 |
return update_metric(raw_data, 'qa', metric, domains, langs, reranking_model, query)
|
58 |
|
|
|
59 |
def update_metric_long_doc(
|
60 |
metric: str,
|
61 |
domains: list,
|
@@ -124,7 +126,7 @@ with demo:
|
|
124 |
|
125 |
leaderboard_table = gr.components.Dataframe(
|
126 |
value=leaderboard_df_qa,
|
127 |
-
datatype=
|
128 |
elem_id="leaderboard-table",
|
129 |
interactive=False,
|
130 |
visible=True,
|
@@ -133,7 +135,7 @@ with demo:
|
|
133 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
134 |
hidden_leaderboard_table_for_search = gr.components.Dataframe(
|
135 |
value=leaderboard_df_qa,
|
136 |
-
datatype=
|
137 |
# headers=COLS,
|
138 |
# datatype=TYPES,
|
139 |
visible=False,
|
@@ -234,7 +236,7 @@ with demo:
|
|
234 |
|
235 |
leaderboard_table_long_doc = gr.components.Dataframe(
|
236 |
value=leaderboard_df_long_doc,
|
237 |
-
datatype=
|
238 |
elem_id="leaderboard-table-long-doc",
|
239 |
interactive=False,
|
240 |
visible=True,
|
@@ -243,7 +245,7 @@ with demo:
|
|
243 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
244 |
hidden_leaderboard_table_for_search = gr.components.Dataframe(
|
245 |
value=leaderboard_df_long_doc,
|
246 |
-
datatype=
|
247 |
visible=False,
|
248 |
)
|
249 |
|
@@ -300,7 +302,7 @@ with demo:
|
|
300 |
with gr.Row():
|
301 |
with gr.Column():
|
302 |
benchmark_version = gr.Dropdown(
|
303 |
-
["AIR-Bench_24.04",],
|
304 |
value="AIR-Bench_24.04",
|
305 |
interactive=True,
|
306 |
label="AIR-Bench Version")
|
|
|
8 |
TITLE,
|
9 |
EVALUATION_QUEUE_TEXT
|
10 |
)
|
11 |
+
from src.benchmarks import DOMAIN_COLS_QA, LANG_COLS_QA, DOMAIN_COLS_LONG_DOC, LANG_COLS_LONG_DOC, METRIC_LIST, \
|
12 |
+
DEFAULT_METRIC
|
13 |
from src.display.css_html_js import custom_css
|
|
|
|
|
14 |
from src.envs import API, EVAL_RESULTS_PATH, REPO_ID, RESULTS_REPO, TOKEN
|
15 |
+
from src.leaderboard.read_evals import get_raw_eval_results, get_leaderboard_df
|
16 |
from utils import update_table, update_metric, update_table_long_doc, upload_file, get_default_cols, submit_results
|
|
|
|
|
17 |
|
18 |
|
19 |
def restart_space():
|
20 |
API.restart_space(repo_id=REPO_ID)
|
21 |
|
22 |
+
|
23 |
try:
|
24 |
snapshot_download(
|
25 |
repo_id=RESULTS_REPO, local_dir=EVAL_RESULTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30,
|
|
|
39 |
print(f'Long-Doc data loaded: {len(original_df_long_doc)}')
|
40 |
|
41 |
leaderboard_df_qa = original_df_qa.copy()
|
42 |
+
shown_columns_qa, types_qa = get_default_cols('qa', leaderboard_df_qa.columns, add_fix_cols=True)
|
43 |
leaderboard_df_qa = leaderboard_df_qa[shown_columns_qa]
|
44 |
|
45 |
leaderboard_df_long_doc = original_df_long_doc.copy()
|
46 |
+
shown_columns_long_doc, types_long_doc = get_default_cols('long-doc', leaderboard_df_long_doc.columns,
|
47 |
+
add_fix_cols=True)
|
48 |
leaderboard_df_long_doc = leaderboard_df_long_doc[shown_columns_long_doc]
|
49 |
|
50 |
|
|
|
57 |
):
|
58 |
return update_metric(raw_data, 'qa', metric, domains, langs, reranking_model, query)
|
59 |
|
60 |
+
|
61 |
def update_metric_long_doc(
|
62 |
metric: str,
|
63 |
domains: list,
|
|
|
126 |
|
127 |
leaderboard_table = gr.components.Dataframe(
|
128 |
value=leaderboard_df_qa,
|
129 |
+
datatype=types_qa,
|
130 |
elem_id="leaderboard-table",
|
131 |
interactive=False,
|
132 |
visible=True,
|
|
|
135 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
136 |
hidden_leaderboard_table_for_search = gr.components.Dataframe(
|
137 |
value=leaderboard_df_qa,
|
138 |
+
datatype=types_qa,
|
139 |
# headers=COLS,
|
140 |
# datatype=TYPES,
|
141 |
visible=False,
|
|
|
236 |
|
237 |
leaderboard_table_long_doc = gr.components.Dataframe(
|
238 |
value=leaderboard_df_long_doc,
|
239 |
+
datatype=types_long_doc,
|
240 |
elem_id="leaderboard-table-long-doc",
|
241 |
interactive=False,
|
242 |
visible=True,
|
|
|
245 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
246 |
hidden_leaderboard_table_for_search = gr.components.Dataframe(
|
247 |
value=leaderboard_df_long_doc,
|
248 |
+
datatype=types_long_doc,
|
249 |
visible=False,
|
250 |
)
|
251 |
|
|
|
302 |
with gr.Row():
|
303 |
with gr.Column():
|
304 |
benchmark_version = gr.Dropdown(
|
305 |
+
["AIR-Bench_24.04", ],
|
306 |
value="AIR-Bench_24.04",
|
307 |
interactive=True,
|
308 |
label="AIR-Bench Version")
|
src/display/utils.py
CHANGED
@@ -25,29 +25,42 @@ COL_NAME_RERANKING_MODEL = "Reranking Model"
|
|
25 |
COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
|
26 |
COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
|
27 |
COL_NAME_RANK = "Rank 🏆"
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
|
31 |
auto_eval_column_dict = []
|
32 |
# Init
|
33 |
auto_eval_column_dict.append(
|
34 |
-
["
|
|
|
|
|
|
|
35 |
)
|
36 |
auto_eval_column_dict.append(
|
37 |
-
["reranking_model", ColumnContent, ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, never_hidden=True)]
|
38 |
)
|
39 |
auto_eval_column_dict.append(
|
40 |
-
["
|
41 |
)
|
42 |
auto_eval_column_dict.append(
|
43 |
-
["
|
44 |
)
|
45 |
auto_eval_column_dict.append(
|
46 |
["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)]
|
47 |
)
|
48 |
auto_eval_column_dict.append(
|
49 |
-
["
|
|
|
|
|
|
|
50 |
)
|
|
|
|
|
|
|
|
|
|
|
51 |
for benchmark in benchmarks:
|
52 |
auto_eval_column_dict.append(
|
53 |
[benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
|
|
|
25 |
COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
|
26 |
COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
|
27 |
COL_NAME_RANK = "Rank 🏆"
|
28 |
+
COL_NAME_REVISION = "Revision"
|
29 |
+
COL_NAME_TIMESTAMP = "Submission Date"
|
30 |
|
31 |
+
|
32 |
+
def get_default_auto_eval_column_dict():
|
33 |
auto_eval_column_dict = []
|
34 |
# Init
|
35 |
auto_eval_column_dict.append(
|
36 |
+
["rank", ColumnContent, ColumnContent(COL_NAME_RANK, "number", True)]
|
37 |
+
)
|
38 |
+
auto_eval_column_dict.append(
|
39 |
+
["retrieval_model", ColumnContent, ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", True, hidden=False, never_hidden=True)]
|
40 |
)
|
41 |
auto_eval_column_dict.append(
|
42 |
+
["reranking_model", ColumnContent, ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, hidden=False, never_hidden=True)]
|
43 |
)
|
44 |
auto_eval_column_dict.append(
|
45 |
+
["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
|
46 |
)
|
47 |
auto_eval_column_dict.append(
|
48 |
+
["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
|
49 |
)
|
50 |
auto_eval_column_dict.append(
|
51 |
["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)]
|
52 |
)
|
53 |
auto_eval_column_dict.append(
|
54 |
+
["retrieval_model_link", ColumnContent, ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", False, hidden=True, never_hidden=False)]
|
55 |
+
)
|
56 |
+
auto_eval_column_dict.append(
|
57 |
+
["reranking_model_link", ColumnContent, ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", False, hidden=True, never_hidden=False)]
|
58 |
)
|
59 |
+
return auto_eval_column_dict
|
60 |
+
|
61 |
+
def make_autoevalcolumn(cls_name="BenchmarksQA", benchmarks=BenchmarksQA):
|
62 |
+
auto_eval_column_dict = get_default_auto_eval_column_dict()
|
63 |
+
## Leaderboard columns
|
64 |
for benchmark in benchmarks:
|
65 |
auto_eval_column_dict.append(
|
66 |
[benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
|
src/leaderboard/read_evals.py
CHANGED
@@ -13,6 +13,8 @@ from src.display.utils import (
|
|
13 |
COL_NAME_RETRIEVAL_MODEL,
|
14 |
COL_NAME_RERANKING_MODEL_LINK,
|
15 |
COL_NAME_RETRIEVAL_MODEL_LINK,
|
|
|
|
|
16 |
COLS_QA,
|
17 |
QA_BENCHMARK_COLS,
|
18 |
COLS_LONG_DOC,
|
@@ -37,6 +39,7 @@ class EvalResult:
|
|
37 |
task: str
|
38 |
metric: str
|
39 |
timestamp: str = "" # submission timestamp
|
|
|
40 |
|
41 |
|
42 |
@dataclass
|
@@ -50,7 +53,8 @@ class FullEvalResult:
|
|
50 |
retrieval_model_link: str
|
51 |
reranking_model_link: str
|
52 |
results: List[EvalResult] # results on all the EvalResults over different tasks and metrics.
|
53 |
-
|
|
|
54 |
|
55 |
@classmethod
|
56 |
def init_from_json_file(cls, json_filepath):
|
@@ -65,20 +69,25 @@ class FullEvalResult:
|
|
65 |
result_list = []
|
66 |
retrieval_model_link = ""
|
67 |
reranking_model_link = ""
|
|
|
68 |
for item in model_data:
|
69 |
config = item.get("config", {})
|
70 |
# eval results for different metrics
|
71 |
results = item.get("results", [])
|
72 |
-
retrieval_model_link=config["retrieval_model_link"]
|
73 |
-
if config["reranking_model_link"] is
|
74 |
-
reranking_model_link=""
|
|
|
|
|
75 |
eval_result = EvalResult(
|
76 |
eval_name=f"{config['retrieval_model']}_{config['reranking_model']}_{config['metric']}",
|
77 |
retrieval_model=config["retrieval_model"],
|
78 |
reranking_model=config["reranking_model"],
|
79 |
results=results,
|
80 |
task=config["task"],
|
81 |
-
metric=config["metric"]
|
|
|
|
|
82 |
)
|
83 |
result_list.append(eval_result)
|
84 |
return cls(
|
@@ -87,7 +96,9 @@ class FullEvalResult:
|
|
87 |
reranking_model=result_list[0].reranking_model,
|
88 |
retrieval_model_link=retrieval_model_link,
|
89 |
reranking_model_link=reranking_model_link,
|
90 |
-
results=result_list
|
|
|
|
|
91 |
)
|
92 |
|
93 |
def to_dict(self, task='qa', metric='ndcg_at_3') -> List:
|
@@ -107,6 +118,8 @@ class FullEvalResult:
|
|
107 |
make_clickable_model(self.reranking_model, self.reranking_model_link))
|
108 |
results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
|
109 |
results[eval_result.eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
|
|
|
|
|
110 |
|
111 |
# print(f'result loaded: {eval_result.eval_name}')
|
112 |
for result in eval_result.results:
|
@@ -193,4 +206,7 @@ def get_leaderboard_df(raw_data: List[FullEvalResult], task: str, metric: str) -
|
|
193 |
# filter out if any of the benchmarks have not been produced
|
194 |
df = df[has_no_nan_values(df, _benchmark_cols)]
|
195 |
df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
|
|
|
|
|
|
|
196 |
return df
|
|
|
13 |
COL_NAME_RETRIEVAL_MODEL,
|
14 |
COL_NAME_RERANKING_MODEL_LINK,
|
15 |
COL_NAME_RETRIEVAL_MODEL_LINK,
|
16 |
+
COL_NAME_REVISION,
|
17 |
+
COL_NAME_TIMESTAMP,
|
18 |
COLS_QA,
|
19 |
QA_BENCHMARK_COLS,
|
20 |
COLS_LONG_DOC,
|
|
|
39 |
task: str
|
40 |
metric: str
|
41 |
timestamp: str = "" # submission timestamp
|
42 |
+
revision: str = ""
|
43 |
|
44 |
|
45 |
@dataclass
|
|
|
53 |
retrieval_model_link: str
|
54 |
reranking_model_link: str
|
55 |
results: List[EvalResult] # results on all the EvalResults over different tasks and metrics.
|
56 |
+
timestamp: str = ""
|
57 |
+
revision: str = ""
|
58 |
|
59 |
@classmethod
|
60 |
def init_from_json_file(cls, json_filepath):
|
|
|
69 |
result_list = []
|
70 |
retrieval_model_link = ""
|
71 |
reranking_model_link = ""
|
72 |
+
revision = ""
|
73 |
for item in model_data:
|
74 |
config = item.get("config", {})
|
75 |
# eval results for different metrics
|
76 |
results = item.get("results", [])
|
77 |
+
retrieval_model_link = config["retrieval_model_link"]
|
78 |
+
if config["reranking_model_link"] is None:
|
79 |
+
reranking_model_link = ""
|
80 |
+
else:
|
81 |
+
reranking_model_link = config["reranking_model_link"]
|
82 |
eval_result = EvalResult(
|
83 |
eval_name=f"{config['retrieval_model']}_{config['reranking_model']}_{config['metric']}",
|
84 |
retrieval_model=config["retrieval_model"],
|
85 |
reranking_model=config["reranking_model"],
|
86 |
results=results,
|
87 |
task=config["task"],
|
88 |
+
metric=config["metric"],
|
89 |
+
timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
|
90 |
+
revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e")
|
91 |
)
|
92 |
result_list.append(eval_result)
|
93 |
return cls(
|
|
|
96 |
reranking_model=result_list[0].reranking_model,
|
97 |
retrieval_model_link=retrieval_model_link,
|
98 |
reranking_model_link=reranking_model_link,
|
99 |
+
results=result_list,
|
100 |
+
timestamp=result_list[0].timestamp,
|
101 |
+
revision=result_list[0].revision
|
102 |
)
|
103 |
|
104 |
def to_dict(self, task='qa', metric='ndcg_at_3') -> List:
|
|
|
118 |
make_clickable_model(self.reranking_model, self.reranking_model_link))
|
119 |
results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
|
120 |
results[eval_result.eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
|
121 |
+
results[eval_result.eval_name][COL_NAME_REVISION] = self.revision
|
122 |
+
results[eval_result.eval_name][COL_NAME_TIMESTAMP] = self.timestamp
|
123 |
|
124 |
# print(f'result loaded: {eval_result.eval_name}')
|
125 |
for result in eval_result.results:
|
|
|
206 |
# filter out if any of the benchmarks have not been produced
|
207 |
df = df[has_no_nan_values(df, _benchmark_cols)]
|
208 |
df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
|
209 |
+
|
210 |
+
# shorten the revision
|
211 |
+
df[COL_NAME_REVISION] = df[COL_NAME_REVISION].str[:6]
|
212 |
return df
|
tests/src/display/test_utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import pytest
|
2 |
-
from src.display.utils import fields, AutoEvalColumnQA, COLS_QA, COLS_LONG_DOC, COLS_LITE,
|
3 |
|
4 |
|
5 |
def test_fields():
|
@@ -11,6 +11,7 @@ def test_macro_variables():
|
|
11 |
print(f'COLS_QA: {COLS_QA}')
|
12 |
print(f'COLS_LONG_DOC: {COLS_LONG_DOC}')
|
13 |
print(f'COLS_LITE: {COLS_LITE}')
|
14 |
-
print(f'
|
|
|
15 |
print(f'QA_BENCHMARK_COLS: {QA_BENCHMARK_COLS}')
|
16 |
print(f'LONG_DOC_BENCHMARK_COLS: {LONG_DOC_BENCHMARK_COLS}')
|
|
|
1 |
import pytest
|
2 |
+
from src.display.utils import fields, AutoEvalColumnQA, COLS_QA, COLS_LONG_DOC, COLS_LITE, TYPES_QA, TYPES_LONG_DOC, QA_BENCHMARK_COLS, LONG_DOC_BENCHMARK_COLS
|
3 |
|
4 |
|
5 |
def test_fields():
|
|
|
11 |
print(f'COLS_QA: {COLS_QA}')
|
12 |
print(f'COLS_LONG_DOC: {COLS_LONG_DOC}')
|
13 |
print(f'COLS_LITE: {COLS_LITE}')
|
14 |
+
print(f'TYPES_QA: {TYPES_QA}')
|
15 |
+
print(f'TYPES_LONG_DOC: {TYPES_LONG_DOC}')
|
16 |
print(f'QA_BENCHMARK_COLS: {QA_BENCHMARK_COLS}')
|
17 |
print(f'LONG_DOC_BENCHMARK_COLS: {LONG_DOC_BENCHMARK_COLS}')
|
tests/test_utils.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
-
from utils import filter_models, search_table, filter_queries, select_columns, update_table_long_doc, get_iso_format_timestamp
|
|
|
5 |
|
6 |
|
7 |
@pytest.fixture
|
@@ -86,3 +87,9 @@ def test_get_iso_format_timestamp():
|
|
86 |
assert len(timestamp_fn) == 14
|
87 |
assert len(timestamp_config) == 20
|
88 |
assert timestamp_config[-1] == "Z"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
+
from utils import filter_models, search_table, filter_queries, select_columns, update_table_long_doc, get_iso_format_timestamp, get_default_cols
|
5 |
+
from src.display.utils import COLS_QA
|
6 |
|
7 |
|
8 |
@pytest.fixture
|
|
|
87 |
assert len(timestamp_fn) == 14
|
88 |
assert len(timestamp_config) == 20
|
89 |
assert timestamp_config[-1] == "Z"
|
90 |
+
|
91 |
+
|
92 |
+
def test_get_default_cols():
|
93 |
+
cols, types = get_default_cols("qa", COLS_QA)
|
94 |
+
for c, t in zip(cols, types):
|
95 |
+
print(f"type({c}): {t}")
|
utils.py
CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
7 |
import pandas as pd
|
8 |
|
9 |
from src.benchmarks import BENCHMARK_COLS_QA, BENCHMARK_COLS_LONG_DOC, BenchmarksQA, BenchmarksLongDoc
|
10 |
-
from src.display.utils import
|
11 |
from src.leaderboard.read_evals import FullEvalResult, get_leaderboard_df
|
12 |
from src.envs import API, SEARCH_RESULTS_REPO, CACHE_PATH
|
13 |
from src.display.formatting import styled_message, styled_error
|
@@ -44,22 +44,37 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
|
|
44 |
|
45 |
|
46 |
def get_default_cols(task: str, columns: list, add_fix_cols: bool=True) -> list:
|
|
|
|
|
47 |
if task == "qa":
|
48 |
-
|
|
|
|
|
49 |
elif task == "long-doc":
|
50 |
-
|
|
|
|
|
51 |
else:
|
52 |
raise NotImplemented
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
if add_fix_cols:
|
54 |
cols = FIXED_COLS + cols
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
def select_columns(df: pd.DataFrame, domain_query: list, language_query: list, task: str = "qa") -> pd.DataFrame:
|
65 |
cols = get_default_cols(task=task, columns=df.columns, add_fix_cols=False)
|
|
|
7 |
import pandas as pd
|
8 |
|
9 |
from src.benchmarks import BENCHMARK_COLS_QA, BENCHMARK_COLS_LONG_DOC, BenchmarksQA, BenchmarksLongDoc
|
10 |
+
from src.display.utils import COLS_QA, TYPES_QA, COLS_LONG_DOC, TYPES_LONG_DOC, COL_NAME_RANK, COL_NAME_AVG, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL, COL_NAME_REVISION, COL_NAME_TIMESTAMP, AutoEvalColumnQA, AutoEvalColumnLongDoc, get_default_auto_eval_column_dict
|
11 |
from src.leaderboard.read_evals import FullEvalResult, get_leaderboard_df
|
12 |
from src.envs import API, SEARCH_RESULTS_REPO, CACHE_PATH
|
13 |
from src.display.formatting import styled_message, styled_error
|
|
|
44 |
|
45 |
|
46 |
def get_default_cols(task: str, columns: list, add_fix_cols: bool=True) -> list:
|
47 |
+
cols = []
|
48 |
+
types = []
|
49 |
if task == "qa":
|
50 |
+
cols_list = COLS_QA
|
51 |
+
types_list = TYPES_QA
|
52 |
+
benchmark_list = BENCHMARK_COLS_QA
|
53 |
elif task == "long-doc":
|
54 |
+
cols_list = COLS_LONG_DOC
|
55 |
+
types_list = TYPES_LONG_DOC
|
56 |
+
benchmark_list = BENCHMARK_COLS_LONG_DOC
|
57 |
else:
|
58 |
raise NotImplemented
|
59 |
+
for col_name, col_type in zip(cols_list, types_list):
|
60 |
+
if col_name not in benchmark_list:
|
61 |
+
continue
|
62 |
+
if col_name not in columns:
|
63 |
+
continue
|
64 |
+
cols.append(col_name)
|
65 |
+
types.append(col_type)
|
66 |
+
|
67 |
if add_fix_cols:
|
68 |
cols = FIXED_COLS + cols
|
69 |
+
types = FIXED_COLS_TYPES + types
|
70 |
+
return cols, types
|
71 |
+
|
72 |
+
fixed_cols = get_default_auto_eval_column_dict()[:-2]
|
73 |
+
|
74 |
+
|
75 |
+
FIXED_COLS = [c.name for _, _, c in fixed_cols]
|
76 |
+
FIXED_COLS_TYPES = [c.type for _, _, c in fixed_cols]
|
77 |
+
|
78 |
|
79 |
def select_columns(df: pd.DataFrame, domain_query: list, language_query: list, task: str = "qa") -> pd.DataFrame:
|
80 |
cols = get_default_cols(task=task, columns=df.columns, add_fix_cols=False)
|