|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
pipe = pipeline("token-classification", model="Clinical-AI-Apollo/Medical-NER", aggregation_strategy='simple') |
|
|
|
|
|
entity_colors = { |
|
"AGE": "#ffadad", |
|
"SEX": "#ffd6a5", |
|
"DISEASE_DISORDER": "#caffbf", |
|
"SIGN_SYMPTOM": "#9bf6ff", |
|
"LAB_VALUE": "#a0c4ff", |
|
"THERAPEUTIC_PROCEDURE": "#bdb2ff", |
|
"CLINICAL_EVENT": "#ffc6ff", |
|
"DIAGNOSTIC_PROCEDURE": "#fffffc", |
|
"DETAILED_DESCRIPTION": "#fdffb6", |
|
"BIOLOGICAL_STRUCTURE": "#ffb5a7" |
|
} |
|
|
|
def classify_text(text): |
|
|
|
result = pipe(text) |
|
|
|
|
|
highlighted_text = "" |
|
last_pos = 0 |
|
|
|
for res in result: |
|
entity = res['entity_group'] |
|
word = res['word'] |
|
start = res['start'] |
|
end = res['end'] |
|
|
|
|
|
highlighted_text += text[last_pos:start] |
|
|
|
|
|
color = entity_colors.get(entity, "#e0e0e0") |
|
highlighted_text += f""" |
|
<span style='background-color:{color}; padding:2px; border-radius:5px;'> |
|
{word} |
|
<span style='display:inline-block; background-color:#fff; color:#000; border-radius:3px; padding:2px; margin-left:5px; font-size:10px;'>{entity}</span> |
|
</span>""" |
|
|
|
|
|
last_pos = end |
|
|
|
|
|
highlighted_text += text[last_pos:] |
|
|
|
return f"<div style='font-family: Arial, sans-serif; line-height: 1.5;'>{highlighted_text}</div>" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_text, |
|
inputs=gr.Textbox(lines=5, label="Enter Medical Text"), |
|
outputs=gr.HTML(label="Entity Classification with Highlighting and Labels"), |
|
title="Medical Entity Classification", |
|
description="Enter medical-related text, and the model will classify medical entities with color highlighting and labels.", |
|
examples=[ |
|
["45 year old woman diagnosed with CAD"], |
|
["A 65-year-old male presents with acute chest pain and a history of hypertension."], |
|
["The patient underwent a laparoscopic cholecystectomy."] |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|