|
|
|
import mne |
|
import streamlit as st |
|
import matplotlib.pyplot as plt |
|
|
|
from braindecode import EEGClassifier |
|
from braindecode.models import Deep4Net,ShallowFBCSPNet,EEGNetv4, TCN |
|
from braindecode.training.losses import CroppedLoss |
|
|
|
import torch |
|
import numpy as np |
|
|
|
def set_button_state(output,col): |
|
|
|
|
|
|
|
|
|
st.session_state.output = output |
|
|
|
|
|
if st.session_state.output == 0: |
|
button_color = "green" |
|
button_text = "Normal" |
|
elif st.session_state.output == 1: |
|
button_color = "red" |
|
button_text = "Abnormal" |
|
|
|
|
|
|
|
else: |
|
button_color = "gray" |
|
button_text = "Unknown" |
|
|
|
|
|
col.markdown(f""" |
|
<style> |
|
.custom-button {{ |
|
background-color: {button_color}; |
|
color: black; |
|
padding: 10px 20px; |
|
border: none; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
}} |
|
</style> |
|
<button class="custom-button">Output: {button_text}</button> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
def predict(raw,clf): |
|
x = np.expand_dims(raw.get_data()[:21, :6000], axis=0) |
|
output = clf.predict(x) |
|
return output |
|
|
|
|
|
def build_model(model_name, n_classes, n_chans, input_window_samples, drop_prob=0.5, lr=0.01): |
|
n_start_chans = 25 |
|
final_conv_length = 1 |
|
n_chan_factor = 2 |
|
stride_before_pool = True |
|
|
|
model = Deep4Net( |
|
n_chans, n_classes, |
|
n_filters_time=n_start_chans, |
|
n_filters_spat=n_start_chans, |
|
input_window_samples=input_window_samples, |
|
n_filters_2=int(n_start_chans * n_chan_factor), |
|
n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)), |
|
n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)), |
|
final_conv_length=final_conv_length, |
|
stride_before_pool=stride_before_pool, |
|
drop_prob=drop_prob) |
|
|
|
clf = EEGClassifier( |
|
model, |
|
cropped=True, |
|
criterion=CroppedLoss, |
|
|
|
criterion__loss_function=torch.nn.functional.nll_loss, |
|
optimizer=torch.optim.AdamW, |
|
optimizer__lr=lr, |
|
iterator_train__shuffle=False, |
|
|
|
|
|
callbacks=[ |
|
|
|
|
|
|
|
|
|
], |
|
|
|
) |
|
clf.initialize() |
|
pt_path = './Deep4Net_trained_tuh_scaling_wN_WAug_DefArgs_index8_number2700_state_dict_100.pt' |
|
clf.load_params(f_params=pt_path) |
|
|
|
return clf |
|
|
|
|
|
def preprocessing_and_plotting(raw): |
|
fig = raw.plot(duration=10, scalings='auto',remove_dc=True,show_scrollbars=False) |
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_file(edf_file): |
|
|
|
bytes_data = edf_file.getvalue() |
|
|
|
with open('edf_file.edf', "wb") as f: |
|
|
|
f.write(bytes_data) |
|
|
|
raw = mne.io.read_raw_edf('edf_file.edf') |
|
st.write(f"Loaded {edf_file.name} with {raw.info['nchan']} channels") |
|
return raw |
|
|