Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from datasets import load_dataset | |
from bunkatopics import Bunka | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFaceHub | |
# Streamlit app | |
st.title("Bunka Map 🗺️") | |
# Input parameters | |
dataset_id = st.text_input("Dataset ID", "bunkalab/medium-sample-technology") | |
language = st.text_input("Language", "english") | |
text_field = st.text_input("Text Field", "title") | |
embedder_model = st.text_input("Embedder Model", "sentence-transformers/distiluse-base-multilingual-cased-v2") | |
sample_size = st.number_input("Sample Size", min_value=100, max_value=10000, value=1000) | |
n_clusters = st.number_input("Number of Clusters", min_value=5, max_value=50, value=15) | |
llm_model = st.text_input("LLM Model", "mistralai/Mistral-7B-Instruct-v0.1") | |
# Hugging Face API token input | |
hf_token = st.text_input("Hugging Face API Token", type="password") | |
if st.button("Generate Bunka Map"): | |
# Load dataset and sample | |
def load_data(dataset_id, text_field, sample_size): | |
dataset = load_dataset(dataset_id, streaming=True) | |
docs_sample = [] | |
for i, example in enumerate(dataset["train"]): | |
if i >= sample_size: | |
break | |
docs_sample.append(example[text_field]) | |
return docs_sample | |
docs_sample = load_data(dataset_id, text_field, sample_size) | |
# Initialize embedding model and Bunka | |
embedding_model = HuggingFaceEmbeddings(model_name=embedder_model) | |
bunka = Bunka(embedding_model=embedding_model, language=language) | |
# Fit Bunka to the text data | |
bunka.fit(docs_sample) | |
# Generate topics | |
df_topics = bunka.get_topics(n_clusters=n_clusters, name_length=5, min_count_terms=2) | |
# Visualize topics | |
st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True)) | |
# Clean labels using LLM | |
if hf_token: | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token | |
llm = HuggingFaceHub(repo_id=llm_model, huggingfacehub_api_token=hf_token) | |
bunka.get_clean_topic_name(llm=llm, language=language) | |
st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True)) | |
else: | |
st.warning("Please provide a Hugging Face API token to clean labels using LLM.") | |
# Manual topic cleaning | |
st.subheader("Manually Clean Topics") | |
cleaned_topics = {} | |
for topic, keywords in bunka.topics_.items(): | |
cleaned_topic = st.text_input(f"Topic {topic}", ", ".join(keywords)) | |
cleaned_topics[topic] = cleaned_topic.split(", ") | |
if st.button("Update Topics"): | |
bunka.topics_ = cleaned_topics | |
st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True)) | |
# Remove unwanted topics | |
st.subheader("Remove Unwanted Topics") | |
topics_to_remove = st.multiselect("Select topics to remove", list(bunka.topics_.keys())) | |
if st.button("Remove Topics"): | |
bunka.clean_data_by_topics(topics_to_remove) | |
st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True)) | |
# Save dataset | |
if st.button("Save Cleaned Dataset"): | |
name = dataset_id.replace('/', '_') + '_cleaned.csv' | |
bunka.df_cleaned_.to_csv(name) | |
st.success(f"Dataset saved as {name}") |