NCU / app.py
Łukasz Furman
update app.py
a59bdc5
from cProfile import run
import streamlit as st
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from complexRadar import ComplexRadar
import math
from zipfile import ZipFile
from glob import glob
import os
from BrainPulse import (dataset,
vector_space,
distance_matrix,
recurrence_quantification_analysis,
features_space,
plot)
# path
path = "./mne_data"
path2 = "./RPs"
# Remove the specified
# file path
try:
os.remove(path)
print("% s removed successfully" % path)
except:
pass
path = "./mne_data"
os.makedirs(path, exist_ok = True)
path1 = "./RPs"
os.makedirs(path1, exist_ok = True)
def run_computation(t_start, t_end, selected_subject, fir_filter, electrode_name, cut_freq, win_len, n_fft, percentile, run_list, options):
epochs, raw = dataset.eegbci_data(tmin=t_start, tmax=t_end,
subject=selected_subject,
filter_range=fir_filter,run_list=run_list)
s_rate = epochs.info['sfreq']
electrode_index = epochs.ch_names.index(electrode_name)
electrode_open = epochs.get_data()[0][electrode_index]
electrode_close = epochs.get_data()[1][electrode_index]
stft_open = vector_space.compute_stft(electrode_open,
n_fft=n_fft, win_len=win_len,
s_rate=epochs.info['sfreq'],
cut_freq=cut_freq)
stft_close = vector_space.compute_stft(electrode_close,
n_fft=n_fft, win_len=win_len,
s_rate=epochs.info['sfreq'],
cut_freq=cut_freq)
del raw
del electrode_open, electrode_close
# matrix_open = distance_matrix.EuclideanPyRQA_RP_stft(stft_open)
# matrix_close = distance_matrix.EuclideanPyRQA_RP_stft(stft_close)
matrix_open = distance_matrix.EuclideanPyRQA_RP_stft_cpu(stft_open)
matrix_close = distance_matrix.EuclideanPyRQA_RP_stft_cpu(stft_close)
nbr_open = np.percentile(matrix_open, percentile)
nbr_close = np.percentile(matrix_close, percentile)
matrix_open_binary = distance_matrix.set_epsilon(matrix_open,nbr_open)
matrix_close_binary = distance_matrix.set_epsilon(matrix_close,nbr_close)
del matrix_open, matrix_close
# matrix_open_to_plot = matrix_open_binary
# matrix_closed_to_plot = matrix_close_binary
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(16,8),dpi=200)
ax1.imshow(matrix_open_binary, cmap='Greys', origin='lower') #cividis
ax1.set_xticks(np.linspace(0, matrix_open_binary.shape[0] , ax1.get_xticks().shape[0]))
ax1.set_yticks(np.linspace(0, matrix_open_binary.shape[0] , ax1.get_xticks().shape[0]))
ax1.set_xticklabels([str(np.around(x,decimals=0)) for x in np.linspace(0, matrix_open_binary.shape[0] / s_rate, ax1.get_xticks().shape[0])])
ax1.set_yticklabels([str(np.around(x, decimals=0)) for x in np.linspace(0, matrix_open_binary.shape[0] / s_rate, ax1.get_yticks().shape[0])])
ax1.set_title(options[0]+' window size = 240 samples, ε = '+str(np.round(nbr_open,4)))
ax1.set_xlabel('time (s)')
ax1.set_ylabel('time (s)')
ax2.imshow(matrix_close_binary, cmap='Greys', origin='lower')
ax2.set_xticks(np.linspace(0, matrix_close_binary.shape[0] , ax1.get_xticks().shape[0]))
ax2.set_yticks(np.linspace(0, matrix_close_binary.shape[0] , ax1.get_xticks().shape[0]))
ax2.set_xticklabels([str(np.around(x,decimals=0)) for x in np.linspace(0, matrix_close_binary.shape[0] / s_rate, ax1.get_xticks().shape[0])])
ax2.set_yticklabels([str(np.around(x, decimals=0)) for x in np.linspace(0, matrix_close_binary.shape[0] / s_rate, ax2.get_yticks().shape[0])])
ax2.set_title(options[1]+' window size = 240 samples, ε = '+str(np.round(nbr_close,4)))
ax2.set_xlabel('time (s)')
ax2.set_ylabel('time (s)')
return fig, matrix_open_binary, matrix_close_binary, epochs, stft_open, stft_close
def plot_rqa(matrix_open_binary, matrix_close_binary, min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len,options):
categories = ['RR', 'DET', 'L', 'Lmax', 'DIV', 'Lentr', 'DET_RR', 'LAM', 'V', 'Vmax', 'Ventr', 'LAM_DET', 'W', 'Wmax', 'Wentr', 'TT']
result_rqa_open = recurrence_quantification_analysis.get_results(matrix_open_binary,min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len)
result_rqa_closed = recurrence_quantification_analysis.get_results(matrix_close_binary,min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len)
data = pd.DataFrame([result_rqa_open,result_rqa_closed], columns=categories)
data = data.drop(['RR', 'DIV', 'Lmax'],axis=1)
# print(data)
min_max_per_variable = data.describe().T[['min', 'max']]
min_max_per_variable['min'] = min_max_per_variable['min'].apply(lambda x: int(x))
min_max_per_variable['max'] = min_max_per_variable['max'].apply(lambda x: math.ceil(x))
# print(min_max_per_variable)
variables = data.columns
ranges = list(min_max_per_variable.itertuples(index=False, name=None))
format_cfg = {
#'axes_args':{'facecolor':'#84A8CD'},
'rad_ln_args': {'visible':True, 'linestyle':'dotted'},
'angle_ln_args':{'linestyle':'dotted'},
'outer_ring': {'visible':True, 'linestyle':'dotted'},
'rgrid_tick_lbls_args': {'fontsize':6},
'theta_tick_lbls': {'fontsize':9, 'backgroundcolor':'#355C7D', 'color':'#FFFFFF'},
'theta_tick_lbls_pad':3
}
fig = plt.figure(figsize=(5,3),dpi=100)
radar = ComplexRadar(fig, variables, ranges,n_ring_levels=3 ,show_scales=True, format_cfg=format_cfg)
custom_colors = ['#F67280', '#6C5B7B', '#355C7D']
k=0
for g,c in zip(data.index, custom_colors):
# radar.plot(data.loc[g].values, label=f"condition {g}", color=c, marker='o')
radar.plot(data.loc[g].values, label=options[k], color=c, marker='o')
radar.fill(data.loc[g].values, alpha=0.5, color=c)
k+=1
radar.use_legend(loc='upper left', bbox_to_anchor=(-0.4, 1.1), fontsize = 'xx-small') #, bbox_to_anchor=(0.15, -0.25),ncol=radar.plot_counter
return fig
def waterfall_spectrum(stft1, stft2, s_rate, cut_freq, options):
fig = plt.figure(figsize=(14, 12), dpi=150)
grid = plt.GridSpec(8, 8, hspace=0.0, wspace=3.5)
spectrogram1 = fig.add_subplot(grid[0:3, 0:4])
spectrogram2 = fig.add_subplot(grid[0:3, 4:])
spectrogram1.pcolormesh(stft1.T,cmap='viridis')
spectrogram1.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft1.shape[0], 5)))
spectrogram1.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, stft1.shape[0] / s_rate, spectrogram1.get_xticks().shape[0])])
spectrogram1.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft1.shape[1], 5)))
spectrogram1.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
spectrogram1.set_ylabel('Freq (Hz)', )
spectrogram1.set_xlabel('Time (s)', )
spectrogram1.set_title(options[0] + ' Spectrogram', )
spectrogram2.pcolormesh(stft2.T,cmap='viridis')
spectrogram2.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft2.shape[0], 5)))
spectrogram2.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, stft2.shape[0] / s_rate, spectrogram2.get_xticks().shape[0])])
spectrogram2.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft2.shape[1], 5)))
spectrogram2.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
spectrogram2.set_ylabel('Freq (Hz)', )
spectrogram2.set_xlabel('Time (s)', )
spectrogram2.set_title(options[1] +' Spectrogram', )
return fig
def save(matrix_open_binary, matrix_close_binary):
file_name_open = './RPs/subject-'+str(selected_subject)+'_electrode-'+electrode_name+'_percentile-'+str(percentile)+'_run-open_binary.npy'
np.save(file_name_open, np.asarray(matrix_close_binary, dtype=np.ubyte))
file_name_close = './RPs/subject-'+str(selected_subject)+'_electrode-'+electrode_name+'_percentile-'+str(percentile)+'_run-close_binary.npy'
np.save(file_name_close, np.asarray(matrix_close_binary, dtype=np.ubyte))
def download():
file_paths = glob('./RPs/*')
with ZipFile('download.zip','w') as zip:
for file in file_paths:
# writing each file one by one
zip.write(file)
return open('download.zip', 'rb')
# ---------------Settings--------------------
st.set_page_config(layout="wide")
st.title('BrainPulse Playground')
sidebar = st.sidebar
selected_subject = sidebar.slider('Select Subject', 0, 100, 25)
electrode_name = sidebar.selectbox(
'Select Electrode',
('FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FT8', 'T7', 'T8', 'T9', 'T10', 'TP7', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2', 'Iz'))
t_start, t_end = sidebar.slider(
'Select a time range in seconds',
0, 60, (0, 30), step=1)
f1, f2 = sidebar.slider(
'Select a FIR filter range',
0, 60, (2, 50), step=1)
fir_filter = [f1, f2]
cut_freq = f2
win_len = sidebar.slider('FFT window size', 0, 512, 170, step=1)
n_fft = sidebar.slider('numer of FFT bins', 0, 1024, 512, step=1)
min_vert_line_len = sidebar.slider('Minimum vertical line length', 0, 250, 2, step=1)
min_diagonal_line_len = sidebar.slider('Minimum diagonal line length', 0, 250, 2, step=1)
min_white_vert_line_len = sidebar.slider('Minimum white vertical line length', 0, 250, 2, step=1)
percentile = sidebar.slider('Precentile', 0, 100, 24, step=1)
sidebar.download_button('Download file', download(),file_name='archive.zip')
# ---------------Plot RPs--------------------
# runs_ = ['Baseline open eyes', 'Baseline closed eyes', 'Motor execution: left vs right hand', 'Motor imagery: left vs right hand',
# 'Motor execution: hands vs feet', 'Motor imagery: hands vs feet']
#
# options = st.multiselect('Select two runs to compare', runs_, ['Baseline open eyes', 'Baseline closed eyes'])
# run_list = []
#
# for v in options:
# run_list.append(runs_.index(v)+1)
# if len(run_list) <= 1:
# run_list = [1,2]
st.markdown('Baseline open eyes vs Baseline closed eyes')
options = ['Baseline open eyes', 'Baseline closed eyes']
run_list = [1,2]
rp_plot, matrix_open_binary, matrix_close_binary, epochs, stft1, stft2 = run_computation(t_start, t_end, selected_subject, fir_filter, electrode_name, cut_freq, win_len, n_fft, percentile, run_list,options)
st.write(rp_plot)
# ---------------Plot Spectrum--------------------
st.write(waterfall_spectrum(stft1, stft2, 160, cut_freq, options))
# ---------------Save RPs--------------------
if st.button('Save RPs as *.npy'):
save(matrix_open_binary, matrix_close_binary)
# ---------------Plot Radar--------------------
rqa_radar = plot_rqa(matrix_open_binary, matrix_close_binary, min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len, options)
st.write(rqa_radar)