import streamlit as st import pandas as pd from pipelines.keyphrase_extraction_pipeline import KeyphraseExtractionPipeline from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline import orjson from annotated_text.util import get_annotated_html from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode import re import numpy as np if "config" not in st.session_state: with open("config.json", "r") as f: content = f.read() st.session_state.config = orjson.loads(content) st.session_state.data_frame = pd.DataFrame(columns=["model"]) st.session_state.keyphrases = [] st.set_page_config( page_icon="🔑", page_title="Keyphrase extraction/generation with Transformers", layout="wide", ) if "select_rows" not in st.session_state: st.session_state.selected_rows = [] st.header("🔑 Keyphrase extraction/generation with Transformers") col1, col2 = st.empty().columns(2) @st.cache(allow_output_mutation=True) def load_pipeline(chosen_model): if "keyphrase-extraction" in chosen_model: return KeyphraseExtractionPipeline(chosen_model) elif "keyphrase-generation" in chosen_model: return KeyphraseGenerationPipeline(chosen_model) def extract_keyphrases(): st.session_state.keyphrases = pipe(st.session_state.input_text) st.session_state.data_frame = pd.concat( [ st.session_state.data_frame, pd.DataFrame( data=[ np.concatenate( ( [ st.session_state.chosen_model, st.session_state.input_text, ], st.session_state.keyphrases, ) ) ], columns=["model", "text"] + [str(i) for i in range(len(st.session_state.keyphrases))], ), ], ignore_index=True, axis=0, ).fillna("") def get_annotated_text(text, keyphrases): for keyphrase in keyphrases: text = re.sub( f"({keyphrase})", keyphrase.replace(" ", "$K"), text, flags=re.I, ) result = [] for i, word in enumerate(text.split(" ")): if re.sub(r"[^\w\s]", "", word) in keyphrases: result.append((word, "KEY", "#21c354")) elif "$K" in word: result.append((" ".join(word.split("$K")), "KEY", "#21c354")) else: if i == len(st.session_state.input_text.split(" ")) - 1: result.append(f" {word}") elif i == 0: result.append(f"{word} ") else: result.append(f" {word} ") return result def rerender_output(layout): layout.subheader("🐧 Output") if ( len(st.session_state.keyphrases) > 0 and len(st.session_state.selected_rows) == 0 ): text, keyphrases = st.session_state.input_text, st.session_state.keyphrases else: text, keyphrases = ( st.session_state.selected_rows["text"].values[0], [ keyphrase for keyphrase in st.session_state.selected_rows.loc[ :, st.session_state.selected_rows.columns.difference( ["model", "text"] ), ] .astype(str) .values.tolist()[0] if keyphrase != "" ], ) result = get_annotated_text(text, keyphrases) layout.markdown( get_annotated_html(*result), unsafe_allow_html=True, ) chosen_model = col1.selectbox( "Choose your model:", st.session_state.config.get("models"), ) st.session_state.chosen_model = chosen_model pipe = load_pipeline( f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}" ) st.session_state.input_text = col1.text_area( "Input", st.session_state.config.get("example_text"), height=300 ) pressed = col1.button("Extract", on_click=extract_keyphrases) if len(st.session_state.data_frame.columns) > 0: st.subheader("📜 History") builder = GridOptionsBuilder.from_dataframe( st.session_state.data_frame, sortable=False ) builder.configure_selection(selection_mode="single", use_checkbox=True) builder.configure_column("text", hide=True) go = builder.build() data = AgGrid( st.session_state.data_frame, gridOptions=go, update_mode=GridUpdateMode.SELECTION_CHANGED, ) st.session_state.selected_rows = pd.DataFrame(data["selected_rows"]) if len(st.session_state.selected_rows) > 0 or len(st.session_state.keyphrases) > 0: rerender_output(col2)