import gradio as gr from transformers import pipeline # Load the token classification model pipe = pipeline("token-classification", model="Clinical-AI-Apollo/Medical-NER", aggregation_strategy='simple') # Define colors for different entity types entity_colors = { "AGE": "#ffadad", "SEX": "#ffd6a5", "DISEASE_DISORDER": "#caffbf", "SIGN_SYMPTOM": "#9bf6ff", "LAB_VALUE": "#a0c4ff", "THERAPEUTIC_PROCEDURE": "#bdb2ff", "CLINICAL_EVENT": "#ffc6ff", "DIAGNOSTIC_PROCEDURE": "#fffffc", "DETAILED_DESCRIPTION": "#fdffb6", "BIOLOGICAL_STRUCTURE": "#ffb5a7" } def classify_text(text): # Get token classification results result = pipe(text) # Format the results into HTML with color highlighting and entity names highlighted_text = "" last_pos = 0 for res in result: entity = res['entity_group'] word = res['word'] start = res['start'] end = res['end'] # Add text before the entity without highlighting highlighted_text += text[last_pos:start] # Add highlighted entity text with the entity name displayed color = entity_colors.get(entity, "#e0e0e0") # Default to gray if entity type not defined highlighted_text += f""" {word} {entity} """ # Update last position last_pos = end # Add the rest of the text after the last entity highlighted_text += text[last_pos:] return f"