Lamp Socrates
latest
4efeb3b
import uvicorn
import threading
from collections import Counter
from typing import Optional
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
#import datasets
from pprint import pprint
import gradio as gr
from transformers import pipeline
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict
# Define the FastAPI app
app = FastAPI()
model_cache: Optional[object] = None
dataset_cache : Optional[object] = None
def load_model():
""" We load the model at startup"""
tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
# Mapping labels
id2label = model.config.id2label
# Print the label mapping
print(f"Can recognise the following labels {id2label}")
# Load the NER model and tokenizer from Hugging Face
#ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
model = pipeline("ner", model=model, tokenizer = tokenizer)
return model
def load_plod_cw_dataset():
from datasets import load_dataset
dataset = load_dataset("surrey-nlp/PLOD-CW")
return dataset
def get_cached_data():
global dataset_cache
if dataset_cache is None:
dataset_cache = load_plod_cw_dataset()
return dataset_cache
def get_cached_model():
global model_cache
if model_cache is None:
model_cache = load_model()
return model_cache
# Cache the model when the server starts
model = get_cached_model()
#plod_cw = get_cached_data()
class Entity(BaseModel):
entity: str
score: float
start: int
end: int
word: str
class NERResponse(BaseModel):
entities: List[Entity]
class NERRequest(BaseModel):
text: str
@app.get("/hello")
def read_root():
"""useful for testing connections"""
return {"message": "Hello, World!"}
@app.post("/ner", response_model=NERResponse)
def get_entities(request: NERRequest):
""" This is invoked while API Testing """
print(request)
model = get_cached_model()
# Use the NER model to detect entities
entities = model(request.text)
print(entities[0].keys())
# Convert entities to the response model
response_entities = [Entity(**entity) for entity in entities]
print(response_entities[0])
return NERResponse(entities=response_entities)
def get_color_for_label(label: str) -> str:
# Define a mapping of labels to colors
color_mapping = {
"I-LF": "red",
"B-LF": "pink",
"B-AC": "blue",
"B-O": "green",
# Add more labels and colors as needed
}
return color_mapping.get(label, "black") # Default to black if label not found
# Define the Gradio interface function
def ner_demo(text):
""" This is invoked while rendering the page"""
model = get_cached_model()
entities = model(text)
print("Entities detected {}".format(Counter( [ entity['entity'] for entity in entities])))
all_html = ""
last_index = 0
for entity in entities:
start, end, label = entity["start"], entity["end"], entity["entity"]
color = get_color_for_label(label)
entity_text = text[start:end]
#colored_entity = f'<span style="color: {color}; font-weight: bold;">{entity_text}</span>'
colored_entity = f'<sup style="color: {color}; font-weight: bold;">{entity_text}</sup>'
# Append text before the entity
all_html += text[last_index:start]
# Append the colored entity
all_html += colored_entity
# Update the last_index
last_index = end
# Append the remaining text after the last entity
all_html += text[last_index:]
return all_html
bo_color = get_color_for_label("B-O")
bac_color = get_color_for_label("B-AC")
ilf_color = get_color_for_label("I-LF")
blf_color = get_color_for_label("B-LF")
PROJECT_INTRO = f"""This is a HF Spaces hosted Gradio App built by NLP Group 27. \n\n
The model has been trained on surrey-nlp/PLOD-CW dataset.
The following Entities are recognized:
<sup style="color: {bo_color}; font-weight: bold;">B-O</sup>
<sup style="color: {bac_color}; font-weight: bold;">B-AC</sup>
<sup style="color: {ilf_color}; font-weight: bold;">I-LF</sup>
<sup style="color: {blf_color}; font-weight: bold;">B-LF</sup>
<sup style="color: black; font-weight: bold;">Rest</sup>
"""
def echo(text, request: gr.Request):
res = '<div>'
if request:
res += f"Request headers dictionary: {request.headers} <p>"
res += f"IP address: {request.client.host} <p>"
res += f"Query parameters: {dict(request.query_params)} <p>"
res += "</div>"
return res
def sample_data(text):
text = "The red dots represents LCI , the bright yellow rectangle represents RV , and the black triangle represents the /TLCnLCI"
#dat = get_cached_data()
#df = dat['test']['tokens'].sample(5)
data = {
"Text": [text],
"Length": [len(text)]
}
df = pd.DataFrame(data)
return df
# Create the Gradio interface
demo = gr.Interface(
fn=ner_demo,
inputs=gr.Textbox(lines=10, placeholder="Enter text here..."),
outputs="html",
#outputs=gr.JSON(),
title="Named Entity Recognition on PLOD-CW ",
description=f"{PROJECT_INTRO}\n\nEnter text to extract named entities using a NER model."
)
with gr.Blocks() as demo:
gr.Markdown("# Named Entity Recognition on PLOD-CW")
gr.Markdown(PROJECT_INTRO)
gr.Markdown("### Enter text to extract named entities using a NER model.")
text_input = gr.Textbox(lines=10, placeholder="Enter text here...", label="Input Text")
html_output = gr.HTML(label="HTML Output")
with gr.Row():
submit_button = gr.Button("Submit")
echo_button = gr.Button("Echo Client")
sample_button = gr.Button("Sample PLOD_CW")
sample_output = gr.Dataframe(label="Sample Table")
echo_output = gr.HTML(label="HTML Output")
submit_button.click(ner_demo, inputs=text_input, outputs=html_output)
echo_button.click(echo, inputs=text_input, outputs=echo_output)
sample_button.click(sample_data, inputs=text_input, outputs=sample_output)
# Function to run Gradio
demo.launch(server_name="0.0.0.0", server_port=7860)