Spaces:
Running
Running
File size: 1,616 Bytes
0f23c4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import streamlit as st
import pandas as pd
from extraction.keyphrase_extraction_pipeline import KeyphraseExtractionPipeline
from extraction.keyphrase_generation_pipeline import KeyphraseGenerationPipeline
import orjson
if "config" not in st.session_state:
with open("config.json", "r") as f:
content = f.read()
st.session_state.config = orjson.loads(content)
st.set_page_config(
page_icon="π",
page_title="Keyphrase extraction/generation with Transformers",
layout="wide",
initial_sidebar_state="auto",
)
@st.cache(allow_output_mutation=True)
def load_pipeline(chosen_model):
if "keyphrase-extraction" in chosen_model:
return KeyphraseExtractionPipeline(chosen_model)
elif "keyphrase-generation" in chosen_model:
return KeyphraseGenerationPipeline(chosen_model)
def extract_keyphrases():
st.session_state.keyphrases = pipe(st.session_state.input_text)
st.header("π Keyphrase extraction/generation with Transformers")
col1, col2 = st.columns([1, 3])
col1.subheader("Select model")
chosen_model = col1.selectbox(
"Choose your model:",
st.session_state.config.get("models"),
)
st.session_state.chosen_model = chosen_model
pipe = load_pipeline(st.session_state.chosen_model)
col2.subheader("Input your text")
st.session_state.input_text = col2.text_area(
"Input", st.session_state.config.get("example_text"), height=150
)
pressed = col2.button("Extract", on_click=extract_keyphrases)
if pressed:
col2.subheader("π§ Output")
df = pd.DataFrame(data=st.session_state.keyphrases, columns=["Keyphrases"])
col2.table(df)
|