import gradio as gr from transformers import pipeline # Load the token classification model pipe = pipeline("token-classification", model="Clinical-AI-Apollo/Medical-NER", aggregation_strategy='simple') # Define colors for different entity types entity_colors = { "AGE": "#ffadad", "SEX": "#ffd6a5", "DISEASE_DISORDER": "#caffbf", "SIGN_SYMPTOM": "#9bf6ff", "LAB_VALUE": "#a0c4ff", "THERAPEUTIC_PROCEDURE": "#bdb2ff", "CLINICAL_EVENT": "#ffc6ff", "DIAGNOSTIC_PROCEDURE": "#fffffc", "DETAILED_DESCRIPTION": "#fdffb6", "BIOLOGICAL_STRUCTURE": "#ffb5a7" } def classify_text(text): # Get token classification results result = pipe(text) # Format the results into HTML with color highlighting and entity names highlighted_text = "" last_pos = 0 for res in result: entity = res['entity_group'] word = res['word'] start = res['start'] end = res['end'] # Add text before the entity without highlighting highlighted_text += text[last_pos:start] # Add highlighted entity text with the entity name displayed color = entity_colors.get(entity, "#e0e0e0") # Default to gray if entity type not defined highlighted_text += f""" {word} {entity} """ # Update last position last_pos = end # Add the rest of the text after the last entity highlighted_text += text[last_pos:] return f"
{highlighted_text}
" # Gradio Interface demo = gr.Interface( fn=classify_text, inputs=gr.Textbox(lines=5, label="Enter Medical Text"), outputs=gr.HTML(label="Entity Classification with Highlighting and Labels"), title="Medical Entity Classification", description="Enter medical-related text, and the model will classify medical entities with color highlighting and labels.", examples=[ ["45 year old woman diagnosed with CAD"], ["A 65-year-old male presents with acute chest pain and a history of hypertension."], ["The patient underwent a laparoscopic cholecystectomy."] ] ) if __name__ == "__main__": demo.launch()