Spaces:
Sleeping
Sleeping
import streamlit as st | |
import wandb | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
x = st.slider('Select a value') | |
st.write(x, 'squared is', x * x) | |
def load_trained_model(): | |
tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-base-cased-sourav") | |
model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-base-cased-sourav") | |
# Mapping labels | |
label_map = model.config.id2label | |
# Print the label mapping | |
print(label_map) | |
# Load the NER model and tokenizer from Hugging Face | |
#ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english") | |
ner_pipeline = pipeline("ner", model=model, tokenizer = tokenizer) | |
return ner_pipeline | |
def prep_page(): | |
model = load_trained_model() | |
# Streamlit app | |
st.title("Named Entity Recognition with BERT on PLOD-CW") | |
st.write("Enter a sentence to see the named entities recognized by the model.") | |
# Text input | |
text = st.text_area("Enter your sentence here:") | |
# Perform NER and display results | |
if text: | |
st.write("Entities recognized:") | |
entities = model(text) | |
# Create a dictionary to map entity labels to colors | |
label_colors = { | |
'ORG': 'lightblue', | |
'PER': 'lightgreen', | |
'LOC': 'lightcoral', | |
'MISC': 'lightyellow' | |
} | |
# Prepare the HTML output with styled entities | |
def get_entity_html(text, entities): | |
html = "" | |
last_idx = 0 | |
for entity in entities: | |
start = entity['start'] | |
end = entity['end'] | |
label = entity['entity'] | |
entity_text = text[start:end] | |
color = label_colors.get(label, 'lightgray') | |
# Append the text before the entity | |
html += text[last_idx:start] | |
# Append the entity with styling | |
html += f'<mark style="background-color: {color}; border-radius: 3px;">{entity_text}</mark>' | |
last_idx = end | |
# Append any remaining text after the last entity | |
html += text[last_idx:] | |
return html | |
# Generate and display the styled HTML | |
styled_text = get_entity_html(text, entities) | |
st.markdown(styled_text, unsafe_allow_html=True) | |
if __name__ == '__main__': | |
models = load_model_from_wandb() | |
print(models) | |
prep_page() |