arunavsk1's picture
Update app.py
ceedf97
raw
history blame contribute delete
No virus
5.34 kB
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, DebertaV2Tokenizer, DebertaV2Model
import sentencepiece
import streamlit as st
import pandas as pd
import spacy
example_list = [
"The primary outcome was overall survival (OS).",
"Overall survival was not significantly different between the groups (hazard ratio [HR], 0.87; 95% CI, 0.66-1.16; P = .34).",
"Afatinib-an oral irreversible ErbB family blocker-improves progression-free survival compared with pemetrexed and cisplatin for first-line treatment of patients with EGFR mutation-positive advanced non-small-cell lung cancer (NSCLC)."
]
st.set_page_config(layout="wide")
st.title("Demo for EIC NER")
model_list = ['arunavsk1/my-awesome-pubmed-bert'
# 'akdeniz27/convbert-base-turkish-cased-ner',
# 'akdeniz27/xlm-roberta-base-turkish-ner',
# 'xlm-roberta-large-finetuned-conll03-english'
]
# st.sidebar.header("Select NER Model")
model_checkpoint = st.sidebar.radio("", model_list)
# st.sidebar.write("For details of models: 'https://huggingface.co/akdeniz27/")
# st.sidebar.write("")
# xlm_agg_strategy_info = "'aggregation_strategy' can be selected as 'simple' or 'none' for 'xlm-roberta' because of the RoBERTa model's tokenization approach."
# st.sidebar.header("Select Aggregation Strategy Type")
# if model_checkpoint == "akdeniz27/xlm-roberta-base-turkish-ner":
# aggregation = st.sidebar.radio("", ('simple', 'none'))
# st.sidebar.write(xlm_agg_strategy_info)
# elif model_checkpoint == "xlm-roberta-large-finetuned-conll03-english":
# aggregation = st.sidebar.radio("", ('simple', 'none'))
# st.sidebar.write(xlm_agg_strategy_info)
# st.sidebar.write("")
# st.sidebar.write("This English NER model is included just to show the zero-shot transfer learning capability of XLM-Roberta.")
# else:
# aggregation = st.sidebar.radio("", ('first', 'simple', 'average', 'max', 'none'))
# st.sidebar.write("Please refer 'https://huggingface.co/transformers/_modules/transformers/pipelines/token_classification.html' for entity grouping with aggregation_strategy parameter.")
aggregation = 'none'
st.subheader("Select Text Input Method")
input_method = st.radio("", ('Select from Examples', 'Write or Paste New Text'))
if input_method == 'Select from Examples':
selected_text = st.selectbox('Select Text from List', example_list, index=0, key=1)
st.subheader("Text to Run")
input_text = st.text_area("Selected Text", selected_text, height=128, max_chars=None, key=2)
elif input_method == "Write or Paste New Text":
st.subheader("Text to Run")
input_text = st.text_area('Write or Paste Text Below', value="", height=128, max_chars=None, key=2)
@st.cache(allow_output_mutation=True)
def setModel(model_checkpoint, aggregation):
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
return pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy=aggregation)
@st.cache(allow_output_mutation=True)
def get_html(html: str):
WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
html = html.replace("\n", " ")
return WRAPPER.format(html)
Run_Button = st.button("Run", key=None)
if Run_Button == True:
ner_pipeline = setModel(model_checkpoint, aggregation)
output = ner_pipeline(input_text)
df = pd.DataFrame.from_dict(output)
if aggregation != "none":
cols_to_keep = ['word','entity_group','score','start','end']
else:
cols_to_keep = ['word','entity','score','start','end']
df_final = df[cols_to_keep]
# st.subheader("Recognized Entities")
# st.dataframe(df_final)
st.subheader("Spacy Style Display")
spacy_display = {}
spacy_display["ents"] = []
spacy_display["text"] = input_text
spacy_display["title"] = None
entity_map = {'LABEL_0': 'O',
'LABEL_1': 'B-Intervention',
'LABEL_2': 'I-Intervention',
'LABEL_3': 'B-Outcome',
'LABEL_4': 'I-Outcome',
'LABEL_5': 'B-Value',
'LABEL_6': 'I-Value'}
for entity in output:
if aggregation != "none":
spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity_map[entity["entity_group"]]})
else:
spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity_map[entity["entity"]]})
entity_list = ['B-Intervention', 'I-Intervention', 'B-Outcome', 'I-Outcome', 'B-Value', 'I-Value']
colors = {'B-Intervention': '#85DCDF',
'I-Intervention': '#85DCDF',
'B-Outcome': '#DF85DC',
'I-Outcome': '#DF85DC',
'B-Value': '#DCDF85',
'I-Value': '#DCDF85'}
html = spacy.displacy.render(spacy_display, style="ent", minify=True, manual=True, options={"ents": entity_list, "colors": colors})
style = "<style>mark.entity { display: inline-block }</style>"
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
st.subheader("Recognized Entities")
st.dataframe(df_final)