|
from typing import Any, Dict, Iterator, List |
|
|
|
import requests |
|
from huggingface_hub import add_collection_item, create_collection |
|
from tqdm.auto import tqdm |
|
|
|
|
|
class DatasetSearchClient: |
|
def __init__( |
|
self, |
|
base_url: str = "https://librarian-bots-dataset-column-search-api.hf.space", |
|
): |
|
self.base_url = base_url |
|
|
|
def search( |
|
self, columns: List[str], match_all: bool = False, page_size: int = 100 |
|
) -> Iterator[Dict[str, Any]]: |
|
""" |
|
Search datasets using the provided API, automatically handling pagination. |
|
|
|
Args: |
|
columns (List[str]): List of column names to search for. |
|
match_all (bool, optional): If True, match all columns. If False, match any column. Defaults to False. |
|
page_size (int, optional): Number of results per page. Defaults to 100. |
|
|
|
Yields: |
|
Dict[str, Any]: Each dataset result from all pages. |
|
|
|
Raises: |
|
requests.RequestException: If there's an error with the HTTP request. |
|
ValueError: If the API returns an unexpected response format. |
|
""" |
|
page = 1 |
|
total_results = None |
|
|
|
while total_results is None or (page - 1) * page_size < total_results: |
|
params = { |
|
"columns": columns, |
|
"match_all": str(match_all).lower(), |
|
"page": page, |
|
"page_size": page_size, |
|
} |
|
|
|
try: |
|
response = requests.get(f"{self.base_url}/search", params=params) |
|
response.raise_for_status() |
|
data = response.json() |
|
|
|
if not {"total", "page", "page_size", "results"}.issubset(data.keys()): |
|
raise ValueError("Unexpected response format from the API") |
|
|
|
if total_results is None: |
|
total_results = data["total"] |
|
|
|
yield from data["results"] |
|
page += 1 |
|
|
|
except requests.RequestException as e: |
|
raise requests.RequestException( |
|
f"Error connecting to the API: {str(e)}" |
|
) from e |
|
except ValueError as e: |
|
raise ValueError(f"Error processing API response: {str(e)}") from e |
|
|
|
|
|
|
|
client = DatasetSearchClient() |
|
|
|
|
|
def update_collection_for_dataset( |
|
collection_name: str = None, |
|
dataset_columns: List[str] = None, |
|
collection_description: str = None, |
|
collection_namespace: str = None, |
|
): |
|
if not collection_name: |
|
collection = create_collection( |
|
collection_name, exists_ok=True, description=collection_description |
|
) |
|
else: |
|
collection = create_collection( |
|
collection_name, |
|
exists_ok=True, |
|
description=collection_description, |
|
namespace=collection_namespace, |
|
) |
|
results = list( |
|
tqdm( |
|
client.search(dataset_columns, match_all=True), |
|
desc="Searching datasets...", |
|
leave=False, |
|
) |
|
) |
|
for result in tqdm(results, desc="Adding datasets to collection...", leave=False): |
|
try: |
|
add_collection_item( |
|
collection.slug, result["hub_id"], item_type="dataset", exists_ok=True |
|
) |
|
except Exception as e: |
|
print( |
|
f"Error adding dataset {result['hub_id']} to collection {collection_name}: {str(e)}" |
|
) |
|
return f"https://huggingface.co/collections/{collection.slug}" |
|
|
|
|
|
collections = [ |
|
{ |
|
"dataset_columns": ["chosen", "rejected", "prompt"], |
|
"collection_description": "Datasets suitable for DPO based on having 'chosen', 'rejected', and 'prompt' columns. Created using librarian-bots/dataset-column-search-api", |
|
"collection_name": "Direct Preference Optimization Datasets", |
|
}, |
|
{ |
|
"dataset_columns": ["image", "chosen", "rejected"], |
|
"collection_description": "Datasets suitable for Image Preference Optimization based on having 'image','chosen', and 'rejected' columns", |
|
"collection_name": "Image Preference Optimization Datasets", |
|
}, |
|
{ |
|
"collection_name": "Alpaca Style Datasets", |
|
"dataset_columns": ["instruction", "input", "output"], |
|
"collection_description": "Datasets which follow the Alpaca Style format based on having 'instruction', 'input', and 'output' columns", |
|
}, |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|