Alcime commited on
Commit
fce64e9
1 Parent(s): 6d7b047

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from datasets import load_dataset
4
+ from bunkatopics import Bunka
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain_community.llms import HuggingFaceHub
7
+
8
+ # Streamlit app
9
+ st.title("Bunka Map 🗺️")
10
+
11
+ # Input parameters
12
+ dataset_id = st.text_input("Dataset ID", "bunkalab/medium-sample-technology")
13
+ language = st.text_input("Language", "english")
14
+ text_field = st.text_input("Text Field", "title")
15
+ embedder_model = st.text_input("Embedder Model", "sentence-transformers/distiluse-base-multilingual-cased-v2")
16
+ sample_size = st.number_input("Sample Size", min_value=100, max_value=10000, value=1000)
17
+ n_clusters = st.number_input("Number of Clusters", min_value=5, max_value=50, value=15)
18
+ llm_model = st.text_input("LLM Model", "mistralai/Mistral-7B-Instruct-v0.1")
19
+
20
+ # Hugging Face API token input
21
+ hf_token = st.text_input("Hugging Face API Token", type="password")
22
+
23
+ if st.button("Generate Bunka Map"):
24
+ # Load dataset and sample
25
+ @st.cache_data
26
+ def load_data(dataset_id, text_field, sample_size):
27
+ dataset = load_dataset(dataset_id, streaming=True)
28
+ docs_sample = []
29
+ for i, example in enumerate(dataset["train"]):
30
+ if i >= sample_size:
31
+ break
32
+ docs_sample.append(example[text_field])
33
+ return docs_sample
34
+
35
+ docs_sample = load_data(dataset_id, text_field, sample_size)
36
+
37
+ # Initialize embedding model and Bunka
38
+ embedding_model = HuggingFaceEmbeddings(model_name=embedder_model)
39
+ bunka = Bunka(embedding_model=embedding_model, language=language)
40
+
41
+ # Fit Bunka to the text data
42
+ bunka.fit(docs_sample)
43
+
44
+ # Generate topics
45
+ df_topics = bunka.get_topics(n_clusters=n_clusters, name_length=5, min_count_terms=2)
46
+
47
+ # Visualize topics
48
+ st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True))
49
+
50
+ # Clean labels using LLM
51
+ if hf_token:
52
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
53
+ llm = HuggingFaceHub(repo_id=llm_model, huggingfacehub_api_token=hf_token)
54
+ bunka.get_clean_topic_name(llm=llm, language=language)
55
+ st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True))
56
+ else:
57
+ st.warning("Please provide a Hugging Face API token to clean labels using LLM.")
58
+
59
+ # Manual topic cleaning
60
+ st.subheader("Manually Clean Topics")
61
+ cleaned_topics = {}
62
+ for topic, keywords in bunka.topics_.items():
63
+ cleaned_topic = st.text_input(f"Topic {topic}", ", ".join(keywords))
64
+ cleaned_topics[topic] = cleaned_topic.split(", ")
65
+
66
+ if st.button("Update Topics"):
67
+ bunka.topics_ = cleaned_topics
68
+ st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True))
69
+
70
+ # Remove unwanted topics
71
+ st.subheader("Remove Unwanted Topics")
72
+ topics_to_remove = st.multiselect("Select topics to remove", list(bunka.topics_.keys()))
73
+ if st.button("Remove Topics"):
74
+ bunka.clean_data_by_topics(topics_to_remove)
75
+ st.plotly_chart(bunka.visualize_topics(width=800, height=800, colorscale='Portland', density=True, label_size_ratio=60, convex_hull=True))
76
+
77
+ # Save dataset
78
+ if st.button("Save Cleaned Dataset"):
79
+ name = dataset_id.replace('/', '_') + '_cleaned.csv'
80
+ bunka.df_cleaned_.to_csv(name)
81
+ st.success(f"Dataset saved as {name}")