ner-turkish / app.py
pnr-svc's picture
Update app.py
35494e1
import streamlit as st
from annotated_text import annotated_text
import transformers
ENTITY_TO_COLOR = {
'DepositProduct': '#edff87',
'Product': '#d586ff',
'ProductProblemInfo': '#9886ff',
'ServiceInformation': '#ff9886',
'ServiceClosest': '#ff86b0',
'Location': '#d461be',
'ServiceNumber': '#f9cde4',
'Brand': '#ffd4a4',
'Campaign': '#bcffd8',
'ProductSelector': '#fb5d4e',
'SpecialCampaign': '#f56286',
}
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_pipe():
model_name = "pnr-svc/distilbert-turkish-ner"
model = transformers.AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
pipe = transformers.pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
return pipe
def parse_text(text, prediction):
start = 0
parsed_text = []
for p in prediction:
parsed_text.append(text[start:p["start"]])
parsed_text.append((p["word"], p["entity_group"], ENTITY_TO_COLOR[p["entity_group"]]))
start = p["end"]
parsed_text.append(text[start:])
return parsed_text
st.set_page_config(page_title="NER ARÇELİK")
st.title("NER ARÇELİK")
st.write("Type text into the text box and then press 'Predict' to get the named entities.")
default_text = "tekirdağ çerkezköy arçelik yetkili servis no paylaş"
text = st.text_area('Enter text here:', value=default_text)
submit = st.button('Predict')
with st.spinner("Loading model..."):
pipe = get_pipe()
if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
prediction = pipe(text)
parsed_text = parse_text(text, prediction)
st.header("Prediction:")
annotated_text(*parsed_text)
st.header('Raw values:')
st.json(prediction)