import os import shutil import gradio as gr import numpy as np import wfdb import torch from wfdb.plot.plot import plot_wfdb from wfdb.io.record import Record, rdrecord from models.CNN import CNN, MMCNN_CAT from models.RNN import MMRNN from utils.helper_functions import predict import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from transformers import AutoTokenizer, AutoModel from langdetect import detect # edit this before Running CWD = os.getcwd() #CKPT paths MMCNN_CAT_ckpt_path = f"{CWD}/demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt" MMRNN_ckpt_path = f"{CWD}/demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt" # Define clinical models and tokenizers en_clin_bert = 'emilyalsentzer/Bio_ClinicalBERT' ger_clin_bert = 'smanjil/German-MedBERT' en_tokenizer = AutoTokenizer.from_pretrained(en_clin_bert) en_model = AutoModel.from_pretrained(en_clin_bert) g_tokenizer = AutoTokenizer.from_pretrained(ger_clin_bert) g_model = AutoModel.from_pretrained(ger_clin_bert) def preprocess(data_file_path): data = [wfdb.rdsamp(data_file_path)] data = np.array([signal for signal, meta in data]) return data def embed(notes): if detect(notes) == 'en': tokens = en_tokenizer(notes, return_tensors='pt') outputs = en_model(**tokens) else: tokens = g_tokenizer(notes, return_tensors='pt') outputs = g_model(**tokens) embeddings = outputs.last_hidden_state embedding = torch.mean(embeddings, dim=1).squeeze(0) return embedding # return torch.load(f'{"./data/embeddings/"}1.pt') def plot_ecg(path): record100 = rdrecord(path) return plot_wfdb(record=record100, title='ECG Signal Graph', figsize=(12,10), return_fig=True) def infer(model,data, notes): embed_notes = embed(notes).unsqueeze(0) data= torch.tensor(data) if model == "CNN": model = MMCNN_CAT() checkpoint = torch.load(MMCNN_CAT_ckpt_path, map_location="cpu") model.load_state_dict(checkpoint['model_state_dict']) data = data.transpose(1,2).float() elif model == "RNN": model = MMRNN(device='cpu') model.load_state_dict(torch.load(MMRNN_ckpt_path, map_location="cpu")['model_state_dict']) data = data.float() model.eval() outputs, predicted = predict(model, data, embed_notes, device='cpu') outputs = torch.sigmoid(outputs)[0] return {'Conduction Disturbance':round(outputs[0].item(),2), 'Hypertrophy':round(outputs[1].item(),2), 'Myocardial Infarction':round(outputs[2].item(),2), 'Normal ECG':round(outputs[3].item(),2), 'ST/T Change':round(outputs[4].item(),2)} def run(model_name, header_file, data_file, notes): demo_dir = f"{CWD}/demo_data" hdr_dirname, hdr_basename = os.path.split(header_file.name) data_dirname, data_basename = os.path.split(data_file.name) shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}") shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}") data = preprocess(f"{demo_dir}/{hdr_basename.split('.')[0]}") ECG_graph = plot_ecg(f"{demo_dir}/{hdr_basename.split('.')[0]}") os.remove(f"{demo_dir}/{data_basename}") os.remove(f"{demo_dir}/{hdr_basename}") output = infer(model_name, data, notes) return output, ECG_graph with gr.Blocks() as demo: with gr.Row(): model = gr.Radio(['CNN', 'RNN'], label= "Select Model") with gr.Row(): with gr.Column(scale=1): header_file = gr.File(label = "header_file", file_types=[".hea"]) data_file = gr.File(label = "data_file", file_types=[".dat"]) notes = gr.Textbox(label = "Clinical Notes") with gr.Column(scale=1): output_prob = gr.Label({'Normal ECG':0, 'Myocardial Infarction':0, 'ST/T Change':0, 'Conduction Disturbance':0, 'Hypertrophy':0}, show_label=False) with gr.Row(): ecg_graph = gr.Plot(label = "ECG Signal Visualisation") with gr.Row(): predict_btn = gr.Button("Predict Class") predict_btn.click(fn= run, inputs = [model, header_file, data_file, notes], outputs=[output_prob, ecg_graph]) with gr.Row(): gr.Examples(examples=[[f"{CWD}/demo_data/test/00001_lr.hea", f"{CWD}/demo_data/test/00001_lr.dat", "sinusrhythmus periphere niederspannung"],\ [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \ [f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\ [f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\ ], inputs = [header_file, data_file, notes]) if __name__ == "__main__": demo.launch()