Spaces:
Runtime error
Runtime error
Committing App
Browse files- app.py +114 -0
- demo_data/.gitkeep +0 -0
- demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt +3 -0
- demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt +3 -0
- demo_data/test/00001_lr.dat +0 -0
- demo_data/test/00001_lr.hea +13 -0
- demo_data/test/00008_lr.dat +0 -0
- demo_data/test/00008_lr.hea +13 -0
- demo_data/test/00045_lr.dat +0 -0
- demo_data/test/00045_lr.hea +13 -0
- demo_data/test/00257_lr.dat +0 -0
- demo_data/test/00257_lr.hea +13 -0
- models/CNN.py +213 -0
- models/RNN.py +71 -0
- models/__pycache__/CNN.cpython-39.pyc +0 -0
- models/__pycache__/RNN.cpython-39.pyc +0 -0
- requirements.txt +13 -0
- utils/RNN_utils.py +198 -0
- utils/__pycache__/RNN_utils.cpython-39.pyc +0 -0
- utils/__pycache__/helper_functions.cpython-39.pyc +0 -0
- utils/__pycache__/trainer.cpython-39.pyc +0 -0
- utils/helper_functions.py +86 -0
- utils/trainer.py +141 -0
app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import wfdb
|
6 |
+
import torch
|
7 |
+
from wfdb.plot.plot import plot_wfdb
|
8 |
+
from wfdb.io.record import Record, rdrecord
|
9 |
+
|
10 |
+
from models.CNN import CNN, MMCNN_CAT
|
11 |
+
from models.RNN import MMRNN
|
12 |
+
from utils.helper_functions import predict
|
13 |
+
|
14 |
+
import matplotlib
|
15 |
+
matplotlib.use('Agg')
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
from transformers import AutoTokenizer, AutoModel
|
19 |
+
from langdetect import detect
|
20 |
+
|
21 |
+
# edit this before Running
|
22 |
+
CWD = os.getcwd()
|
23 |
+
#CKPT paths
|
24 |
+
MMCNN_CAT_ckpt_path = f"{CWD}/demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt"
|
25 |
+
MMRNN_ckpt_path = f"{CWD}/demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt"
|
26 |
+
|
27 |
+
# Define clinical models and tokenizers
|
28 |
+
en_clin_bert = 'emilyalsentzer/Bio_ClinicalBERT'
|
29 |
+
ger_clin_bert = 'smanjil/German-MedBERT'
|
30 |
+
|
31 |
+
en_tokenizer = AutoTokenizer.from_pretrained(en_clin_bert)
|
32 |
+
en_model = AutoModel.from_pretrained(en_clin_bert)
|
33 |
+
|
34 |
+
g_tokenizer = AutoTokenizer.from_pretrained(ger_clin_bert)
|
35 |
+
g_model = AutoModel.from_pretrained(ger_clin_bert)
|
36 |
+
|
37 |
+
def preprocess(data_file_path):
|
38 |
+
data = [wfdb.rdsamp(data_file_path)]
|
39 |
+
data = np.array([signal for signal, meta in data])
|
40 |
+
return data
|
41 |
+
|
42 |
+
def embed(notes):
|
43 |
+
if detect(notes) == 'en':
|
44 |
+
tokens = en_tokenizer(notes, return_tensors='pt')
|
45 |
+
outputs = en_model(**tokens)
|
46 |
+
else:
|
47 |
+
tokens = g_tokenizer(notes, return_tensors='pt')
|
48 |
+
outputs = g_model(**tokens)
|
49 |
+
|
50 |
+
embeddings = outputs.last_hidden_state
|
51 |
+
embedding = torch.mean(embeddings, dim=1).squeeze(0)
|
52 |
+
|
53 |
+
return embedding
|
54 |
+
# return torch.load(f'{"./data/embeddings/"}1.pt')
|
55 |
+
def plot_ecg(path):
|
56 |
+
record100 = rdrecord(path)
|
57 |
+
return plot_wfdb(record=record100, title='ECG Signal Graph', figsize=(12,10), return_fig=True)
|
58 |
+
|
59 |
+
def infer(model,data, notes):
|
60 |
+
embed_notes = embed(notes).unsqueeze(0)
|
61 |
+
data= torch.tensor(data)
|
62 |
+
if model == "CNN":
|
63 |
+
model = MMCNN_CAT()
|
64 |
+
checkpoint = torch.load(MMCNN_CAT_ckpt_path)
|
65 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
66 |
+
data = data.transpose(1,2).float()
|
67 |
+
|
68 |
+
elif model == "RNN":
|
69 |
+
model = MMRNN(device='cpu')
|
70 |
+
model.load_state_dict(torch.load(MMRNN_ckpt_path)['model_state_dict'])
|
71 |
+
data = data.float()
|
72 |
+
model.eval()
|
73 |
+
outputs, predicted = predict(model, data, embed_notes, device='cpu')
|
74 |
+
outputs = torch.sigmoid(outputs)[0]
|
75 |
+
return {'Conduction Disturbance':round(outputs[0].item(),2), 'Hypertrophy':round(outputs[1].item(),2), 'Myocardial Infarction':round(outputs[2].item(),2), 'Normal ECG':round(outputs[3].item(),2), 'ST/T Change':round(outputs[4].item(),2)}
|
76 |
+
|
77 |
+
def run(model_name, header_file, data_file, notes):
|
78 |
+
demo_dir = f"{CWD}/demo_data"
|
79 |
+
hdr_dirname, hdr_basename = os.path.split(header_file.name)
|
80 |
+
data_dirname, data_basename = os.path.split(data_file.name)
|
81 |
+
shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
|
82 |
+
shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
|
83 |
+
data = preprocess(f"{demo_dir}/{hdr_basename.split('.')[0]}")
|
84 |
+
ECG_graph = plot_ecg(f"{demo_dir}/{hdr_basename.split('.')[0]}")
|
85 |
+
os.remove(f"{demo_dir}/{data_basename}")
|
86 |
+
os.remove(f"{demo_dir}/{hdr_basename}")
|
87 |
+
output = infer(model_name, data, notes)
|
88 |
+
return output, ECG_graph
|
89 |
+
|
90 |
+
with gr.Blocks() as demo:
|
91 |
+
with gr.Row():
|
92 |
+
model = gr.Radio(['CNN', 'RNN'], label= "Select Model")
|
93 |
+
with gr.Row():
|
94 |
+
with gr.Column(scale=1):
|
95 |
+
header_file = gr.File(label = "header_file", file_types=[".hea"])
|
96 |
+
data_file = gr.File(label = "data_file", file_types=[".dat"])
|
97 |
+
notes = gr.Textbox(label = "Clinical Notes")
|
98 |
+
with gr.Column(scale=1):
|
99 |
+
output_prob = gr.Label({'Normal ECG':0, 'Myocardial Infarction':0, 'ST/T Change':0, 'Conduction Disturbance':0, 'Hypertrophy':0}, show_label=False)
|
100 |
+
with gr.Row():
|
101 |
+
ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
|
102 |
+
with gr.Row():
|
103 |
+
predict_btn = gr.Button("Predict Class")
|
104 |
+
predict_btn.click(fn= run, inputs = [model, header_file, data_file, notes], outputs=[output_prob, ecg_graph])
|
105 |
+
with gr.Row():
|
106 |
+
gr.Examples(examples=[[f"{CWD}/demo_data/test/00001_lr.hea", f"{CWD}/demo_data/test/00001_lr.dat", "sinusrhythmus periphere niederspannung"],\
|
107 |
+
[f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \
|
108 |
+
[f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\
|
109 |
+
[f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\
|
110 |
+
],
|
111 |
+
inputs = [header_file, data_file, notes])
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
demo.launch()
|
demo_data/.gitkeep
ADDED
File without changes
|
demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3735115ddc15ecab4844a13124616f339364795349aeef0476491accfa8b4eda
|
3 |
+
size 25392011
|
demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7cfaa76908e6246b051fd5725152bca28b5111a83ada18fec5848816d8bd6e7a
|
3 |
+
size 1340343
|
demo_data/test/00001_lr.dat
ADDED
Binary file (24 kB). View file
|
|
demo_data/test/00001_lr.hea
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
00001_lr 12 100 1000
|
2 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 -119 1508 0 I
|
3 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 -55 723 0 II
|
4 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 64 64758 0 III
|
5 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 86 64423 0 AVR
|
6 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 -91 1211 0 AVL
|
7 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 4 7 0 AVF
|
8 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 -69 63827 0 V1
|
9 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 -31 6999 0 V2
|
10 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 0 63759 0 V3
|
11 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 -26 61447 0 V4
|
12 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 -39 64979 0 V5
|
13 |
+
00001_lr.dat 16 1000.0(0)/mV 16 0 -79 832 0 V6
|
demo_data/test/00008_lr.dat
ADDED
Binary file (24 kB). View file
|
|
demo_data/test/00008_lr.hea
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
00008_lr 12 100 1000
|
2 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 -41 2321 0 I
|
3 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 -80 4548 0 II
|
4 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 -39 2234 0 III
|
5 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 60 62047 0 AVR
|
6 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 -1 0 0 AVL
|
7 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 -60 3352 0 AVF
|
8 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 45 232 0 V1
|
9 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 -5 65262 0 V2
|
10 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 5 63785 0 V3
|
11 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 -55 58960 0 V4
|
12 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 -70 3471 0 V5
|
13 |
+
00008_lr.dat 16 1000.0(0)/mV 16 0 -40 2065 0 V6
|
demo_data/test/00045_lr.dat
ADDED
Binary file (24 kB). View file
|
|
demo_data/test/00045_lr.hea
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
00045_lr 12 100 1000
|
2 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 -181 1318 0 I
|
3 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 -438 5652 0 II
|
4 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 -257 4356 0 III
|
5 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 310 62008 0 AVR
|
6 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 38 64012 0 AVL
|
7 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 -347 4979 0 AVF
|
8 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 121 3953 0 V1
|
9 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 51 64138 0 V2
|
10 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 82 61158 0 V3
|
11 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 -58 63682 0 V4
|
12 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 -52 65025 0 V5
|
13 |
+
00045_lr.dat 16 1000.0(0)/mV 16 0 -134 193 0 V6
|
demo_data/test/00257_lr.dat
ADDED
Binary file (24 kB). View file
|
|
demo_data/test/00257_lr.hea
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
00257_lr 12 100 1000
|
2 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 -8 8043 0 I
|
3 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 24 3049 0 II
|
4 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 32 60557 0 III
|
5 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 -9 59959 0 AVR
|
6 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 -20 6506 0 AVL
|
7 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 28 64558 0 AVF
|
8 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 -29 60014 0 V1
|
9 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 -24 64087 0 V2
|
10 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 138 1192 0 V3
|
11 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 34 65087 0 V4
|
12 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 26 65386 0 V5
|
13 |
+
00257_lr.dat 16 1000.0(0)/mV 16 0 32 59612 0 V6
|
models/CNN.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchinfo import summary
|
5 |
+
|
6 |
+
# Not in use yet
|
7 |
+
class Conv1d_layer(nn.Module):
|
8 |
+
def __init__(self, in_channel, out_channel, kernel_size) -> None:
|
9 |
+
super().__init__()
|
10 |
+
self.conv = nn.Conv1d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size)
|
11 |
+
self.batch_norm = torch.nn.BatchNorm1d(out_channel)
|
12 |
+
self.dropout = nn.Dropout1d(p=0.5)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
x= self.conv(x)
|
16 |
+
x = self.batch_norm(x)
|
17 |
+
x = self.dropout(x)
|
18 |
+
return x
|
19 |
+
|
20 |
+
class CNN(nn.Module):
|
21 |
+
def __init__(self, ecg_channels=12):
|
22 |
+
super(CNN, self).__init__()
|
23 |
+
self.name = "CNN"
|
24 |
+
self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
|
25 |
+
self.pool1 = nn.MaxPool1d(2, 2)
|
26 |
+
self.conv2 = nn.Conv1d(16, 32, 5)
|
27 |
+
self.pool2 = nn.MaxPool1d(2, 2)
|
28 |
+
self.conv3 = nn.Conv1d(32, 48, 3)
|
29 |
+
self.pool3 = nn.MaxPool1d(2, 2)
|
30 |
+
self.fc0 = nn.Linear(5856, 512)
|
31 |
+
self.fc1 = nn.Linear(512, 128)
|
32 |
+
self.fc2 = nn.Linear(128, 5)
|
33 |
+
self.activation = nn.ReLU()
|
34 |
+
def forward(self, x, notes=None):
|
35 |
+
x = self.pool1(self.activation(self.conv1(x)))
|
36 |
+
x = self.pool2(self.activation(self.conv2(x)))
|
37 |
+
x = self.pool3(self.activation(self.conv3(x)))
|
38 |
+
x = x.view(x.size(0),-1)
|
39 |
+
x = self.activation(self.fc0(x))
|
40 |
+
x = self.activation(self.fc1(x))
|
41 |
+
x = self.fc2(x)
|
42 |
+
x = x.squeeze(1)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class MMCNN_SUM(nn.Module):
|
47 |
+
def __init__(self, ecg_channels=12):
|
48 |
+
super(MMCNN_SUM, self).__init__()
|
49 |
+
# ECG processing Layers
|
50 |
+
self.name = "MMCNN_SUM"
|
51 |
+
self.conv1 = Conv1d_layer(ecg_channels, 16, 7)
|
52 |
+
self.pool1 = nn.MaxPool1d(2, 2)
|
53 |
+
self.conv2 = Conv1d_layer(16, 32, 5)
|
54 |
+
self.pool2 = nn.MaxPool1d(2, 2)
|
55 |
+
self.conv3 = Conv1d_layer(32, 48, 3)
|
56 |
+
self.pool3 = nn.MaxPool1d(2, 2)
|
57 |
+
self.fc0 = nn.Linear(5856, 512)
|
58 |
+
self.fc1 = nn.Linear(512, 128)
|
59 |
+
self.fc2 = nn.Linear(128, 5)
|
60 |
+
|
61 |
+
# Clinical Notes Processing Layers
|
62 |
+
self.fc_emb = nn.Linear(768, 128)
|
63 |
+
self.norm = nn.LayerNorm(128)
|
64 |
+
|
65 |
+
self.activation = nn.ReLU()
|
66 |
+
|
67 |
+
def forward(self, x, notes):
|
68 |
+
# ECG Processing
|
69 |
+
x = self.pool1(self.activation(self.conv1(x)))
|
70 |
+
x = self.pool2(self.activation(self.conv2(x)))
|
71 |
+
x = self.pool3(self.activation(self.conv3(x)))
|
72 |
+
x = x.view(x.size(0),-1)
|
73 |
+
x = self.activation(self.fc0(x))
|
74 |
+
x = self.activation(self.fc1(x))
|
75 |
+
|
76 |
+
# Notes Processing
|
77 |
+
notes = notes.view(notes.size(0),-1)
|
78 |
+
notes = self.activation(self.fc_emb(notes))
|
79 |
+
|
80 |
+
x = self.fc2(self.norm(x + notes))
|
81 |
+
x = x.squeeze(1)
|
82 |
+
return x
|
83 |
+
|
84 |
+
class MMCNN_CAT(nn.Module):
|
85 |
+
def __init__(self, ecg_channels=12):
|
86 |
+
super(MMCNN_CAT, self).__init__()
|
87 |
+
# ECG processing Layers
|
88 |
+
self.name = "MMCNN_CAT"
|
89 |
+
self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
|
90 |
+
self.pool1 = nn.MaxPool1d(2, 2)
|
91 |
+
self.conv2 = nn.Conv1d(16, 32, 5)
|
92 |
+
self.pool2 = nn.MaxPool1d(2, 2)
|
93 |
+
self.conv3 = nn.Conv1d(32, 48, 3)
|
94 |
+
self.pool3 = nn.MaxPool1d(2, 2)
|
95 |
+
self.fc0 = nn.Linear(5856, 512)
|
96 |
+
self.fc1 = nn.Linear(512, 128)
|
97 |
+
self.fc2 = nn.Linear(256, 5)
|
98 |
+
|
99 |
+
# Clinical Notes Processing Layers
|
100 |
+
self.fc_emb = nn.Linear(768, 128)
|
101 |
+
self.norm = nn.LayerNorm(128)
|
102 |
+
|
103 |
+
self.activation = nn.ReLU()
|
104 |
+
|
105 |
+
def forward(self, x, notes):
|
106 |
+
# ECG Processing
|
107 |
+
x = self.pool1(self.activation(self.conv1(x)))
|
108 |
+
x = self.pool2(self.activation(self.conv2(x)))
|
109 |
+
x = self.pool3(self.activation(self.conv3(x)))
|
110 |
+
x = x.view(x.size(0),-1)
|
111 |
+
x = self.activation(self.fc0(x))
|
112 |
+
x = self.activation(self.fc1(x))
|
113 |
+
|
114 |
+
# Notes Processing
|
115 |
+
notes = notes.view(notes.size(0),-1)
|
116 |
+
notes = self.activation(self.fc_emb(notes))
|
117 |
+
|
118 |
+
x = self.fc2(torch.cat((x,notes),dim=1))
|
119 |
+
x = x.squeeze(1)
|
120 |
+
return x
|
121 |
+
class MMCNN_ATT(nn.Module):
|
122 |
+
def __init__(self, ecg_channels=12):
|
123 |
+
super(MMCNN_ATT, self).__init__()
|
124 |
+
# ECG processing Layers
|
125 |
+
self.name = "MMCNN_ATT"
|
126 |
+
self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
|
127 |
+
self.pool1 = nn.MaxPool1d(2, 2)
|
128 |
+
self.conv2 = nn.Conv1d(16, 32, 5)
|
129 |
+
self.pool2 = nn.MaxPool1d(2, 2)
|
130 |
+
self.conv3 = nn.Conv1d(32, 48, 3)
|
131 |
+
self.pool3 = nn.MaxPool1d(2, 2)
|
132 |
+
self.fc0 = nn.Linear(5856, 512)
|
133 |
+
self.fc1 = nn.Linear(512, 128)
|
134 |
+
self.fc2 = nn.Linear(128, 5)
|
135 |
+
|
136 |
+
# Clinical Notes Processing Layers
|
137 |
+
self.fc_emb = nn.Linear(768, 128)
|
138 |
+
self.norm1 = nn.LayerNorm(128)
|
139 |
+
self.norm2 = nn.LayerNorm(128)
|
140 |
+
|
141 |
+
self.attention = nn.MultiheadAttention(128, 8, batch_first=True)
|
142 |
+
self.activation = nn.ReLU()
|
143 |
+
|
144 |
+
def forward(self, x, notes):
|
145 |
+
# ECG Processing
|
146 |
+
x = self.pool1(self.activation(self.conv1(x)))
|
147 |
+
x = self.pool2(self.activation(self.conv2(x)))
|
148 |
+
x = self.pool3(self.activation(self.conv3(x)))
|
149 |
+
x = x.view(x.size(0),-1)
|
150 |
+
x = self.activation(self.fc0(x))
|
151 |
+
x = self.activation(self.fc1(x))
|
152 |
+
x = self.norm1(x)
|
153 |
+
|
154 |
+
# Notes Processing
|
155 |
+
notes = notes.view(notes.size(0),-1)
|
156 |
+
notes = self.activation(self.fc_emb(notes))
|
157 |
+
notes = self.norm2(notes)
|
158 |
+
notes=notes.unsqueeze(1)
|
159 |
+
x=x.unsqueeze(1)
|
160 |
+
x,_= self.attention(notes, x, x)
|
161 |
+
x = self.fc2(x)
|
162 |
+
x = x.squeeze(1)
|
163 |
+
return x
|
164 |
+
|
165 |
+
class MMCNN_SUM_ATT(nn.Module):
|
166 |
+
def __init__(self, ecg_channels=12):
|
167 |
+
super(MMCNN_SUM_ATT, self).__init__()
|
168 |
+
# ECG processing Layers
|
169 |
+
self.name = "MMCNN_SUM_ATT"
|
170 |
+
self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
|
171 |
+
self.pool1 = nn.MaxPool1d(2, 2)
|
172 |
+
self.conv2 = nn.Conv1d(16, 32, 5)
|
173 |
+
self.pool2 = nn.MaxPool1d(2, 2)
|
174 |
+
self.conv3 = nn.Conv1d(32, 48, 3)
|
175 |
+
self.pool3 = nn.MaxPool1d(2, 2)
|
176 |
+
self.fc0 = nn.Linear(5856, 512)
|
177 |
+
self.fc1 = nn.Linear(512, 128)
|
178 |
+
self.fc2 = nn.Linear(128, 5)
|
179 |
+
|
180 |
+
# Clinical Notes Processing Layers
|
181 |
+
self.fc_emb = nn.Linear(768, 128)
|
182 |
+
self.norm = nn.LayerNorm(128)
|
183 |
+
|
184 |
+
self.attention = nn.MultiheadAttention(128, 8, batch_first=True)
|
185 |
+
self.activation = nn.ReLU()
|
186 |
+
|
187 |
+
def forward(self, x, notes):
|
188 |
+
# ECG Processing
|
189 |
+
x = self.pool1(self.activation(self.conv1(x)))
|
190 |
+
x = self.pool2(self.activation(self.conv2(x)))
|
191 |
+
x = self.pool3(self.activation(self.conv3(x)))
|
192 |
+
x = x.view(x.size(0),-1)
|
193 |
+
x = self.activation(self.fc0(x))
|
194 |
+
x = self.activation(self.fc1(x))
|
195 |
+
|
196 |
+
# Notes Processing
|
197 |
+
notes = notes.view(notes.size(0),-1)
|
198 |
+
notes = self.activation(self.fc_emb(notes))
|
199 |
+
x = self.norm(x + notes)
|
200 |
+
|
201 |
+
x=x.unsqueeze(1)
|
202 |
+
# print(x.shape)
|
203 |
+
x,_= self.attention(x, x, x)
|
204 |
+
|
205 |
+
x = self.fc2(x)
|
206 |
+
x = x.squeeze(1)
|
207 |
+
return x
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
model = CNN()
|
211 |
+
# model = Conv1d_layer(12, 16, 7)
|
212 |
+
summary(model, input_size = (1, 12, 1000))
|
213 |
+
|
models/RNN.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class RNN(nn.Module):
|
6 |
+
def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, cuda=True, device='cuda'):
|
7 |
+
super(RNN, self).__init__()
|
8 |
+
self.hidden_dim = hidden_dim
|
9 |
+
self.num_layers = num_layers
|
10 |
+
self.device = device
|
11 |
+
|
12 |
+
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim,
|
13 |
+
num_layers=self.num_layers, batch_first=True)
|
14 |
+
self.fc1 = nn.Linear(self.hidden_dim, self.hidden_dim)
|
15 |
+
self.fc2 = nn.Linear(self.hidden_dim, num_classes)
|
16 |
+
self.relu = nn.ReLU()
|
17 |
+
|
18 |
+
def forward(self, x, notes):
|
19 |
+
h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
|
20 |
+
c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
|
21 |
+
|
22 |
+
nn.init.xavier_normal_(h)
|
23 |
+
nn.init.xavier_normal_(c)
|
24 |
+
h = h.to(self.device)
|
25 |
+
c = c.to(self.device)
|
26 |
+
x = x.to(self.device)
|
27 |
+
|
28 |
+
output, _ = self.lstm(x, (h, c))
|
29 |
+
|
30 |
+
out = self.fc2(self.relu(self.fc1(output[:, -1, :])))
|
31 |
+
|
32 |
+
return out
|
33 |
+
|
34 |
+
|
35 |
+
class MMRNN(nn.ModuleList):
|
36 |
+
def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, embed_size=768, device="cuda"):
|
37 |
+
super(MMRNN, self).__init__()
|
38 |
+
self.hidden_dim = hidden_dim
|
39 |
+
self.num_layers = num_layers
|
40 |
+
self.device = device
|
41 |
+
|
42 |
+
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim,
|
43 |
+
num_layers=self.num_layers, batch_first=True)
|
44 |
+
self.fc1 = nn.Linear(self.hidden_dim, embed_size)
|
45 |
+
self.fc2 = nn.Linear(embed_size, num_classes)
|
46 |
+
|
47 |
+
self.lnorm_out = nn.LayerNorm(embed_size)
|
48 |
+
self.lnorm_embed = nn.LayerNorm(embed_size)
|
49 |
+
|
50 |
+
def forward(self, x, note):
|
51 |
+
h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
|
52 |
+
c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
|
53 |
+
|
54 |
+
nn.init.xavier_normal_(h)
|
55 |
+
nn.init.xavier_normal_(c)
|
56 |
+
h = h.to(self.device)
|
57 |
+
c = c.to(self.device)
|
58 |
+
x = x.to(self.device)
|
59 |
+
note = note.to(self.device)
|
60 |
+
|
61 |
+
output, _ = self.lstm(x, (h, c))
|
62 |
+
# Take last hidden state
|
63 |
+
out = self.fc1(output[:, -1, :])
|
64 |
+
|
65 |
+
note = self.lnorm_embed(note)
|
66 |
+
out = self.lnorm_out(out)
|
67 |
+
out = note + out
|
68 |
+
|
69 |
+
out = self.fc2(out)
|
70 |
+
|
71 |
+
return out.squeeze(1)
|
models/__pycache__/CNN.cpython-39.pyc
ADDED
Binary file (6.49 kB). View file
|
|
models/__pycache__/RNN.cpython-39.pyc
ADDED
Binary file (2.34 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.25.0
|
2 |
+
langdetect==1.0.9
|
3 |
+
matplotlib==3.6.3
|
4 |
+
numpy==1.24.2
|
5 |
+
pandas==1.5.3
|
6 |
+
PyWavelets==1.4.1
|
7 |
+
scikit_learn==1.2.1
|
8 |
+
torch==1.12.1
|
9 |
+
torchinfo==1.7.2
|
10 |
+
torchvision==0.13.1
|
11 |
+
tqdm==4.64.1
|
12 |
+
transformers==4.28.1
|
13 |
+
wfdb==4.1.0
|
utils/RNN_utils.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from tqdm.autonotebook import tqdm
|
4 |
+
import pywt
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
def display_eval(epoch, epochs, tlength, global_step, tcorrect, tsamples, t_valid_samples, average_train_loss, average_valid_loss, total_acc_val):
|
9 |
+
tqdm.write(
|
10 |
+
f'Epoch: [{epoch + 1}/{epochs}], Step [{global_step}/{epochs*tlength}] | Train Loss: {average_train_loss: .3f} \
|
11 |
+
| Train Accuracy: {tcorrect / tsamples: .3f} \
|
12 |
+
| Val Loss: {average_valid_loss: .3f} \
|
13 |
+
| Val Accuracy: {total_acc_val / t_valid_samples: .3f}')
|
14 |
+
|
15 |
+
|
16 |
+
def save_model(model, optimizer, valid_loss, epoch, path='model.pt'):
|
17 |
+
torch.save({'valid_loss': valid_loss,
|
18 |
+
'model_state_dict': model.state_dict(),
|
19 |
+
'epoch': epoch + 1,
|
20 |
+
'optimizer': optimizer.state_dict()
|
21 |
+
}, path)
|
22 |
+
tqdm.write(f'Model saved to ==> {path}')
|
23 |
+
|
24 |
+
|
25 |
+
def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'):
|
26 |
+
torch.save({'train_loss_list': train_loss_list,
|
27 |
+
'valid_loss_list': valid_loss_list,
|
28 |
+
'global_steps_list': global_steps_list,
|
29 |
+
}, path)
|
30 |
+
|
31 |
+
|
32 |
+
def plot_losses(metrics_save_name='metrics', save_dir='./'):
|
33 |
+
path = f'{save_dir}metrics_{metrics_save_name}.pt'
|
34 |
+
state = torch.load(path)
|
35 |
+
|
36 |
+
train_loss_list = state['train_loss_list']
|
37 |
+
valid_loss_list = state['valid_loss_list']
|
38 |
+
global_steps_list = state['global_steps_list']
|
39 |
+
|
40 |
+
plt.plot(global_steps_list, train_loss_list, label='Train')
|
41 |
+
plt.plot(global_steps_list, valid_loss_list, label='Valid')
|
42 |
+
plt.xlabel('Global Steps')
|
43 |
+
plt.ylabel('Loss')
|
44 |
+
plt.legend()
|
45 |
+
plt.show()
|
46 |
+
|
47 |
+
|
48 |
+
def train_RNN(epochs, train_loader, valid_loader, model, loss_fn, optimizer, eval_every=0.25, best_valid_loss=float("Inf"), device='cuda', model_save_name='', save_dir='./'):
|
49 |
+
model.train()
|
50 |
+
|
51 |
+
running_loss = 0.0
|
52 |
+
valid_running_loss = 0.0
|
53 |
+
global_step = 0
|
54 |
+
train_loss_list = []
|
55 |
+
valid_loss_list = []
|
56 |
+
global_steps_list = []
|
57 |
+
|
58 |
+
wavelet = 'db4'
|
59 |
+
level = 3
|
60 |
+
|
61 |
+
for epoch in tqdm(range(epochs)):
|
62 |
+
running_loss = 0.0
|
63 |
+
t_correct = 0
|
64 |
+
t_samples = 0
|
65 |
+
for images, labels, notes in train_loader:
|
66 |
+
optimizer.zero_grad()
|
67 |
+
|
68 |
+
coeffs = pywt.wavedec(images, wavelet, level=level, axis=1)
|
69 |
+
threshold = 0.1 * \
|
70 |
+
torch.median(torch.abs(torch.from_numpy(coeffs[-1])))
|
71 |
+
denoised_coeffs = [pywt.threshold(
|
72 |
+
data=c, mode='hard', value=threshold) for c in coeffs]
|
73 |
+
images = pywt.waverec(denoised_coeffs, wavelet, axis=1)
|
74 |
+
|
75 |
+
images = torch.tensor(images).float().to(device)
|
76 |
+
labels = labels.to(device)
|
77 |
+
notes = notes.to(device)
|
78 |
+
|
79 |
+
output = model(images, notes)
|
80 |
+
|
81 |
+
loss = loss_fn(output, labels.float())
|
82 |
+
running_loss += loss.item()*len(labels)
|
83 |
+
loss.backward()
|
84 |
+
global_step += 1*len(images)
|
85 |
+
|
86 |
+
optimizer.step()
|
87 |
+
|
88 |
+
values, indices = torch.max(output, dim=1)
|
89 |
+
t_correct += sum(1 for s, i in enumerate(indices)
|
90 |
+
if labels[s][i] == 1)
|
91 |
+
t_samples += len(indices)
|
92 |
+
|
93 |
+
if (global_step % (int(eval_every*len(train_loader.dataset)))) < train_loader.batch_size:
|
94 |
+
model.eval()
|
95 |
+
valid_running_loss = 0.0
|
96 |
+
total_acc_val = 0
|
97 |
+
with torch.no_grad():
|
98 |
+
|
99 |
+
for images, labels, notes in valid_loader:
|
100 |
+
|
101 |
+
coeffs = pywt.wavedec(
|
102 |
+
images, wavelet, level=level, axis=1)
|
103 |
+
threshold = 0.1 * \
|
104 |
+
torch.median(
|
105 |
+
torch.abs(torch.from_numpy(coeffs[-1])))
|
106 |
+
denoised_coeffs = [pywt.threshold(
|
107 |
+
data=c, mode='hard', value=threshold) for c in coeffs]
|
108 |
+
images = pywt.waverec(denoised_coeffs, wavelet, axis=1)
|
109 |
+
|
110 |
+
images = torch.tensor(images).float().to(device)
|
111 |
+
labels = labels.to(device)
|
112 |
+
notes = notes.to(device)
|
113 |
+
output = model(images, notes)
|
114 |
+
|
115 |
+
loss = loss_fn(output, labels.float()).item()
|
116 |
+
valid_running_loss += loss*len(images)
|
117 |
+
values, indices = torch.max(output, dim=1)
|
118 |
+
total_acc_val += sum(1 for s,
|
119 |
+
i in enumerate(indices) if labels[s][i] == 1)
|
120 |
+
|
121 |
+
# evaluation
|
122 |
+
average_train_loss = running_loss / t_samples
|
123 |
+
average_valid_loss = valid_running_loss / \
|
124 |
+
len(valid_loader.dataset)
|
125 |
+
train_loss_list.append(average_train_loss)
|
126 |
+
valid_loss_list.append(average_valid_loss)
|
127 |
+
global_steps_list.append(global_step)
|
128 |
+
|
129 |
+
display_eval(epoch, epochs, len(train_loader.dataset), global_step, t_correct, t_samples, len(
|
130 |
+
valid_loader.dataset), average_train_loss, average_valid_loss, total_acc_val)
|
131 |
+
|
132 |
+
# resetting running values
|
133 |
+
model.train()
|
134 |
+
|
135 |
+
if best_valid_loss > average_valid_loss:
|
136 |
+
best_valid_loss = average_valid_loss
|
137 |
+
save_model(model, optimizer, best_valid_loss, epoch,
|
138 |
+
path=f'{save_dir}model_{model_save_name}.pt')
|
139 |
+
save_metrics(train_loss_list, valid_loss_list,
|
140 |
+
global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt')
|
141 |
+
|
142 |
+
save_metrics(train_loss_list, valid_loss_list, global_steps_list,
|
143 |
+
path=f'{save_dir}metrics_{model_save_name}.pt')
|
144 |
+
print("Training complete.")
|
145 |
+
return model
|
146 |
+
|
147 |
+
|
148 |
+
def evaluate_RNN(model, test_loader, device="cuda"):
|
149 |
+
model.eval()
|
150 |
+
y_pred = []
|
151 |
+
y_true = []
|
152 |
+
|
153 |
+
wavelet = 'db4'
|
154 |
+
level = 3
|
155 |
+
|
156 |
+
total_acc_test = 0
|
157 |
+
with torch.no_grad():
|
158 |
+
for images, labels, notes in test_loader:
|
159 |
+
coeffs = pywt.wavedec(images, wavelet, level=level, axis=1)
|
160 |
+
threshold = 0.1 * \
|
161 |
+
torch.median(torch.abs(torch.from_numpy(coeffs[-1])))
|
162 |
+
denoised_coeffs = [pywt.threshold(
|
163 |
+
data=c, mode='hard', value=threshold) for c in coeffs]
|
164 |
+
images = pywt.waverec(denoised_coeffs, wavelet, axis=1)
|
165 |
+
|
166 |
+
images = torch.tensor(images).float().to(device)
|
167 |
+
labels = labels.to(device)
|
168 |
+
notes = notes.to(device)
|
169 |
+
output = model(images, notes)
|
170 |
+
|
171 |
+
values, indices = torch.max(output, dim=1)
|
172 |
+
y_pred.extend(indices.tolist())
|
173 |
+
y_true.extend(labels.tolist())
|
174 |
+
total_acc_test += sum(1 for s,
|
175 |
+
i in enumerate(indices) if labels[s][i] == 1)
|
176 |
+
|
177 |
+
test_accuracy = total_acc_test / len(test_loader.dataset)
|
178 |
+
print(f'Test Accuracy: {test_accuracy: .3f}')
|
179 |
+
|
180 |
+
return test_accuracy
|
181 |
+
|
182 |
+
|
183 |
+
def rename_with_acc(save_name, save_dir, acc):
|
184 |
+
acc = round(acc*100)
|
185 |
+
# Rename model
|
186 |
+
new_model_name = f'{save_dir}model_{save_name}_acc_{acc}.pt'
|
187 |
+
new_metrics_name = f'{save_dir}metrics_{save_name}_acc_{acc}.pt'
|
188 |
+
|
189 |
+
if os.path.isfile(new_model_name):
|
190 |
+
os.remove(new_model_name)
|
191 |
+
if os.path.isfile(new_metrics_name):
|
192 |
+
os.remove(new_metrics_name)
|
193 |
+
|
194 |
+
os.rename(f'{save_dir}model_{save_name}.pt',
|
195 |
+
f'{save_dir}model_{save_name}_acc_{acc}.pt')
|
196 |
+
# Rename metrics
|
197 |
+
os.rename(f'{save_dir}metrics_{save_name}.pt',
|
198 |
+
f'{save_dir}metrics_{save_name}_acc_{acc}.pt')
|
utils/__pycache__/RNN_utils.cpython-39.pyc
ADDED
Binary file (5.71 kB). View file
|
|
utils/__pycache__/helper_functions.cpython-39.pyc
ADDED
Binary file (2.91 kB). View file
|
|
utils/__pycache__/trainer.cpython-39.pyc
ADDED
Binary file (3.28 kB). View file
|
|
utils/helper_functions.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def define_optimizer(model, lr, alpha):
|
4 |
+
# Define optimizer
|
5 |
+
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, alpha=alpha)
|
6 |
+
optimizer.zero_grad()
|
7 |
+
return optimizer
|
8 |
+
|
9 |
+
def tuple_of_tensors_to_tensor(tuple_of_tensors):
|
10 |
+
return torch.stack(list(tuple_of_tensors), dim=0)
|
11 |
+
|
12 |
+
def predict(model, inputs, notes, device):
|
13 |
+
outputs = model.forward(inputs, notes)
|
14 |
+
predicted = torch.sigmoid(outputs)
|
15 |
+
predicted = (predicted>0.5).float()
|
16 |
+
return outputs, predicted
|
17 |
+
|
18 |
+
def display_train(epoch, num_epochs, i, model, correct, total, loss, train_loader, valid_loader, device):
|
19 |
+
print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Train Loss: {loss.item():.4f}')
|
20 |
+
train_accuracy = correct/total
|
21 |
+
print(f'Epoch [{epoch+1}/{num_epochs}], Train Accuracy: {train_accuracy:.4f}')
|
22 |
+
valid_loss, valid_accuracy = eval_valid(model, valid_loader, epoch, num_epochs, device)
|
23 |
+
return train_accuracy, valid_accuracy, valid_loss
|
24 |
+
|
25 |
+
def eval_valid(model, valid_loader, epoch, num_epochs, device):
|
26 |
+
# Compute model train accuracy on test after all samples have been seen using test samples
|
27 |
+
model.eval()
|
28 |
+
with torch.no_grad():
|
29 |
+
correct = 0
|
30 |
+
total = 0
|
31 |
+
running_loss = 0
|
32 |
+
for inputs, labels, notes in valid_loader:
|
33 |
+
# Get images and labels from test loader
|
34 |
+
inputs = inputs.transpose(1,2).float().to(device)
|
35 |
+
labels = labels.float().to(device)
|
36 |
+
notes = notes.to(device)
|
37 |
+
|
38 |
+
# Forward pass and predict class using max
|
39 |
+
# outputs = model(inputs)
|
40 |
+
outputs, predicted = predict(model, inputs, notes, device) #torch.max(outputs.data, 1)
|
41 |
+
loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, labels)
|
42 |
+
running_loss += loss.item()*len(labels)
|
43 |
+
|
44 |
+
# Check if predicted class matches label and count numbler of correct predictions
|
45 |
+
total += labels.size(0)
|
46 |
+
#TODO: change acc criteria
|
47 |
+
# correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item() # (predicted == labels).sum().item()
|
48 |
+
values, indices = torch.max(outputs,dim=1)
|
49 |
+
correct += sum(1 for s, i in enumerate(indices)
|
50 |
+
if labels[s][i] == 1)
|
51 |
+
|
52 |
+
# Compute final accuracy and display
|
53 |
+
valid_accuracy = correct/total
|
54 |
+
validation_loss = running_loss/total
|
55 |
+
print(f'Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {valid_accuracy:.4f}, Validation Loss: {validation_loss:.4f}')
|
56 |
+
return validation_loss, valid_accuracy
|
57 |
+
|
58 |
+
|
59 |
+
def eval_test(model, test_loader, device):
|
60 |
+
# Compute model test accuracy on test after training
|
61 |
+
model.eval()
|
62 |
+
with torch.no_grad():
|
63 |
+
correct = 0
|
64 |
+
total = 0
|
65 |
+
for inputs, labels, notes in test_loader:
|
66 |
+
# Get images and labels from test loader
|
67 |
+
inputs = inputs.transpose(1,2).float().to(device)
|
68 |
+
labels = labels.float().to(device)
|
69 |
+
notes = notes.to(device)
|
70 |
+
|
71 |
+
# Forward pass and predict class using max
|
72 |
+
# outputs = model(inputs)
|
73 |
+
outputs, predicted = predict(model, inputs, notes, device)#torch.max(outputs.data, 1)
|
74 |
+
|
75 |
+
# Check if predicted class matches label and count numbler of correct predictions
|
76 |
+
total += labels.size(0)
|
77 |
+
#TODO: change acc criteria
|
78 |
+
# correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item() # (predicted == labels).sum().item()
|
79 |
+
values, indices = torch.max(outputs,dim=1)
|
80 |
+
correct += sum(1 for s, i in enumerate(indices)
|
81 |
+
if labels[s][i] == 1)
|
82 |
+
|
83 |
+
# Compute final accuracy and display
|
84 |
+
test_accuracy = correct/total
|
85 |
+
print(f'Ended Training, Test Accuracy: {test_accuracy:.4f}')
|
86 |
+
return test_accuracy
|
utils/trainer.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .helper_functions import define_optimizer, predict, display_train, eval_test
|
3 |
+
from tqdm import tqdm
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
|
7 |
+
def save_model(model, optimizer, valid_loss, epoch, path='model.pt'):
|
8 |
+
torch.save({'valid_loss': valid_loss,
|
9 |
+
'model_state_dict': model.state_dict(),
|
10 |
+
'epoch': epoch + 1,
|
11 |
+
'optimizer': optimizer.state_dict()
|
12 |
+
}, path)
|
13 |
+
tqdm.write(f'Model saved to ==> {path}')
|
14 |
+
|
15 |
+
|
16 |
+
def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'):
|
17 |
+
torch.save({'train_loss_list': train_loss_list,
|
18 |
+
'valid_loss_list': valid_loss_list,
|
19 |
+
'global_steps_list': global_steps_list,
|
20 |
+
}, path)
|
21 |
+
|
22 |
+
def plot_losses(metrics_save_name='metrics', save_dir='./'):
|
23 |
+
path = f'{save_dir}metrics_{metrics_save_name}.pt'
|
24 |
+
state = torch.load(path)
|
25 |
+
|
26 |
+
train_loss_list = state['train_loss_list']
|
27 |
+
valid_loss_list = state['valid_loss_list']
|
28 |
+
global_steps_list = state['global_steps_list']
|
29 |
+
|
30 |
+
plt.plot(global_steps_list, train_loss_list, label='Train')
|
31 |
+
plt.plot(global_steps_list, valid_loss_list, label='Valid')
|
32 |
+
plt.xlabel('Global Steps')
|
33 |
+
plt.ylabel('Loss')
|
34 |
+
plt.legend()
|
35 |
+
plt.show()
|
36 |
+
|
37 |
+
def trainer(model, train_loader, test_loader, valid_loader, num_epochs = 10, lr = 0.01, alpha = 0.99, eval_interval = 10, model_save_name='', save_dir='./'):
|
38 |
+
|
39 |
+
# Use GPU if available, else use CPU
|
40 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
+
print(device)
|
42 |
+
|
43 |
+
|
44 |
+
# History for train acc, test acc
|
45 |
+
train_accs = []
|
46 |
+
valid_accs = []
|
47 |
+
global_step = 0
|
48 |
+
train_loss_list = []
|
49 |
+
valid_loss_list = []
|
50 |
+
global_steps_list = []
|
51 |
+
best_valid_loss = float("inf")
|
52 |
+
|
53 |
+
|
54 |
+
# Define optimizer
|
55 |
+
optimizer = define_optimizer(model, lr, alpha)
|
56 |
+
|
57 |
+
# Training model
|
58 |
+
for epoch in range(num_epochs):
|
59 |
+
# Go trough all samples in train dataset
|
60 |
+
model.train()
|
61 |
+
running_loss = 0
|
62 |
+
correct = 0
|
63 |
+
total = 0
|
64 |
+
for i, (inputs, labels, notes) in enumerate(train_loader):
|
65 |
+
# Get from dataloader and send to device
|
66 |
+
inputs = inputs.transpose(1,2).float().to(device)
|
67 |
+
# print(labels.shape)
|
68 |
+
labels = labels.float().to(device)
|
69 |
+
notes = notes.to(device)
|
70 |
+
# print(labels.shape)
|
71 |
+
|
72 |
+
|
73 |
+
# Forward pass
|
74 |
+
outputs, predicted = predict(model, inputs, notes, device)
|
75 |
+
# print(predicted.shape, labels.shape)
|
76 |
+
|
77 |
+
# Check if predicted class matches label and count numbler of correct predictions
|
78 |
+
total += labels.size(0)
|
79 |
+
#TODO: change acc criteria
|
80 |
+
# correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item() #(predicted == labels).sum().item()
|
81 |
+
values, indices = torch.max(outputs,dim=1)
|
82 |
+
correct += sum(1 for s, i in enumerate(indices)
|
83 |
+
if labels[s][i] == 1)
|
84 |
+
# Compute loss
|
85 |
+
# we use outputs before softmax function to the cross_entropy loss
|
86 |
+
loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, labels)
|
87 |
+
running_loss += loss.item()*len(labels)
|
88 |
+
global_step += 1*len(inputs)
|
89 |
+
# Backward and optimize
|
90 |
+
loss.backward()
|
91 |
+
optimizer.step()
|
92 |
+
optimizer.zero_grad()
|
93 |
+
|
94 |
+
# Display losses over iterations and evaluate on validation set
|
95 |
+
if (i+1) % eval_interval == 0:
|
96 |
+
train_accuracy, valid_accuracy, valid_loss = display_train(epoch, num_epochs, i, model, \
|
97 |
+
correct, total, loss, \
|
98 |
+
train_loader, valid_loader, device)
|
99 |
+
|
100 |
+
average_train_loss = running_loss / total
|
101 |
+
# average_valid_loss = valid_loss
|
102 |
+
train_loss_list.append(average_train_loss)
|
103 |
+
valid_loss_list.append(valid_loss)
|
104 |
+
global_steps_list.append(global_step)
|
105 |
+
|
106 |
+
if valid_loss < best_valid_loss:
|
107 |
+
best_valid_loss = valid_loss
|
108 |
+
save_model(model, optimizer, best_valid_loss, epoch, path=f'{save_dir}model_{model_save_name}.pt')
|
109 |
+
save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt')
|
110 |
+
# torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_best_lr_{lr}.pt')
|
111 |
+
|
112 |
+
|
113 |
+
if(len(train_loader)%eval_interval!=0):
|
114 |
+
train_accuracy, valid_accuracy, valid_loss = display_train(epoch, num_epochs, i, model, \
|
115 |
+
correct, total, loss, \
|
116 |
+
train_loader, valid_loader, device)
|
117 |
+
|
118 |
+
average_train_loss = running_loss / total
|
119 |
+
# average_valid_loss = valid_loss/len(valid_loader.dataset)
|
120 |
+
train_loss_list.append(average_train_loss)
|
121 |
+
valid_loss_list.append(valid_loss)
|
122 |
+
global_steps_list.append(global_step)
|
123 |
+
|
124 |
+
if valid_loss < best_valid_loss:
|
125 |
+
best_valid_loss = valid_loss
|
126 |
+
save_model(model, optimizer, best_valid_loss, epoch, path=f'{save_dir}model_{model_save_name}.pt')
|
127 |
+
save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt')
|
128 |
+
# torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_best_lr_{lr}.pt')
|
129 |
+
# Append accuracies to list at the end of each iteration
|
130 |
+
train_accs.append(train_accuracy)
|
131 |
+
valid_accs.append(valid_accuracy)
|
132 |
+
# torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_epoch_{epoch}_lr_{lr}.pt')
|
133 |
+
save_metrics(train_loss_list, valid_loss_list, global_steps_list,
|
134 |
+
path=f'{save_dir}metrics_{model_save_name}.pt')
|
135 |
+
# Load best_model
|
136 |
+
checkpoint = torch.load(f'{save_dir}model_{model_save_name}.pt')
|
137 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
138 |
+
# Evaluate on test after training has completed
|
139 |
+
test_acc = eval_test(model, test_loader, device)
|
140 |
+
# Return
|
141 |
+
return train_accs, valid_accs, test_acc
|