|
import gradio as gr |
|
from faster_whisper import WhisperModel |
|
from pydantic import BaseModel, Field, AliasChoices, field_validator, ValidationError |
|
from typing import List |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
import csv |
|
import json |
|
import tempfile |
|
import torch |
|
import os |
|
|
|
|
|
os.environ["COMMANDLINE_ARGS"] = "--no-gradio-queue" |
|
|
|
|
|
|
|
|
|
numind_checkpoint = "numind/NuExtract-tiny" |
|
llama_checkpoint = "Atereoyin/Llama3_finetuned_for_medical_entity_extraction" |
|
whisper_checkpoint = "base" |
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_8bit=True, |
|
) |
|
|
|
|
|
whisper_model = WhisperModel(whisper_checkpoint, device="cuda") |
|
numind_model = AutoModelForCausalLM.from_pretrained(numind_checkpoint, quantization_config=quantization_config, torch_dtype=torch.float16, trust_remote_code=True) |
|
numind_tokenizer = AutoTokenizer.from_pretrained(numind_checkpoint) |
|
llama_model = AutoModelForCausalLM.from_pretrained(llama_checkpoint, quantization_config=quantization_config, trust_remote_code=True) |
|
llama_tokenizer = AutoTokenizer.from_pretrained(llama_checkpoint) |
|
|
|
|
|
def transcribe_audio(audio_file_path): |
|
try: |
|
segments, info = whisper_model.transcribe(audio_file_path, beam_size=5) |
|
text = "".join([segment.text for segment in segments]) |
|
return text |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
def predict_NuExtract(model, tokenizer, text, schema, example=["","",""]): |
|
schema = json.dumps(json.loads(schema), indent=4) |
|
input_llm = "<|input|>\n### Template:\n" + schema + "\n" |
|
for i in example: |
|
if i != "": |
|
input_llm += "### Example:\n"+ json.dumps(json.loads(i), indent=4)+"\n" |
|
|
|
input_llm += "### Text:\n"+text +"\n<|output|>\n" |
|
input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=4000).to("cuda") |
|
|
|
output = tokenizer.decode(model.generate(**input_ids)[0], skip_special_tokens=True) |
|
return output.split("<|output|>")[1].split("<|end-output|>")[0] |
|
|
|
|
|
|
|
def prompt_format(text): |
|
prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
|
|
|
### Instruction: |
|
{} |
|
|
|
### Input: |
|
{} |
|
|
|
### Response: |
|
{}""" |
|
|
|
instruction = """Extract the following entities from the medical conversation: |
|
* **Symptoms:** List all the symptoms the patient mentions. |
|
* **Diagnosis:** List the doctor's diagnosis or potential diagnoses. |
|
* **Medical History:** Summarize the patient's relevant medical history. |
|
* **Action Plan:** List the recommended actions or treatment plan. |
|
|
|
Provide the result in the following JSON format: |
|
{ |
|
"Symptoms": [...], |
|
"Diagnosis": [...], |
|
"Medical history": [...], |
|
"Action plan": [...] |
|
}""" |
|
full_prompt = prompt.format(instruction, text, "") |
|
return full_prompt |
|
|
|
|
|
|
|
def validate_medical_record(response): |
|
|
|
class MedicalRecord(BaseModel): |
|
Symptoms: List[str] = Field(default_factory=list) |
|
Diagnosis: List[str] = Field(default_factory=list) |
|
Medical_history: List[str] = Field( |
|
default_factory=list, |
|
validation_alias=AliasChoices('Medical history', 'History of Patient') |
|
) |
|
Action_plan: List[str] = Field( |
|
default_factory=list, |
|
validation_alias=AliasChoices('Action plan', 'Plan of Action') |
|
) |
|
|
|
@field_validator('*', mode='before') |
|
def ensure_list(cls, v): |
|
if isinstance(v, str): |
|
return [item.strip() for item in v.split(',')] |
|
return v |
|
|
|
try: |
|
validated_data = MedicalRecord(**response) |
|
return validated_data.dict() |
|
except ValidationError as e: |
|
return response |
|
|
|
|
|
|
|
|
|
def predict_Llama(model, tokenizer, text): |
|
inputs = tokenizer(prompt_format(text), return_tensors="pt", truncation=True).to("cuda") |
|
|
|
try: |
|
outputs = model.generate(**inputs, max_new_tokens=128, temperature=0.2, use_cache=True) |
|
extracted_entities = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
response = extracted_entities.split("### Response:", 1)[-1].strip() |
|
response_dict = {k.strip(): v.strip() for k, v in (line.split(': ', 1) for line in response.splitlines() if ': ' in line)} |
|
|
|
validated_response = validate_medical_record(response_dict) |
|
|
|
return validated_response |
|
except Exception as e: |
|
print(f"Error during Llama prediction: {str(e)}") |
|
return {} |
|
|
|
|
|
|
|
def process_audio(audio): |
|
if isinstance(audio, str): |
|
with open(audio, 'rb') as f: |
|
audio_bytes = f.read() |
|
else: |
|
audio_bytes = audio |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
|
temp_audio.write(audio_bytes) |
|
temp_audio.flush() |
|
audio_path = temp_audio.name |
|
|
|
transcription = transcribe_audio(audio_path) |
|
|
|
person_schema = """{"Name": "","Age": "","Gender": ""}""" |
|
person_entities_raw = predict_NuExtract(numind_model, numind_tokenizer, transcription, person_schema) |
|
|
|
try: |
|
person_entities = json.loads(person_entities_raw) |
|
except json.JSONDecodeError as e: |
|
return f"Error in NuExtract response: {str(e)}" |
|
|
|
medical_entities = predict_Llama(llama_model, llama_tokenizer, transcription) |
|
|
|
return ( |
|
person_entities.get("Name", ""), |
|
person_entities.get("Age", ""), |
|
person_entities.get("Gender", ""), |
|
", ".join(medical_entities.get("Symptoms", [])), |
|
", ".join(medical_entities.get("Diagnosis", [])), |
|
", ".join(medical_entities.get("Medical_history", [])), |
|
", ".join(medical_entities.get("Action_plan", [])) |
|
) |
|
|
|
|
|
|
|
|
|
def download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan): |
|
csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") |
|
|
|
with open(csv_file.name, mode='w', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow(["Name", "Age", "Gender", "Symptoms", "Diagnosis", "Medical History", "Plan of Action"]) |
|
writer.writerow([name, age, gender, symptoms, diagnosis, medical_history, action_plan]) |
|
|
|
return csv_file.name |
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=process_audio, |
|
inputs=[ |
|
gr.Audio(type="filepath") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Name"), |
|
gr.Textbox(label="Age"), |
|
gr.Textbox(label="Gender"), |
|
gr.Textbox(label="Symptoms"), |
|
gr.Textbox(label="Diagnosis"), |
|
gr.Textbox(label="Medical History"), |
|
gr.Textbox(label="Plan of Action"), |
|
], |
|
title="Medical Diagnostic Form Assistant", |
|
description="Upload an audio file or record audio to generate a medical diagnostic form." |
|
) |
|
|
|
with demo: |
|
download_button = gr.Button("Download CSV") |
|
download_button.click( |
|
fn=lambda name, age, gender, symptoms, diagnosis, medical_history, action_plan: download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan), |
|
inputs=demo.output_components, |
|
outputs=gr.File(label="Download CSV") |
|
) |
|
|
|
demo.launch() |
|
|
|
|