Mohammad Javad Darvishi
commited on
Commit
•
13f97c3
1
Parent(s):
db8f37b
'first working version of the demo'
Browse files
app.py
CHANGED
@@ -3,17 +3,35 @@ import streamlit as st
|
|
3 |
import mne
|
4 |
import matplotlib.pyplot as plt
|
5 |
import os
|
|
|
|
|
|
|
6 |
from misc import *
|
7 |
|
|
|
|
|
|
|
|
|
8 |
|
|
|
9 |
# Load the edf file
|
10 |
-
edf_file =
|
|
|
|
|
11 |
|
12 |
|
13 |
if edf_file is not None:
|
14 |
-
|
15 |
# Read the file
|
16 |
raw = read_file(edf_file)
|
17 |
|
18 |
# Preprocess and plot the data
|
19 |
preprocessing_and_plotting(raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import mne
|
4 |
import matplotlib.pyplot as plt
|
5 |
import os
|
6 |
+
import streamlit as st
|
7 |
+
import random
|
8 |
+
|
9 |
from misc import *
|
10 |
|
11 |
+
import streamlit as st
|
12 |
+
|
13 |
+
# Create two columns with st.columns (new way)
|
14 |
+
col1, col2 = st.columns(2)
|
15 |
|
16 |
+
# Create the upload button in the first column
|
17 |
# Load the edf file
|
18 |
+
edf_file = col1.file_uploader("Upload an EEG edf file", type="edf")
|
19 |
+
# Create the result placeholder button in the second column
|
20 |
+
col2.button('Result:')
|
21 |
|
22 |
|
23 |
if edf_file is not None:
|
24 |
+
|
25 |
# Read the file
|
26 |
raw = read_file(edf_file)
|
27 |
|
28 |
# Preprocess and plot the data
|
29 |
preprocessing_and_plotting(raw)
|
30 |
+
|
31 |
+
# Build the model
|
32 |
+
clf = build_model(model_name='deep4net', n_classes=2, n_chans=21, input_window_samples=6000)
|
33 |
+
|
34 |
+
output = predict(raw,clf)
|
35 |
+
|
36 |
+
# # Print the output
|
37 |
+
set_button_state (output,col2)
|
misc.py
CHANGED
@@ -3,7 +3,99 @@ import mne
|
|
3 |
import streamlit as st
|
4 |
import matplotlib.pyplot as plt
|
5 |
|
|
|
|
|
|
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
def preprocessing_and_plotting(raw):
|
9 |
# Select the first channel
|
|
|
3 |
import streamlit as st
|
4 |
import matplotlib.pyplot as plt
|
5 |
|
6 |
+
from braindecode import EEGClassifier
|
7 |
+
from braindecode.models import Deep4Net,ShallowFBCSPNet,EEGNetv4, TCN
|
8 |
+
from braindecode.training.losses import CroppedLoss
|
9 |
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
def set_button_state(output,col):
|
14 |
+
# Generate a random output value of 0 or 1
|
15 |
+
# output = 2023 #random.randint(0, 1)
|
16 |
+
|
17 |
+
# Store the output value in session state
|
18 |
+
st.session_state.output = output
|
19 |
+
|
20 |
+
# Define the button color and text based on the output value
|
21 |
+
if st.session_state.output == 0:
|
22 |
+
button_color = "green"
|
23 |
+
button_text = "Normal"
|
24 |
+
elif st.session_state.output == 1:
|
25 |
+
button_color = "red"
|
26 |
+
button_text = "Abnormal"
|
27 |
+
# elif st.session_state.output == 3:
|
28 |
+
# button_color = "yellow"
|
29 |
+
# button_text = "Waiting"
|
30 |
+
else:
|
31 |
+
button_color = "gray"
|
32 |
+
button_text = "Unknown"
|
33 |
+
|
34 |
+
# Create a custom HTML button with CSS styling
|
35 |
+
col.markdown(f"""
|
36 |
+
<style>
|
37 |
+
.custom-button {{
|
38 |
+
background-color: {button_color};
|
39 |
+
color: black;
|
40 |
+
padding: 10px 20px;
|
41 |
+
border: none;
|
42 |
+
border-radius: 5px;
|
43 |
+
cursor: pointer;
|
44 |
+
}}
|
45 |
+
</style>
|
46 |
+
<button class="custom-button">Output: {button_text}</button>
|
47 |
+
""", unsafe_allow_html=True)
|
48 |
+
|
49 |
+
|
50 |
+
def predict(raw,clf):
|
51 |
+
x = np.expand_dims(raw.get_data()[:21, :6000], axis=0)
|
52 |
+
output = clf.predict(x)
|
53 |
+
return output
|
54 |
+
|
55 |
+
|
56 |
+
def build_model(model_name, n_classes, n_chans, input_window_samples, drop_prob=0.5, lr=0.01):#, weight_decay, batch_size, n_epochs, wandb_run, checkpoint, optimizer__param_groups, window_train_set, window_val):
|
57 |
+
n_start_chans = 25
|
58 |
+
final_conv_length = 1
|
59 |
+
n_chan_factor = 2
|
60 |
+
stride_before_pool = True
|
61 |
+
# input_window_samples =6000
|
62 |
+
model = Deep4Net(
|
63 |
+
n_chans, n_classes,
|
64 |
+
n_filters_time=n_start_chans,
|
65 |
+
n_filters_spat=n_start_chans,
|
66 |
+
input_window_samples=input_window_samples,
|
67 |
+
n_filters_2=int(n_start_chans * n_chan_factor),
|
68 |
+
n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
|
69 |
+
n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
|
70 |
+
final_conv_length=final_conv_length,
|
71 |
+
stride_before_pool=stride_before_pool,
|
72 |
+
drop_prob=drop_prob)
|
73 |
+
|
74 |
+
clf = EEGClassifier(
|
75 |
+
model,
|
76 |
+
cropped=True,
|
77 |
+
criterion=CroppedLoss,
|
78 |
+
# criterion=CroppedLoss_sd,
|
79 |
+
criterion__loss_function=torch.nn.functional.nll_loss,
|
80 |
+
optimizer=torch.optim.AdamW,
|
81 |
+
optimizer__lr=lr,
|
82 |
+
iterator_train__shuffle=False,
|
83 |
+
# iterator_train__sampler = ImbalancedDatasetSampler(window_train_set, labels=window_train_set.get_metadata().target),
|
84 |
+
# batch_size=batch_size,
|
85 |
+
callbacks=[
|
86 |
+
# EarlyStopping(patience=5),
|
87 |
+
# StochasticWeightAveraging(swa_utils, swa_start=1, verbose=1, swa_lr=lr),
|
88 |
+
# "accuracy", "balanced_accuracy","f1",("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
|
89 |
+
# checkpoint,
|
90 |
+
], #"accuracy",
|
91 |
+
# device='cuda'
|
92 |
+
)
|
93 |
+
clf.initialize()
|
94 |
+
pt_path = './Deep4Net_trained_tuh_scaling_wN_WAug_DefArgs_index8_number2700_state_dict_100.pt'
|
95 |
+
clf.load_params(f_params=pt_path)
|
96 |
+
|
97 |
+
return clf
|
98 |
+
|
99 |
|
100 |
def preprocessing_and_plotting(raw):
|
101 |
# Select the first channel
|