Lamp Socrates
latest
7551cdd
raw
history blame
6.8 kB
import streamlit as st
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
from pprint import pprint
@st.cache_resource()
def load_trained_model():
tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
# Mapping labels
id2label = model.config.id2label
# Print the label mapping
print(f"Can recognise the following labels {id2label}")
# Load the NER model and tokenizer from Hugging Face
#ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
ner_pipeline = pipeline("ner", model=model, tokenizer = tokenizer)
return ner_pipeline
@st.cache_data()
def load_plod_cw_dataset():
from datasets import load_dataset
dataset = load_dataset("surrey-nlp/PLOD-CW")
return dataset
def load_random_examples(dataset_name, num_examples=5):
"""
Load random examples from the specified Hugging Face dataset.
Args:
dataset_name (str): The name of the dataset to load.
num_examples (int): The number of random examples to load.
Returns:
pd.DataFrame: A DataFrame containing the random examples.
"""
# Load the dataset
dat = load_plod_cw_dataset()
# Convert the dataset to a pandas DataFrame
df = pd.DataFrame(dat['test'])
# Select random examples
random_examples = df.sample(n=1)
tokens = random_examples.tokens
ner_tags = random_examples.ner_tags
return pd.DataFrame((tokens, ner_tags))
def render_entities(tokens, entities):
"""
Renders a page with a 2-column table showing the entity corresponding to each token.
"""
# Custom CSS for chilled and cool theme
st.markdown("""
<style>
body {
font-family: 'Arial', sans-serif;
background-color: #f0f0f5;
color: #333333;
}
table {
width: 100%;
border-collapse: collapse;
}
th, td {
padding: 12px;
text-align: left;
border-bottom: 1px solid #dddddd;
}
th {
background-color: #4CAF50;
color: white;
width: 16.66%;
}
tr:hover {
background-color: #f5f5f5;
}
td {
width: 16.66%;
}
</style>
""", unsafe_allow_html=True)
# Title and description
st.title("Model predicted Token vs Entities Table")
st.write("This table shows the entity corresponding to each token in a cool and chilled theme.")
# Create the table
table_data = {"Token": tokens, "Entity": entities}
st.table(table_data)
def render_random_examples():
"""
Render random examples from the PLOD-CW dataset in a Streamlit table.
"""
# Load random examples
# Custom CSS for chilled and cool theme
st.markdown("""
<style>
body {
font-family: 'Arial', sans-serif;
background-color: #f0f0f5;
color: #333333;
}
table {
width: 100%;
border-collapse: collapse;
}
th, td {
padding: 12px;
text-align: left;
border-bottom: 1px solid #dddddd;
}
th {
background-color: #4CAF50;
color: white;
width: 16.66%;
}
tr:hover {
background-color: #f5f5f5;
}
td {
width: 16.66%;
}
</style>
""", unsafe_allow_html=True)
# Title and description
st.title("Random Examples from PLOD-CW")
st.write("This table shows 1 random examples from the PLOD-CW dataset in a cool and chilled theme.")
# Add a button to select a different set of random samples
if st.button('Show another set of random examples'):
st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW")
# Load random examples if not already loaded
if 'random_examples' not in st.session_state:
st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW")
# Display the table
st.table(st.session_state['random_examples'])
def predict_using_trained(sentence):
model = load_trained_model()
entities = model(sentence)
return entities
def prep_page():
model = load_trained_model()
# Streamlit app
# Page configuration
#st.set_page_config(page_title="NER Token Entities", layout="centered")
st.title("Named Entity Recognition with BERT on PLOD-CW")
st.write("Enter a sentence to see the named entities recognized by the model.")
# Text input
text = st.text_area("Enter your sentence here:")
# Perform NER and display results
if text:
st.write("Entities recognized:")
entities = model(text)
pprint(entities)
# Create a dictionary to map entity labels to colors
label_colors = {
'B-LF': 'lightblue',
'B-O': 'lightgreen',
'B-AC': 'lightcoral',
'I-LF': 'lightyellow'
}
# Prepare the HTML output with styled entities
def get_entity_html(text, entities):
html = "<div>"
last_idx = 0
for entity in entities:
start = entity['start']
end = entity['end']
label = entity['entity']
entity_text = text[start:end]
color = label_colors.get(label, 'lightgray')
# Append the text before the entity
html += text[last_idx:start].replace(" ", "<br>")
# Append the entity with styling
html += f'<div style="background-color: {color}; padding: 5px; border-radius: 3px; margin: 5px 0;">{entity_text}</div>'
last_idx = end
# Append any remaining text after the last entity
html += text[last_idx:].replace(" ", "<br>")
html += "</div>"
return html
# Generate and display the styled HTML
styled_text = get_entity_html(text, entities)
st.markdown(styled_text, unsafe_allow_html=True)
render_entities(text, entities)
render_random_examples()
if __name__ == '__main__':
query_params = st.query_params
if 'api' in query_params:
sentence = query_params.get('sentence')
entities = predict_using_trained(sentence)
response = {"sentence" : sentence , "entities" : entities}
pprint(response)
st.write(response)
else:
prep_page()