File size: 7,942 Bytes
cd33ba8
 
 
 
 
 
 
 
 
7370511
 
 
 
 
cd33ba8
 
 
 
 
7a56ff8
cd33ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7370511
cd33ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7370511
cd33ba8
 
 
 
 
 
 
 
 
 
f80f102
cd33ba8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import gradio as gr
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, AliasChoices, field_validator, ValidationError
from typing import List
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import csv
import json
import tempfile
import torch
import os

# Set environment variables for gradio
os.environ["COMMANDLINE_ARGS"] = "--no-gradio-queue"



# Initiate checkpoints for model loading
numind_checkpoint = "numind/NuExtract-tiny"
llama_checkpoint = "Atereoyin/Llama3_finetuned_for_medical_entity_extraction"
whisper_checkpoint = "base"

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

# Load models with the correct device
whisper_model = WhisperModel(whisper_checkpoint, device="cuda")
numind_model = AutoModelForCausalLM.from_pretrained(numind_checkpoint, quantization_config=quantization_config, torch_dtype=torch.float16, trust_remote_code=True)
numind_tokenizer = AutoTokenizer.from_pretrained(numind_checkpoint)
llama_model = AutoModelForCausalLM.from_pretrained(llama_checkpoint, quantization_config=quantization_config, trust_remote_code=True)
llama_tokenizer = AutoTokenizer.from_pretrained(llama_checkpoint)

# Function to transcribe audio
def transcribe_audio(audio_file_path):
    try:
        segments, info = whisper_model.transcribe(audio_file_path, beam_size=5)
        text = "".join([segment.text for segment in segments])
        return text
    except Exception as e:
        return str(e)

# Functions for Person entity extraction
def predict_NuExtract(model, tokenizer, text, schema, example=["","",""]):
    schema = json.dumps(json.loads(schema), indent=4)
    input_llm =  "<|input|>\n### Template:\n" +  schema + "\n"
    for i in example:
        if i != "":
            input_llm += "### Example:\n"+ json.dumps(json.loads(i), indent=4)+"\n"

    input_llm +=  "### Text:\n"+text +"\n<|output|>\n"
    input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=4000).to("cuda")

    output = tokenizer.decode(model.generate(**input_ids)[0], skip_special_tokens=True)
    return output.split("<|output|>")[1].split("<|end-output|>")[0]


#Function for generating promtps for Llama
def prompt_format(text):
    prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

              ### Instruction:
              {}

              ### Input:
              {}

              ### Response:
              {}"""

    instruction = """Extract the following entities from the medical conversation:
                    * **Symptoms:** List all the symptoms the patient mentions.
                    * **Diagnosis:** List the doctor's diagnosis or potential diagnoses.
                    * **Medical History:** Summarize the patient's relevant medical history.
                    * **Action Plan:** List the recommended actions or treatment plan.

                    Provide the result in the following JSON format:
                    {
                        "Symptoms": [...],
                        "Diagnosis": [...],
                        "Medical history": [...],
                        "Action plan": [...]
                    }"""
    full_prompt = prompt.format(instruction, text, "")
    return full_prompt


#Pydantic Validator to validate Llama's response
def validate_medical_record(response):

    class MedicalRecord(BaseModel):
        Symptoms: List[str] = Field(default_factory=list)
        Diagnosis: List[str] = Field(default_factory=list)
        Medical_history: List[str] = Field(
            default_factory=list,
            validation_alias=AliasChoices('Medical history', 'History of Patient')
        )
        Action_plan: List[str] = Field(
            default_factory=list,
            validation_alias=AliasChoices('Action plan', 'Plan of Action')
        )

        @field_validator('*', mode='before')
        def ensure_list(cls, v):
            if isinstance(v, str):
                return [item.strip() for item in v.split(',')]
            return v

    try:
        validated_data = MedicalRecord(**response)
        return validated_data.dict()
    except ValidationError as e:
        return response



# Function to predict medical entities using Llama
def predict_Llama(model, tokenizer, text):
    inputs = tokenizer(prompt_format(text), return_tensors="pt", truncation=True).to("cuda")

    try:
        outputs = model.generate(**inputs, max_new_tokens=128, temperature=0.2, use_cache=True)
        extracted_entities = tokenizer.decode(outputs[0], skip_special_tokens=True)

        response = extracted_entities.split("### Response:", 1)[-1].strip()
        response_dict = {k.strip(): v.strip() for k, v in (line.split(': ', 1) for line in response.splitlines() if ': ' in line)}

        validated_response = validate_medical_record(response_dict)

        return validated_response
    except Exception as e:
        print(f"Error during Llama prediction: {str(e)}")
        return {}


#Control function that cordinates communication of other functions to map entities to form fields
def process_audio(audio):
    if isinstance(audio, str):
        with open(audio, 'rb') as f:
            audio_bytes = f.read()
    else:
        audio_bytes = audio

    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
        temp_audio.write(audio_bytes)
        temp_audio.flush()
        audio_path = temp_audio.name

    transcription = transcribe_audio(audio_path)

    person_schema = """{"Name": "","Age": "","Gender": ""}"""
    person_entities_raw = predict_NuExtract(numind_model, numind_tokenizer, transcription, person_schema)

    try:
        person_entities = json.loads(person_entities_raw)
    except json.JSONDecodeError as e:
        return f"Error in NuExtract response: {str(e)}"

    medical_entities = predict_Llama(llama_model, llama_tokenizer, transcription)

    return (
        person_entities.get("Name", ""),
        person_entities.get("Age", ""),
        person_entities.get("Gender", ""),
        ", ".join(medical_entities.get("Symptoms", [])),
        ", ".join(medical_entities.get("Diagnosis", [])),
        ", ".join(medical_entities.get("Medical_history", [])),
        ", ".join(medical_entities.get("Action_plan", []))
    )



#Function that allows users to download information
def download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan):
    csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
  
    with open(csv_file.name, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Name", "Age", "Gender", "Symptoms", "Diagnosis", "Medical History", "Plan of Action"])
        writer.writerow([name, age, gender, symptoms, diagnosis, medical_history, action_plan])
    
    return csv_file.name



# Gradio interface to create a web-based form for users to input audio and fill the medical diagnostic form
demo = gr.Interface(
    fn=process_audio,
    inputs=[
        gr.Audio(type="filepath")
    ],
    outputs=[
        gr.Textbox(label="Name"),
        gr.Textbox(label="Age"),
        gr.Textbox(label="Gender"),
        gr.Textbox(label="Symptoms"),
        gr.Textbox(label="Diagnosis"),
        gr.Textbox(label="Medical History"),
        gr.Textbox(label="Plan of Action"),
    ],
    title="Medical Diagnostic Form Assistant",
    description="Upload an audio file or record audio to generate a medical diagnostic form."
)

with demo:
    download_button = gr.Button("Download CSV")
    download_button.click(
        fn=lambda name, age, gender, symptoms, diagnosis, medical_history, action_plan: download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan),
        inputs=demo.output_components,
        outputs=gr.File(label="Download CSV")
    )

demo.launch()