Datasets-Metrics-Viewer / src /logic /data_fetching.py
hynky's picture
hynky HF staff
UI overhaul + seapration of concerns
40e38d3
raw
history blame
4.45 kB
import os
import json
import tempfile
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict
from datatrove.io import get_datafolder
from datatrove.utils.stats import MetricStatsDict
import gradio as gr
import tenacity
def find_folders(base_folder: str, path: str) -> List[str]:
base_folder = get_datafolder(base_folder)
if not base_folder.exists(path):
return []
return sorted(
[
folder["name"]
for folder in base_folder.ls(path, detail=True)
if folder["type"] == "directory" and not folder["name"].rstrip("/") == path
]
)
def find_metrics_folders(base_folder: str) -> List[str]:
base_data_df = get_datafolder(base_folder)
dirs = sorted(
folder
for folder, info in base_data_df.find("", detail=True, maxdepth=1, withdirs=True).items()
if info["type"] == "directory"
)
return sorted(list(set(dirs)))
def fetch_datasets(base_folder: str):
datasets = sorted(find_metrics_folders(base_folder))
return datasets, gr.update(choices=datasets, value=None), fetch_groups(base_folder, datasets, None, "union")
def fetch_groups(base_folder: str, datasets: List[str], old_groups: str, type: str = "intersection"):
if not datasets:
return gr.update(choices=[], value=None)
with ThreadPoolExecutor() as executor:
GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets))
if len(GROUPS) == 0:
return gr.update(choices=[], value=None)
if type == "intersection":
new_choices = set.intersection(*(set(g) for g in GROUPS))
else:
new_choices = set.union(*(set(g) for g in GROUPS))
value = None
if old_groups:
value = list(set.intersection(new_choices, {old_groups}))
value = value[0] if value else None
if not value and len(new_choices) == 1:
value = list(new_choices)[0]
return gr.update(choices=sorted(list(new_choices)), value=value)
def fetch_metrics(base_folder: str, datasets: List[str], group: str, old_metrics: str, type: str = "intersection"):
if not group:
return gr.update(choices=[], value=None)
with ThreadPoolExecutor() as executor:
metrics = list(
executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
if len(metrics) == 0:
return gr.update(choices=[], value=None)
if type == "intersection":
new_possibles_choices = set.intersection(*(set(s) for s in metrics))
else:
new_possibles_choices = set.union(*(set(s) for s in metrics))
value = None
if old_metrics:
value = list(set.intersection(new_possibles_choices, {old_metrics}))
value = value[0] if value else None
if not value and len(new_possibles_choices) == 1:
value = list(new_possibles_choices)[0]
return gr.update(choices=sorted(list(new_possibles_choices)), value=value)
def reverse_search(base_folder: str, possible_datasets: List[str], grouping: str, metric_name: str) -> str:
with ThreadPoolExecutor() as executor:
found_datasets = list(executor.map(
lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None,
possible_datasets))
found_datasets = [dataset for dataset in found_datasets if dataset is not None]
return "\n".join(found_datasets)
def reverse_search_add(datasets: List[str], reverse_search_results: str) -> List[str]:
datasets = datasets or []
return sorted(list(set(datasets + reverse_search_results.strip().split("\n"))))
def metric_exists(base_folder: str, path: str, metric_name: str, group_by: str) -> bool:
base_folder = get_datafolder(base_folder)
return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json")
@tenacity.retry(stop=tenacity.stop_after_attempt(5))
def load_metrics(base_folder: str, path: str, metric_name: str, group_by: str) -> MetricStatsDict:
base_folder = get_datafolder(base_folder)
with base_folder.open(f"{path}/{group_by}/{metric_name}/metric.json") as f:
json_metric = json.load(f)
return MetricStatsDict.from_dict(json_metric)
def load_data(dataset_path: str, base_folder: str, grouping: str, metric_name: str) -> MetricStatsDict:
return load_metrics(base_folder, dataset_path, metric_name, grouping)