Spaces:
Sleeping
Sleeping
# Inspired by https://huggingface.co/spaces/davanstrien/dataset_column_search | |
import os | |
from functools import lru_cache | |
from urllib.parse import quote | |
import faiss | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from dotenv import load_dotenv | |
from httpx import Client | |
from huggingface_hub import HfApi | |
from huggingface_hub.utils import logging | |
from sentence_transformers import SentenceTransformer | |
from tqdm.contrib.concurrent import thread_map | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables" | |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" | |
logger = logging.get_logger(__name__) | |
headers = { | |
"authorization": f"Bearer ${HF_TOKEN}", | |
} | |
client = Client(headers=headers) | |
api = HfApi(token=HF_TOKEN) | |
def get_first_config_name(dataset: str): | |
try: | |
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/splits?dataset={dataset}") | |
data = resp.json() | |
return data["splits"][0]["config"][0] | |
except Exception as e: | |
logger.error(f"Failed to get splits for {dataset}: {e}") | |
return None | |
def datasets_server_valid_rows(dataset: str): | |
try: | |
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={dataset}") | |
return resp.json()["viewer"] | |
except Exception as e: | |
logger.error(f"Failed to get is-valid for {dataset}: {e}") | |
return None | |
def dataset_is_valid(dataset): | |
return dataset if datasets_server_valid_rows(dataset.id) else None | |
def get_first_config_and_split_name(hub_id: str): | |
try: | |
resp = client.get(f"https://datasets-server.huggingface.co/splits?dataset={hub_id}") | |
data = resp.json() | |
return data["splits"][0]["config"], data["splits"][0]["split"] | |
except Exception as e: | |
logger.error(f"Failed to get splits for {hub_id}: {e}") | |
return None | |
def get_dataset_info(hub_id: str, config: str | None = None): | |
if config is None: | |
config = get_first_config_and_split_name(hub_id) | |
if config is None: | |
return None | |
else: | |
config = config[0] | |
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}") | |
resp.raise_for_status() | |
return resp.json() | |
def dataset_with_info(dataset): | |
try: | |
if info := get_dataset_info(dataset.id): | |
columns = info.get("dataset_info", {}).get("features", {}) | |
if columns is not None: | |
return { | |
"dataset": dataset.id, | |
"column_names": ','.join(list(columns.keys())), | |
"text": f"{dataset.id}-{','.join(list(columns.keys()))}", | |
"likes": dataset.likes, | |
"downloads": dataset.downloads, | |
"created_at": dataset.created_at, | |
"tags": dataset.tags, | |
"text": f"{str(dataset.id).split('/')[-1]}-{','.join(list(columns.keys()))}", | |
} | |
except Exception as e: | |
logger.error(f"Failed to get info for {dataset.id}: {e}") | |
return None | |
def prep_data(): | |
datasets = list(api.list_datasets(limit=None, sort="createdAt", direction=-1)) | |
print(f"Found {len(datasets)} datasets in the hub.") | |
has_server = thread_map( | |
dataset_is_valid, | |
datasets, | |
) | |
datasets_with_server = [x for x in has_server if x is not None] | |
print(f"Found {len(datasets_with_server)} valid datasets.") | |
dataset_infos = thread_map(dataset_with_info, datasets_with_server) | |
dataset_infos = [x for x in dataset_infos if x is not None] | |
print(f"Found {len(dataset_infos)} datasets with info.") | |
return dataset_infos | |
all_datasets = prep_data() | |
all_datasets_df = pd.DataFrame.from_dict(all_datasets) | |
print(all_datasets_df.head()) | |
text = all_datasets_df['text'] | |
encoder = SentenceTransformer("Snowflake/snowflake-arctic-embed-s") | |
vectors = encoder.encode(text) | |
vector_dimension = vectors.shape[1] | |
print("Start indexing") | |
index = faiss.IndexFlatL2(vector_dimension) | |
faiss.normalize_L2(vectors) | |
index.add(vectors) | |
print("Indexing done") | |
def render_model_hub_link(hub_id): | |
link = f"https://huggingface.co/datasets/{quote(hub_id)}" | |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{hub_id}</a>' | |
def search(dataset_name, k): | |
print(f"start search for {dataset_name}") | |
try: | |
dataset_row = all_datasets_df[all_datasets_df.dataset == dataset_name].iloc[0] | |
except IndexError: | |
return pd.DataFrame([{"error": "❌ Dataset does not exist or is not supported"}]) | |
text = dataset_row["text"] | |
search_vector = encoder.encode(text) | |
_vector = np.array([search_vector]) | |
faiss.normalize_L2(_vector) | |
distances, ann = index.search(_vector, k=k) | |
results = pd.DataFrame({"distances": distances[0], "ann": ann[0]}) | |
merge = pd.merge(results, all_datasets_df, left_on="ann", right_index=True) | |
merge["dataset"] = merge["dataset"].apply(render_model_hub_link) | |
return merge.drop("text", axis=1) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Search similar Datasets on Hugging Face") | |
gr.Markdown("This space shows similar datasets based on a name and columns. It uses https://github.com/facebookresearch/faiss for vector indexing.") | |
gr.Markdown("'Text' column was used for indexing. Where text is a concatenation of 'dataset_name'-'column_names'") | |
dataset_name = gr.Textbox("sksayril/medicine-info", label="Dataset Name") | |
k = gr.Slider(5, 200, 20, step=5, interactive=True, label="Top K Nearest Neighbors") | |
btn = gr.Button("Show similar datasets") | |
df = gr.DataFrame(datatype="markdown") | |
btn.click(search, inputs=[dataset_name, k], outputs=df) | |
gr.Markdown("This space was inspired by https://huggingface.co/spaces/davanstrien/dataset_column_search") | |
demo.launch() | |