Spaces:
Runtime error
Runtime error
import json | |
from collections import defaultdict, Counter | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
import pandas as pd | |
from transformers import pipeline | |
plt.switch_backend("Agg") | |
examples = [] | |
with open("examples.json", "r") as f: | |
content = json.load(f) | |
examples = [f"{x['label']}: {x['text']}" for x in content] | |
pipe = pipeline( | |
"ner", | |
model="Clinical-AI-Apollo/Medical-NER", | |
aggregation_strategy="simple", | |
) | |
def plot_to_figure(grouped): | |
fig = plt.figure() | |
plt.bar(x=list(grouped.keys()), height=list(grouped.values())) | |
plt.margins(0.2) | |
plt.subplots_adjust(bottom=0.4) | |
plt.xticks(rotation=90) | |
return fig | |
def run_ner(text): | |
raw = pipe(text) | |
ner_content = { | |
"text": text, | |
"entities": [ | |
{ | |
"entity": x["entity_group"], | |
"word": x["word"], | |
"score": x["score"], | |
"start": x["start"], | |
"end": x["end"], | |
} | |
for x in raw | |
], | |
} | |
grouped = Counter((x["entity_group"] for x in raw)) | |
rows = [[k, v] for k, v in grouped.items()] | |
figure = plot_to_figure(grouped) | |
return ner_content, rows, figure | |
with gr.Blocks() as demo: | |
note = gr.Textbox(label="Note text") | |
submit = gr.Button("Submit") | |
# with gr.Accordion("Examples", open=False): | |
example_dropdown = gr.Dropdown(label="Examples", choices=examples) | |
example_dropdown.change(lambda x: x, inputs=example_dropdown, outputs=note) | |
highlight = gr.HighlightedText(label="NER", combine_adjacent=True) | |
table = gr.Dataframe(headers=["Entity", "Count"]) | |
submit.click(run_ner, [note], [highlight, table]) | |
note.submit(run_ner, [note], [highlight, table]) | |
demo.launch() | |