import streamlit as st import pandas as pd from extraction.keyphrase_extraction_pipeline import KeyphraseExtractionPipeline from extraction.keyphrase_generation_pipeline import KeyphraseGenerationPipeline import orjson if "config" not in st.session_state: with open("config.json", "r") as f: content = f.read() st.session_state.config = orjson.loads(content) st.set_page_config( page_icon="🔑", page_title="Keyphrase extraction/generation with Transformers", layout="wide", initial_sidebar_state="auto", ) @st.cache(allow_output_mutation=True) def load_pipeline(chosen_model): if "keyphrase-extraction" in chosen_model: return KeyphraseExtractionPipeline(chosen_model) elif "keyphrase-generation" in chosen_model: return KeyphraseGenerationPipeline(chosen_model) def extract_keyphrases(): st.session_state.keyphrases = pipe(st.session_state.input_text) st.header("🔑 Keyphrase extraction/generation with Transformers") col1, col2 = st.columns([1, 3]) col1.subheader("Select model") chosen_model = col1.selectbox( "Choose your model:", st.session_state.config.get("models"), ) st.session_state.chosen_model = chosen_model pipe = load_pipeline(st.session_state.chosen_model) col2.subheader("Input your text") st.session_state.input_text = col2.text_area( "Input", st.session_state.config.get("example_text"), height=150 ) pressed = col2.button("Extract", on_click=extract_keyphrases) if pressed: col2.subheader("🐧 Output") df = pd.DataFrame(data=st.session_state.keyphrases, columns=["Keyphrases"]) col2.table(df)