|
import uvicorn |
|
import threading |
|
from collections import Counter |
|
from typing import Optional |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
import pandas as pd |
|
|
|
from pprint import pprint |
|
|
|
import gradio as gr |
|
from transformers import pipeline |
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from typing import List, Dict |
|
|
|
|
|
|
|
app = FastAPI() |
|
model_cache: Optional[object] = None |
|
dataset_cache : Optional[object] = None |
|
|
|
def load_model(): |
|
""" We load the model at startup""" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") |
|
model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") |
|
|
|
id2label = model.config.id2label |
|
|
|
print(f"Can recognise the following labels {id2label}") |
|
|
|
|
|
|
|
model = pipeline("ner", model=model, tokenizer = tokenizer) |
|
return model |
|
|
|
def load_plod_cw_dataset(): |
|
from datasets import load_dataset |
|
dataset = load_dataset("surrey-nlp/PLOD-CW") |
|
return dataset |
|
|
|
def get_cached_data(): |
|
global dataset_cache |
|
if dataset_cache is None: |
|
dataset_cache = load_plod_cw_dataset() |
|
return dataset_cache |
|
|
|
def get_cached_model(): |
|
global model_cache |
|
if model_cache is None: |
|
model_cache = load_model() |
|
return model_cache |
|
|
|
|
|
model = get_cached_model() |
|
|
|
|
|
class Entity(BaseModel): |
|
entity: str |
|
score: float |
|
start: int |
|
end: int |
|
word: str |
|
|
|
class NERResponse(BaseModel): |
|
entities: List[Entity] |
|
|
|
class NERRequest(BaseModel): |
|
text: str |
|
|
|
@app.get("/hello") |
|
def read_root(): |
|
"""useful for testing connections""" |
|
return {"message": "Hello, World!"} |
|
|
|
|
|
@app.post("/ner", response_model=NERResponse) |
|
def get_entities(request: NERRequest): |
|
""" This is invoked while API Testing """ |
|
print(request) |
|
|
|
model = get_cached_model() |
|
|
|
|
|
entities = model(request.text) |
|
|
|
print(entities[0].keys()) |
|
|
|
response_entities = [Entity(**entity) for entity in entities] |
|
print(response_entities[0]) |
|
return NERResponse(entities=response_entities) |
|
|
|
def get_color_for_label(label: str) -> str: |
|
|
|
color_mapping = { |
|
"I-LF": "red", |
|
"B-LF": "pink", |
|
"B-AC": "blue", |
|
"B-O": "green", |
|
|
|
} |
|
return color_mapping.get(label, "black") |
|
|
|
|
|
|
|
def ner_demo(text): |
|
""" This is invoked while rendering the page""" |
|
model = get_cached_model() |
|
entities = model(text) |
|
|
|
print("Entities detected {}".format(Counter( [ entity['entity'] for entity in entities]))) |
|
|
|
all_html = "" |
|
last_index = 0 |
|
|
|
for entity in entities: |
|
start, end, label = entity["start"], entity["end"], entity["entity"] |
|
color = get_color_for_label(label) |
|
entity_text = text[start:end] |
|
|
|
colored_entity = f'<sup style="color: {color}; font-weight: bold;">{entity_text}</sup>' |
|
|
|
|
|
|
|
all_html += text[last_index:start] |
|
|
|
all_html += colored_entity |
|
|
|
last_index = end |
|
|
|
|
|
all_html += text[last_index:] |
|
return all_html |
|
|
|
bo_color = get_color_for_label("B-O") |
|
bac_color = get_color_for_label("B-AC") |
|
ilf_color = get_color_for_label("I-LF") |
|
blf_color = get_color_for_label("B-LF") |
|
|
|
PROJECT_INTRO = f"""This is a HF Spaces hosted Gradio App built by NLP Group 27. \n\n |
|
The model has been trained on surrey-nlp/PLOD-CW dataset. |
|
The following Entities are recognized: |
|
<sup style="color: {bo_color}; font-weight: bold;">B-O</sup> |
|
<sup style="color: {bac_color}; font-weight: bold;">B-AC</sup> |
|
<sup style="color: {ilf_color}; font-weight: bold;">I-LF</sup> |
|
<sup style="color: {blf_color}; font-weight: bold;">B-LF</sup> |
|
<sup style="color: black; font-weight: bold;">Rest</sup> |
|
""" |
|
def echo(text, request: gr.Request): |
|
res = '<div>' |
|
if request: |
|
res += f"Request headers dictionary: {request.headers} <p>" |
|
res += f"IP address: {request.client.host} <p>" |
|
res += f"Query parameters: {dict(request.query_params)} <p>" |
|
res += "</div>" |
|
|
|
return res |
|
|
|
def sample_data(text): |
|
text = "The red dots represents LCI , the bright yellow rectangle represents RV , and the black triangle represents the /TLCnLCI" |
|
|
|
|
|
|
|
|
|
|
|
data = { |
|
"Text": [text], |
|
"Length": [len(text)] |
|
} |
|
df = pd.DataFrame(data) |
|
return df |
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=ner_demo, |
|
inputs=gr.Textbox(lines=10, placeholder="Enter text here..."), |
|
outputs="html", |
|
|
|
title="Named Entity Recognition on PLOD-CW ", |
|
description=f"{PROJECT_INTRO}\n\nEnter text to extract named entities using a NER model." |
|
) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Named Entity Recognition on PLOD-CW") |
|
gr.Markdown(PROJECT_INTRO) |
|
gr.Markdown("### Enter text to extract named entities using a NER model.") |
|
text_input = gr.Textbox(lines=10, placeholder="Enter text here...", label="Input Text") |
|
html_output = gr.HTML(label="HTML Output") |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Submit") |
|
echo_button = gr.Button("Echo Client") |
|
sample_button = gr.Button("Sample PLOD_CW") |
|
|
|
sample_output = gr.Dataframe(label="Sample Table") |
|
echo_output = gr.HTML(label="HTML Output") |
|
|
|
submit_button.click(ner_demo, inputs=text_input, outputs=html_output) |
|
|
|
echo_button.click(echo, inputs=text_input, outputs=echo_output) |
|
sample_button.click(sample_data, inputs=text_input, outputs=sample_output) |
|
|
|
|
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|