import pandas as pd import streamlit as st from Functionalities import NLP_Helper from Functionalities.TopicClustering import TopicClustering from streamlit_extras.dataframe_explorer import dataframe_explorer class TopicClusterView: def __init__(self): self.n_neighbors = 10 self.topic_cluster = None self.representation_model = None self.sentence_model = None self.text_col = None self.text_df = None self.text_file = None st.session_state.topic_cluster = None \ if 'topic_cluster' not in st.session_state else st.session_state.topic_cluster st.set_page_config(page_title='Topic Clustering', layout="wide") st.header("Topic Clustering") # st.write(f"This page tries to predict the suitable ad group for new keywords " # f"based on the keywords already existing in the campaign.") def input_params(self) -> None: """ Takes csv file input, name of text col, select option for sentence model and representation model :return: """ self.text_file = st.file_uploader(label="Upload the CSV file containing the texts to cluster") if self.text_file: self.text_df = pd.read_csv(self.text_file) self.text_col = st.selectbox( label=f"Choose the column to use for topic clustering in **{self.text_file.name}**", options=self.text_df.columns ) self.sentence_model = st.selectbox( label=f"Choose the text embedding model", options=NLP_Helper.TRANSFORMERS, help="; ".join(NLP_Helper.TRANSFORMERS_INFO) ) self.representation_model = st.selectbox( label=f"Choose the representation model", options=NLP_Helper.BERTOPIC_REPRESENTATIONS, ) st.button("Cluster", on_click=self.run_clustering) def run_clustering(self) -> None: self.topic_cluster = TopicClustering(keyword_df=self.text_df, text_col=self.text_col, representation_model=self.representation_model, sentence_model=self.sentence_model) self.topic_cluster.topic_cluster_bert() st.session_state.topic_cluster = self.topic_cluster def show_and_download_df(self): if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None): filtered_df = dataframe_explorer(st.session_state.topic_cluster.keyword_df) st.dataframe(filtered_df) with st.expander("Rename Topics"): for topic_name in st.session_state.topic_cluster.topic_names: cur_topic_col, new_topic_col = st.columns(2) with cur_topic_col: cur_topic_col.write(topic_name) with new_topic_col: st.session_state.topic_cluster.topic_name_mapping[topic_name] = \ st.text_input("New topic name", topic_name) if st.button("Update Topic Names"): st.session_state.topic_cluster.update_topic_names() st.experimental_rerun() st.download_button( "Press to Download as CSV", st.session_state.topic_cluster.keyword_df.to_csv(index=False).encode('utf-8'), "Clustered.csv", "text/csv", key='download-csv' ) with st.expander("Download as CSV for Bulk upload in Google Ads"): campaign_name = st.text_input("Campaign Name", "Demo Campaign") st.dataframe(st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name)) st.download_button( "Download as CSV for Bulk upload in Google Ads", st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name).to_csv( index=False).encode('utf-8'), f"{campaign_name}_keywords_upload.csv", "text/csv", key='download-google-csv' ) def visualize_clusters(self): if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None): self.n_neighbors = st.slider(label='Size of the local neighborhood', min_value=2, max_value=100, step=1) if st.button("Visualize Topic Clusters"): if (st.session_state.topic_cluster is not None) and ( st.session_state.topic_cluster.topic_model is not None): fig = st.session_state.topic_cluster.visualize_documents(n_neighbors=self.n_neighbors) fig.update_layout(title=None) st.plotly_chart(fig, use_container_width=True, theme=None) def visualize_topic_distribution(self): if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None): if (st.session_state.topic_cluster is not None) and ( st.session_state.topic_cluster.topic_model is not None): fig = st.session_state.topic_cluster.visualize_topic_distribution() st.plotly_chart(fig, use_container_width=True, theme=None) if __name__ == '__main__': topic_cluster_view = TopicClusterView() # tab1, tab2, tab3 = st.tabs(['Clustering Process', 'Cluster Visualization', 'Topic Distribution']) tab1, tab2 = st.tabs(['Clustering Process', 'Cluster Visualization']) with tab1: topic_cluster_view.input_params() topic_cluster_view.show_and_download_df() with tab2: topic_cluster_view.visualize_clusters() # with tab3: # topic_cluster_view.visualize_topic_distribution()