Spaces:
Sleeping
Sleeping
import os | |
import random | |
from collections import Counter | |
from datasets import Dataset, load_dataset | |
from fasthtml.common import * | |
from fastlite import database | |
from huggingface_hub import create_repo, login | |
login(token=os.environ.get("HF_TOKEN")) | |
fact_dataset = load_dataset("griffin/iclr2025_data_scores", split="train").to_list() | |
fact_dataset = [{"example_id": i, **example} for i, example in enumerate(fact_dataset)] | |
db = database("data/examples.db") | |
examples = db.t.examples | |
if examples not in db.t: | |
examples.create( | |
id=int, | |
example_id=int, | |
question_type=str, | |
question=str, | |
answer=str, | |
decision=str, | |
pk="id", | |
) | |
question_types = sorted(set(ex["question_type"] for ex in fact_dataset)) | |
def get_stats(): | |
total_examples = Counter(ex["question_type"] for ex in fact_dataset) | |
curated_examples = Counter(row["question_type"] for row in examples.rows) | |
stats = { | |
qt: {"total": total_examples[qt], "curated": curated_examples[qt]} | |
for qt in question_types | |
} | |
return stats | |
def get_example(selected_type=None): | |
evaluated_ids = set(row["example_id"] for row in examples()) | |
print(f"completed: {evaluated_ids}") | |
available_examples = [ | |
ex for ex in fact_dataset if ex["example_id"] not in evaluated_ids | |
] | |
if selected_type: | |
available_examples = [ | |
ex for ex in available_examples if ex["question_type"] == selected_type | |
] | |
if not available_examples: | |
return None | |
example = random.choice(available_examples) | |
keep_keys = [ | |
"example_id", | |
"question_type", | |
"question", | |
"rationale", | |
"answer", | |
"log_ll", | |
"oracle_log_ll", | |
"oracle_advantage", | |
"prediction", | |
"prediction_oracle", | |
"accuracy", | |
"accuracy_oracle", | |
"accuracy_status", | |
] | |
return {k: example[k] for k in keep_keys if k in example} | |
# app | |
style = Style(""" | |
body { background-color: #1e1e1e; color: #d4d4d4; font-family: Arial, sans-serif; } | |
h1, h2, h3 { color: #61dafb; } | |
.example-container { margin-top: 20px; } | |
.example-table { border-collapse: collapse; width: 100%; } | |
.example-table th, .example-table td { border: 1px solid #3a3a3a; padding: 8px; text-align: left; } | |
.example-table th { background-color: #2a2a2a; color: #61dafb; } | |
.example-table td { color: #d4d4d4; } | |
#evaluation-form { margin-top: 20px; } | |
#evaluation-form button { margin-right: 10px; background-color: #0e639c; color: white; border: none; padding: 10px 20px; cursor: pointer; } | |
#evaluation-form button:hover { background-color: #1177bb; } | |
select { background-color: #2a2a2a; color: #d4d4d4; border: 1px solid #3a3a3a; padding: 5px; } | |
a { color: #61dafb; text-decoration: none; } | |
a:hover { text-decoration: underline; } | |
""") | |
app, rt = fast_app(hdrs=(style,)) | |
def render_stats(stats): | |
return Table( | |
Tr(Th("Question Type"), Th("Curated"), Th("Total")), | |
*[ | |
Tr( | |
Td(qt), | |
Td( | |
f"{stats[qt]['curated']} ({stats[qt]['curated']/stats[qt]['total']:.1%})" | |
), | |
Td(stats[qt]["total"]), | |
) | |
for qt in question_types | |
], | |
cls="stats-table", | |
) | |
def render_example(example): | |
return Div( | |
Table( | |
*[Tr(Th(key), Td(str(value))) for key, value in example.items()], | |
cls="example-table", | |
), | |
Form( | |
Button( | |
"Good Example", | |
name="decision", | |
value="good", | |
hx_post="/evaluate", | |
hx_target="#example-container", | |
), | |
Button( | |
"Bad Example", | |
name="decision", | |
value="bad", | |
hx_post="/evaluate", | |
hx_target="#example-container", | |
), | |
Hidden( | |
name="example_id", | |
value=str(example["example_id"]), | |
id="hidden-example-id", | |
), | |
), | |
id="example-details", | |
) | |
def upload_to_hf(): | |
create_repo( | |
repo_id="rbiswasfc/iclr-eval-examples", | |
token=os.environ.get("HF_TOKEN"), | |
private=True, | |
repo_type="dataset", | |
exist_ok=True, | |
) | |
# examples = db.t.examples | |
annotations = examples() | |
hf_ds = Dataset.from_list(annotations) | |
hf_ds.push_to_hub("rbiswasfc/iclr-eval-examples", token=os.environ.get("HF_TOKEN")) | |
def get(question_type: str = None): | |
stats = get_stats() | |
example = get_example(question_type) | |
dropdown = Select( | |
Option("Question Types", value="", selected=question_type is None), | |
*[Option(qt, value=qt, selected=qt == question_type) for qt in question_types], | |
name="question_type", | |
hx_get="/", | |
hx_target="body", | |
hx_push_url="true", | |
) | |
if example is None: | |
content = Div( | |
H2("All examples of this type have been evaluated!"), render_stats(stats) | |
) | |
else: | |
content = Div( | |
H2("Example"), | |
Div( | |
render_example(example), | |
id="example-container", | |
), | |
) | |
view_stats_link = A("Curation Stats", href="/stats", cls="view-stats-link") | |
return Titled( | |
"Example Curation", | |
H2("Question Type"), | |
dropdown, | |
content, | |
Div(), | |
view_stats_link, | |
) | |
def post(decision: str, example_id: str): | |
print(f"params to post: {decision}, {example_id}") | |
example_id = int(example_id) | |
example_dict = fact_dataset[example_id] | |
# Insert the evaluated example into the database | |
examples.insert( | |
{ | |
"id": len(list(examples.rows)) + 1, # Auto-increment ID | |
"example_id": example_dict["example_id"], | |
"question_type": example_dict["question_type"], | |
"question": example_dict["question"], | |
"answer": example_dict["answer"], | |
"decision": decision, | |
} | |
) | |
upload_to_hf() | |
new_example = get_example(example_dict["question_type"]) | |
if new_example is None: | |
return Div(H2("All examples of this type have been evaluated!")) | |
else: | |
return render_example(new_example) | |
def get(): | |
stats = get_stats() | |
stats = render_stats(stats) | |
return Titled( | |
"Curation Statistics", | |
Div( | |
stats, | |
A("Back to Curation", href="/", cls="back-link"), | |
cls="container", | |
), | |
) | |
# serve() | |
if __name__ == "__main__": | |
import os | |
import uvicorn | |
# setup_hf_backup(app) | |
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |