dataset-sampler / app.py
dhruv-anand-aintech's picture
Create app.py
c2e8d51
import streamlit as st
from huggingface_hub import HfApi, HfFolder
from datasets import load_dataset
# Function to fetch dataset names for typeahead (autocomplete)
def fetch_dataset_names(query):
api = HfApi()
datasets = api.list_datasets()
filtered_datasets = [d.id for d in datasets if query.lower() in d.id.lower()]
return filtered_datasets
# Function to create a new dataset
def create_sampled_dataset(dataset_name, num_rows, user_token):
# Load the dataset
dataset = load_dataset(dataset_name)
# Sample the dataset
sampled_dataset = dataset['train'].shuffle().select(range(num_rows))
# Save the sampled dataset to a file (modify this as needed)
sampled_dataset.to_csv('sampled_dataset.csv')
# Here you'd need to implement the logic to upload this dataset to the user's Hf account
# This part is not straightforward and requires using the Hf API to create a new dataset repo
# You'll need to refer to the Hf API documentation for details on how to implement this
return "URL_to_new_dataset" # This should be the URL to the newly created dataset
# Main app
def main():
st.title("HuggingFace Dataset Sampler")
# User authentication
user_token = st.text_input("Enter your HuggingFace token for authentication")
# Dataset input with typeahead
dataset_query = st.text_input("Enter Dataset Name")
if dataset_query:
dataset_names = fetch_dataset_names(dataset_query)
selected_dataset = st.selectbox("Select Dataset", options=dataset_names)
else:
selected_dataset = None
# Number of rows input
num_rows = st.number_input("Enter number of rows to sample", min_value=1, step=1)
# Button to create new dataset
if st.button("Create Sampled Dataset"):
if user_token and selected_dataset and num_rows:
try:
# Create the sampled dataset and get its URL
dataset_url = create_sampled_dataset(selected_dataset, num_rows, user_token)
st.success(f"Dataset created successfully! Find it here: {dataset_url}")
except Exception as e:
st.error(f"Error: {e}")
else:
st.error("Please fill in all required fields.")
if __name__ == "__main__":
main()