google_ads_space / pages /2_Topic_Cluster.py
zayed-upal
Google ads format download added, topic name rename option added
4c25316
raw
history blame contribute delete
No virus
5.95 kB
import pandas as pd
import streamlit as st
from Functionalities import NLP_Helper
from Functionalities.TopicClustering import TopicClustering
from streamlit_extras.dataframe_explorer import dataframe_explorer
class TopicClusterView:
def __init__(self):
self.n_neighbors = 10
self.topic_cluster = None
self.representation_model = None
self.sentence_model = None
self.text_col = None
self.text_df = None
self.text_file = None
st.session_state.topic_cluster = None \
if 'topic_cluster' not in st.session_state else st.session_state.topic_cluster
st.set_page_config(page_title='Topic Clustering', layout="wide")
st.header("Topic Clustering")
# st.write(f"This page tries to predict the suitable ad group for new keywords "
# f"based on the keywords already existing in the campaign.")
def input_params(self) -> None:
"""
Takes csv file input, name of text col, select option for sentence model and representation model
:return:
"""
self.text_file = st.file_uploader(label="Upload the CSV file containing the texts to cluster")
if self.text_file:
self.text_df = pd.read_csv(self.text_file)
self.text_col = st.selectbox(
label=f"Choose the column to use for topic clustering in **{self.text_file.name}**",
options=self.text_df.columns
)
self.sentence_model = st.selectbox(
label=f"Choose the text embedding model",
options=NLP_Helper.TRANSFORMERS,
help="; ".join(NLP_Helper.TRANSFORMERS_INFO)
)
self.representation_model = st.selectbox(
label=f"Choose the representation model",
options=NLP_Helper.BERTOPIC_REPRESENTATIONS,
)
st.button("Cluster", on_click=self.run_clustering)
def run_clustering(self) -> None:
self.topic_cluster = TopicClustering(keyword_df=self.text_df, text_col=self.text_col,
representation_model=self.representation_model,
sentence_model=self.sentence_model)
self.topic_cluster.topic_cluster_bert()
st.session_state.topic_cluster = self.topic_cluster
def show_and_download_df(self):
if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
filtered_df = dataframe_explorer(st.session_state.topic_cluster.keyword_df)
st.dataframe(filtered_df)
with st.expander("Rename Topics"):
for topic_name in st.session_state.topic_cluster.topic_names:
cur_topic_col, new_topic_col = st.columns(2)
with cur_topic_col:
cur_topic_col.write(topic_name)
with new_topic_col:
st.session_state.topic_cluster.topic_name_mapping[topic_name] = \
st.text_input("New topic name", topic_name)
if st.button("Update Topic Names"):
st.session_state.topic_cluster.update_topic_names()
st.experimental_rerun()
st.download_button(
"Press to Download as CSV",
st.session_state.topic_cluster.keyword_df.to_csv(index=False).encode('utf-8'),
"Clustered.csv",
"text/csv",
key='download-csv'
)
with st.expander("Download as CSV for Bulk upload in Google Ads"):
campaign_name = st.text_input("Campaign Name", "Demo Campaign")
st.dataframe(st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name))
st.download_button(
"Download as CSV for Bulk upload in Google Ads",
st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name).to_csv(
index=False).encode('utf-8'),
f"{campaign_name}_keywords_upload.csv",
"text/csv",
key='download-google-csv'
)
def visualize_clusters(self):
if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
self.n_neighbors = st.slider(label='Size of the local neighborhood', min_value=2, max_value=100, step=1)
if st.button("Visualize Topic Clusters"):
if (st.session_state.topic_cluster is not None) and (
st.session_state.topic_cluster.topic_model is not None):
fig = st.session_state.topic_cluster.visualize_documents(n_neighbors=self.n_neighbors)
fig.update_layout(title=None)
st.plotly_chart(fig, use_container_width=True, theme=None)
def visualize_topic_distribution(self):
if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
if (st.session_state.topic_cluster is not None) and (
st.session_state.topic_cluster.topic_model is not None):
fig = st.session_state.topic_cluster.visualize_topic_distribution()
st.plotly_chart(fig, use_container_width=True, theme=None)
if __name__ == '__main__':
topic_cluster_view = TopicClusterView()
# tab1, tab2, tab3 = st.tabs(['Clustering Process', 'Cluster Visualization', 'Topic Distribution'])
tab1, tab2 = st.tabs(['Clustering Process', 'Cluster Visualization'])
with tab1:
topic_cluster_view.input_params()
topic_cluster_view.show_and_download_df()
with tab2:
topic_cluster_view.visualize_clusters()
# with tab3:
# topic_cluster_view.visualize_topic_distribution()