Łukasz Furman
commited on
Commit
•
a59bdc5
1
Parent(s):
e592727
update app.py
Browse files- .gitignore +3 -0
- BrainPulse/.DS_Store +0 -0
- BrainPulse/README.md +2 -0
- BrainPulse/__init__.py +0 -0
- BrainPulse/dataset.py +153 -0
- BrainPulse/dependency.py +8 -0
- BrainPulse/distance_matrix.py +104 -0
- BrainPulse/event.py +377 -0
- BrainPulse/features_space.py +282 -0
- BrainPulse/frequency_recurrence.py +124 -0
- BrainPulse/matrix_open_binary.npy +3 -0
- BrainPulse/model_SVM.py +45 -0
- BrainPulse/plot.py +634 -0
- BrainPulse/recurrence_quantification_analysis.py +305 -0
- BrainPulse/requirements.txt +0 -0
- BrainPulse/vector_space.py +16 -0
- Dockerfile +21 -0
- LICENSE +21 -0
- README.md +2 -13
- app.py +260 -0
- complexRadar.py +168 -0
- papaerspace_config.txt +5 -0
- requirements.txt +10 -0
- test.py +5 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
*.pyc
|
3 |
+
*.zip
|
BrainPulse/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
BrainPulse/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# BrainPulse
|
2 |
+
|
BrainPulse/__init__.py
ADDED
File without changes
|
BrainPulse/dataset.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mne import Epochs, pick_types, events_from_annotations, create_info, EpochsArray
|
2 |
+
from mne.channels import make_standard_montage
|
3 |
+
from mne.io import concatenate_raws, read_raw_edf, RawArray
|
4 |
+
from mne.datasets import eegbci
|
5 |
+
import scipy.io
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import mne
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def eegbci_data(tmin, tmax, subject, filter_range = None, run_list = None):
|
13 |
+
|
14 |
+
event_id = dict(ev=0)
|
15 |
+
runs = run_list # open eyes vs closed eyes
|
16 |
+
|
17 |
+
# raw_fnames = eegbci.load_data(subject, runs, path='./datasets', update_path=True)
|
18 |
+
raw_fnames = eegbci.load_data(subject, runs, update_path=True, path='../mne_data')
|
19 |
+
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
|
20 |
+
eegbci.standardize(raw) # set channel names
|
21 |
+
montage = make_standard_montage('standard_1005')
|
22 |
+
raw.set_montage(montage)
|
23 |
+
|
24 |
+
# strip channel names of "." characters
|
25 |
+
raw.rename_channels(lambda x: x.strip('.'))
|
26 |
+
raw.set_eeg_reference(projection=True)
|
27 |
+
raw.apply_proj()
|
28 |
+
# Apply band-pass filter
|
29 |
+
if filter_range != None:
|
30 |
+
raw.filter(filter_range[0], filter_range[1], fir_design='firwin', skip_by_annotation='edge')
|
31 |
+
|
32 |
+
events, _ = events_from_annotations(raw, event_id=dict(T0=0))
|
33 |
+
|
34 |
+
picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
|
35 |
+
exclude='bads')
|
36 |
+
# Read epochs
|
37 |
+
epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
|
38 |
+
baseline=None, preload=True)
|
39 |
+
# print(epochs.get_data())
|
40 |
+
return epochs, raw
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def eegMCI_data(stim_type, subject, run_type, path_to_data_folder):
|
46 |
+
"""
|
47 |
+
eegMCI_data does read 3d data from .mat file structured as ( n epochs, n channels, n samples ).
|
48 |
+
|
49 |
+
:param stim_type: 'AV','A','V'
|
50 |
+
:param subject: select a subject from dataset
|
51 |
+
:param run_type: chose run type of selected subject, possible types ( button.dat, test.dat_1, test.dat_2, train.dat )
|
52 |
+
:param path_to_data_folder: root path to data folder
|
53 |
+
|
54 |
+
:return: epochs, raw mne objects
|
55 |
+
"""
|
56 |
+
|
57 |
+
stim_type_long = ['audiovisual','audio','visual']
|
58 |
+
|
59 |
+
if stim_type == 'AV':
|
60 |
+
stim_typeL = stim_type_long[0]
|
61 |
+
elif stim_type == 'A':
|
62 |
+
stim_typeL = stim_type_long[1]
|
63 |
+
else:
|
64 |
+
stim_typeL = stim_type_long[2]
|
65 |
+
|
66 |
+
targets = ['allNTARGETS','allTARGETS']
|
67 |
+
path = os.path.join(path_to_data_folder,stim_typeL,'s'+str(subject)+'_'+stim_type+'_'+run_type+'.mat')
|
68 |
+
|
69 |
+
mat = scipy.io.loadmat(path)
|
70 |
+
|
71 |
+
raw_no_targets = mat[targets[0]].transpose((0, 2, 1))
|
72 |
+
raw_targets = mat[targets[1]].transpose((0, 2, 1))
|
73 |
+
|
74 |
+
electrodes_l = np.concatenate(mat['electrodes'])
|
75 |
+
electrodes_l = [str(x).replace('[','').replace(']','').replace("'",'') for x in electrodes_l]
|
76 |
+
ch_types = ['eeg'] * 16
|
77 |
+
info = create_info(electrodes_l, ch_types=ch_types, sfreq=512)
|
78 |
+
info.set_montage('standard_1020')
|
79 |
+
|
80 |
+
epochs_no_target = EpochsArray(raw_no_targets, info, tmin=-0.2)
|
81 |
+
epochs_target = EpochsArray(raw_targets, info, tmin=-0.2)
|
82 |
+
|
83 |
+
erp_no_target = epochs_no_target.average()
|
84 |
+
erp_target = epochs_target.average()
|
85 |
+
|
86 |
+
epochs = np.array([erp_no_target.get_data(),erp_target.get_data()],dtype=np.object)
|
87 |
+
raw = [raw_no_targets,raw_targets]
|
88 |
+
print(electrodes_l)
|
89 |
+
return EpochsArray(epochs,info,tmin=-0.2), np.array(raw)
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def eegMCI_data_epochs(stim_type1,stim_type2, subject, run_type, path_to_data_folder):
|
95 |
+
"""
|
96 |
+
eegMCI_data does read 3d data from .mat file structured as ( n epochs, n channels, n samples ).
|
97 |
+
|
98 |
+
:param stim_type1: 'AV','A','V'
|
99 |
+
:param stim_type2: 'AV','A','V'
|
100 |
+
:param subject: select a subject from dataset
|
101 |
+
:param run_type: chose run type of selected subject, possible types ( button.dat, test.dat_1, test.dat_2, train.dat )
|
102 |
+
:param path_to_data_folder: root path to data folder
|
103 |
+
|
104 |
+
:return: epochs, raw mne objects
|
105 |
+
"""
|
106 |
+
|
107 |
+
stim_type_long = ['audiovisual','audio','visual']
|
108 |
+
|
109 |
+
if stim_type1 == 'AV':
|
110 |
+
stim_typeL1 = stim_type_long[0]
|
111 |
+
elif stim_type1 == 'A':
|
112 |
+
stim_typeL1 = stim_type_long[1]
|
113 |
+
else:
|
114 |
+
stim_typeL1 = stim_type_long[2]
|
115 |
+
|
116 |
+
if stim_type2 == 'AV':
|
117 |
+
stim_typeL2 = stim_type_long[0]
|
118 |
+
elif stim_type2 == 'A':
|
119 |
+
stim_typeL2 = stim_type_long[1]
|
120 |
+
else:
|
121 |
+
stim_typeL2 = stim_type_long[2]
|
122 |
+
|
123 |
+
targets = ['allNTARGETS','allTARGETS']
|
124 |
+
path1 = os.path.join(path_to_data_folder,stim_typeL1,'s'+str(subject)+'_'+stim_type1+'_'+run_type+'.mat')
|
125 |
+
path2 = os.path.join(path_to_data_folder,stim_typeL2,'s'+str(subject)+'_'+stim_type2+'_'+run_type+'.mat')
|
126 |
+
|
127 |
+
mat1 = scipy.io.loadmat(path1)
|
128 |
+
mat2 = scipy.io.loadmat(path2)
|
129 |
+
|
130 |
+
raw_no_targets1 = mat1[targets[0]].transpose((0, 2, 1))
|
131 |
+
raw_targets1 = mat1[targets[1]].transpose((0, 2, 1))
|
132 |
+
|
133 |
+
raw_no_targets2 = mat2[targets[0]].transpose((0, 2, 1))
|
134 |
+
raw_targets2 = mat2[targets[1]].transpose((0, 2, 1))
|
135 |
+
|
136 |
+
electrodes_l = np.concatenate(mat1['electrodes'])
|
137 |
+
electrodes_l = [str(x).replace('[','').replace(']','').replace("'",'') for x in electrodes_l]
|
138 |
+
ch_types = ['eeg'] * 16
|
139 |
+
info = create_info(electrodes_l, ch_types=ch_types, sfreq=512)
|
140 |
+
info.set_montage('standard_1020')
|
141 |
+
|
142 |
+
# epochs_no_target = EpochsArray(raw_no_targets1, info, tmin=-0.2)
|
143 |
+
epochs_target1 = EpochsArray(raw_targets1, info, tmin=-0.2, baseline=(-0.2,0))
|
144 |
+
epochs_target2 = EpochsArray(raw_targets2, info, tmin=-0.2, baseline=(-0.2,0))
|
145 |
+
|
146 |
+
# erp_no_target = epochs_no_target.average()
|
147 |
+
erp_target1 = epochs_target1.average()
|
148 |
+
erp_target2 = epochs_target2.average()
|
149 |
+
|
150 |
+
evoked = np.array([erp_target1.get_data(),erp_target2.get_data()],dtype=np.object)
|
151 |
+
epochs = [epochs_target1,epochs_target2]
|
152 |
+
print(electrodes_l)
|
153 |
+
return EpochsArray(evoked,info,tmin=-0.2), epochs
|
BrainPulse/dependency.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pynndescent
|
2 |
+
import numpy
|
3 |
+
import pyrqa
|
4 |
+
import pandas
|
5 |
+
import sklearn
|
6 |
+
import scipy
|
7 |
+
import torch
|
8 |
+
import umap
|
BrainPulse/distance_matrix.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numba import jit
|
2 |
+
from pynndescent.distances import wasserstein_1d, spearmanr, euclidean
|
3 |
+
import numpy as np
|
4 |
+
from pyrqa.computation import RPComputation
|
5 |
+
from pyrqa.time_series import TimeSeries, EmbeddedSeries
|
6 |
+
from pyrqa.settings import Settings
|
7 |
+
from pyrqa.analysis_type import Classic
|
8 |
+
from pyrqa.metric import EuclideanMetric
|
9 |
+
# from pyrqa.metric import Sigmoid
|
10 |
+
# from pyrqa.metric import Cosine
|
11 |
+
from pyrqa.neighbourhood import Unthresholded
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
#
|
16 |
+
# @jit(parallel=True)
|
17 |
+
@jit(nopython=True)
|
18 |
+
def EuclideanPyRQA_RP_stft_cpu(stft_):
|
19 |
+
result = np.zeros((stft_.shape[0], stft_.shape[0]))
|
20 |
+
|
21 |
+
for i, fft in enumerate(stft_):
|
22 |
+
for j, v in enumerate(stft_):
|
23 |
+
d = euclidean(fft, v)
|
24 |
+
result[i, j] = d
|
25 |
+
return result
|
26 |
+
|
27 |
+
@jit(nopython=True)
|
28 |
+
def wasserstein_squereform(stft_):
|
29 |
+
result = np.zeros((stft_.shape[0], stft_.shape[0]))
|
30 |
+
|
31 |
+
for i, fft in enumerate(stft_):
|
32 |
+
for j, v in enumerate(stft_):
|
33 |
+
d = wasserstein_1d(fft, v)
|
34 |
+
result[i, j] = d
|
35 |
+
return result
|
36 |
+
|
37 |
+
@jit(nopython=True)
|
38 |
+
def spearmanr_squereform(stft_):
|
39 |
+
result = np.zeros((stft_.shape[0], stft_.shape[0]))
|
40 |
+
|
41 |
+
for i, fft in enumerate(stft_):
|
42 |
+
for j, v in enumerate(stft_):
|
43 |
+
d = spearmanr(fft, v)
|
44 |
+
result[i, j] = d
|
45 |
+
return result
|
46 |
+
|
47 |
+
@jit(nopython=True)
|
48 |
+
def wasserstein_squereform_binary(stft_, eps_):
|
49 |
+
result = np.zeros((stft_.shape[0], stft_.shape[0]))
|
50 |
+
|
51 |
+
for i, fft in enumerate(stft_):
|
52 |
+
for j, v in enumerate(stft_):
|
53 |
+
d = wasserstein_1d(fft, v)
|
54 |
+
if d <= eps_:
|
55 |
+
dist = 1
|
56 |
+
else:
|
57 |
+
dist = 0
|
58 |
+
|
59 |
+
result[i, j] = dist
|
60 |
+
return result
|
61 |
+
|
62 |
+
|
63 |
+
@jit(nopython=True)
|
64 |
+
def wasserstein_1d_array(stft_):
|
65 |
+
result = np.zeros((stft_.shape[0] * stft_.shape[0]))
|
66 |
+
k = 0
|
67 |
+
for i, fft in enumerate(stft_):
|
68 |
+
for j, v in enumerate(stft_):
|
69 |
+
d = wasserstein_1d(fft, v)
|
70 |
+
result[k] = d
|
71 |
+
k += 1
|
72 |
+
|
73 |
+
return result
|
74 |
+
|
75 |
+
def set_epsilon(matrix, eps):
|
76 |
+
return np.heaviside(eps - matrix, 0)
|
77 |
+
|
78 |
+
|
79 |
+
def EuclideanPyRQA_RP(signal, embedding = 2,timedelay = 9):
|
80 |
+
time_series = TimeSeries(signal,
|
81 |
+
embedding_dimension=embedding,
|
82 |
+
time_delay=timedelay)
|
83 |
+
settings = Settings(time_series,
|
84 |
+
analysis_type=Classic,
|
85 |
+
neighbourhood=Unthresholded(),
|
86 |
+
similarity_measure=EuclideanMetric,
|
87 |
+
theiler_corrector=1)
|
88 |
+
|
89 |
+
computation = RPComputation.create(settings)
|
90 |
+
result = computation.run()
|
91 |
+
return result.recurrence_matrix
|
92 |
+
|
93 |
+
|
94 |
+
def EuclideanPyRQA_RP_stft(signal, embedding = 2,timedelay = 9):
|
95 |
+
time_series = EmbeddedSeries(signal)
|
96 |
+
settings = Settings(time_series,
|
97 |
+
analysis_type=Classic,
|
98 |
+
neighbourhood=Unthresholded(),
|
99 |
+
similarity_measure=EuclideanMetric,
|
100 |
+
theiler_corrector=1)
|
101 |
+
|
102 |
+
computation = RPComputation.create(settings)
|
103 |
+
result = computation.run()
|
104 |
+
return result.recurrence_matrix
|
BrainPulse/event.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Event segmentation using a Hidden Markov Model
|
2 |
+
|
3 |
+
Adapted from the brainiak package for this workshop.
|
4 |
+
See https://brainiak.org/ for full documentation."""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from scipy import stats
|
8 |
+
import logging
|
9 |
+
import copy
|
10 |
+
from sklearn.base import BaseEstimator
|
11 |
+
from sklearn.utils.validation import check_is_fitted, check_array
|
12 |
+
from sklearn.exceptions import NotFittedError
|
13 |
+
import itertools
|
14 |
+
|
15 |
+
|
16 |
+
def masked_log(x):
|
17 |
+
y = np.empty(x.shape, dtype=x.dtype)
|
18 |
+
lim = x.shape[0]
|
19 |
+
for i in range(lim):
|
20 |
+
if x[i] <= 0:
|
21 |
+
y[i] = float('-inf')
|
22 |
+
else:
|
23 |
+
y[i] = np.log(x[i])
|
24 |
+
return y
|
25 |
+
|
26 |
+
|
27 |
+
class EventSegment(BaseEstimator):
|
28 |
+
|
29 |
+
def _default_var_schedule(step):
|
30 |
+
return 4 * (0.98 ** (step - 1))
|
31 |
+
|
32 |
+
def __init__(self, n_events=2,
|
33 |
+
step_var=_default_var_schedule,
|
34 |
+
n_iter=500, event_chains=None,
|
35 |
+
split_merge=False, split_merge_proposals=1):
|
36 |
+
self.n_events = n_events
|
37 |
+
self.step_var = step_var
|
38 |
+
self.n_iter = n_iter
|
39 |
+
self.split_merge = split_merge
|
40 |
+
self.split_merge_proposals = split_merge_proposals
|
41 |
+
if event_chains is None:
|
42 |
+
self.event_chains = np.zeros(n_events)
|
43 |
+
else:
|
44 |
+
self.event_chains = event_chains
|
45 |
+
|
46 |
+
def _fit_validate(self, X):
|
47 |
+
if len(np.unique(self.event_chains)) > 1:
|
48 |
+
raise RuntimeError("Cannot fit chains, use set_event_patterns")
|
49 |
+
|
50 |
+
# Copy X into a list and transpose
|
51 |
+
X = copy.deepcopy(X)
|
52 |
+
if type(X) is not list:
|
53 |
+
X = [X]
|
54 |
+
for i in range(len(X)):
|
55 |
+
X[i] = check_array(X[i])
|
56 |
+
X[i] = X[i].T
|
57 |
+
|
58 |
+
# Check that number of voxels is consistent across datasets
|
59 |
+
n_dim = X[0].shape[0]
|
60 |
+
for i in range(len(X)):
|
61 |
+
assert (X[i].shape[0] == n_dim)
|
62 |
+
|
63 |
+
# Double-check that data is z-scored in time
|
64 |
+
for i in range(len(X)):
|
65 |
+
X[i] = stats.zscore(X[i], axis=1, ddof=1)
|
66 |
+
|
67 |
+
return X
|
68 |
+
|
69 |
+
def fit(self, X, y=None):
|
70 |
+
seg = []
|
71 |
+
X = self._fit_validate(X)
|
72 |
+
n_train = len(X)
|
73 |
+
n_dim = X[0].shape[0]
|
74 |
+
self.classes_ = np.arange(self.n_events)
|
75 |
+
|
76 |
+
# Initialize variables for fitting
|
77 |
+
log_gamma = []
|
78 |
+
for i in range(n_train):
|
79 |
+
log_gamma.append(np.zeros((X[i].shape[1], self.n_events)))
|
80 |
+
step = 1
|
81 |
+
best_ll = float("-inf")
|
82 |
+
self.ll_ = np.empty((0, n_train))
|
83 |
+
while step <= self.n_iter:
|
84 |
+
iteration_var = self.step_var(step)
|
85 |
+
|
86 |
+
# Based on the current segmentation, compute the mean pattern
|
87 |
+
# for each event
|
88 |
+
seg_prob = [np.exp(lg) / np.sum(np.exp(lg), axis=0)
|
89 |
+
for lg in log_gamma]
|
90 |
+
mean_pat = np.empty((n_train, n_dim, self.n_events))
|
91 |
+
for i in range(n_train):
|
92 |
+
mean_pat[i, :, :] = X[i].dot(seg_prob[i])
|
93 |
+
mean_pat = np.mean(mean_pat, axis=0)
|
94 |
+
|
95 |
+
# Based on the current mean patterns, compute the event
|
96 |
+
# segmentation
|
97 |
+
self.ll_ = np.append(self.ll_, np.empty((1, n_train)), axis=0)
|
98 |
+
for i in range(n_train):
|
99 |
+
logprob = self._logprob_obs(X[i], mean_pat, iteration_var)
|
100 |
+
log_gamma[i], self.ll_[-1, i] = self._forward_backward(logprob)
|
101 |
+
|
102 |
+
if step > 1 and self.split_merge:
|
103 |
+
curr_ll = np.mean(self.ll_[-1, :])
|
104 |
+
self.ll_[-1, :], log_gamma, mean_pat = \
|
105 |
+
self._split_merge(X, log_gamma, iteration_var, curr_ll)
|
106 |
+
|
107 |
+
# If log-likelihood has started decreasing, undo last step and stop
|
108 |
+
if np.mean(self.ll_[-1, :]) < best_ll:
|
109 |
+
self.ll_ = self.ll_[:-1, :]
|
110 |
+
break
|
111 |
+
|
112 |
+
self.segments_ = [np.exp(lg) for lg in log_gamma]
|
113 |
+
self.event_var_ = iteration_var
|
114 |
+
self.event_pat_ = mean_pat
|
115 |
+
best_ll = np.mean(self.ll_[-1, :])
|
116 |
+
|
117 |
+
seg.append(self.segments_[0].copy())
|
118 |
+
|
119 |
+
step += 1
|
120 |
+
|
121 |
+
return seg
|
122 |
+
|
123 |
+
def _logprob_obs(self, data, mean_pat, var):
|
124 |
+
n_vox = data.shape[0]
|
125 |
+
t = data.shape[1]
|
126 |
+
|
127 |
+
# z-score both data and mean patterns in space, so that Gaussians
|
128 |
+
# are measuring Pearson correlations and are insensitive to overall
|
129 |
+
# activity changes
|
130 |
+
data_z = stats.zscore(data, axis=0, ddof=1)
|
131 |
+
mean_pat_z = stats.zscore(mean_pat, axis=0, ddof=1)
|
132 |
+
|
133 |
+
logprob = np.empty((t, self.n_events))
|
134 |
+
|
135 |
+
if type(var) is not np.ndarray:
|
136 |
+
var = var * np.ones(self.n_events)
|
137 |
+
|
138 |
+
for k in range(self.n_events):
|
139 |
+
logprob[:, k] = -0.5 * n_vox * np.log(
|
140 |
+
2 * np.pi * var[k]) - 0.5 * np.sum(
|
141 |
+
(data_z.T - mean_pat_z[:, k]).T ** 2, axis=0) / var[k]
|
142 |
+
|
143 |
+
logprob /= n_vox
|
144 |
+
return logprob
|
145 |
+
|
146 |
+
def _forward_backward(self, logprob):
|
147 |
+
logprob = copy.copy(logprob)
|
148 |
+
t = logprob.shape[0]
|
149 |
+
logprob = np.hstack((logprob, float("-inf") * np.ones((t, 1))))
|
150 |
+
|
151 |
+
# Initialize variables
|
152 |
+
log_scale = np.zeros(t)
|
153 |
+
log_alpha = np.zeros((t, self.n_events + 1))
|
154 |
+
log_beta = np.zeros((t, self.n_events + 1))
|
155 |
+
|
156 |
+
# Set up transition matrix, with final sink state
|
157 |
+
self.p_start = np.zeros(self.n_events + 1)
|
158 |
+
self.p_end = np.zeros(self.n_events + 1)
|
159 |
+
self.P = np.zeros((self.n_events + 1, self.n_events + 1))
|
160 |
+
label_ind = np.unique(self.event_chains, return_inverse=True)[1]
|
161 |
+
n_chains = np.max(label_ind) + 1
|
162 |
+
|
163 |
+
# For each chain of events, link them together and then to sink state
|
164 |
+
for c in range(n_chains):
|
165 |
+
chain_ind = np.nonzero(label_ind == c)[0]
|
166 |
+
self.p_start[chain_ind[0]] = 1 / n_chains
|
167 |
+
self.p_end[chain_ind[-1]] = 1 / n_chains
|
168 |
+
|
169 |
+
p_trans = (len(chain_ind) - 1) / t
|
170 |
+
if p_trans >= 1:
|
171 |
+
raise ValueError('Too few timepoints')
|
172 |
+
for i in range(len(chain_ind)):
|
173 |
+
self.P[chain_ind[i], chain_ind[i]] = 1 - p_trans
|
174 |
+
if i < len(chain_ind) - 1:
|
175 |
+
self.P[chain_ind[i], chain_ind[i+1]] = p_trans
|
176 |
+
else:
|
177 |
+
self.P[chain_ind[i], -1] = p_trans
|
178 |
+
self.P[-1, -1] = 1
|
179 |
+
|
180 |
+
# Forward pass
|
181 |
+
for i in range(t):
|
182 |
+
if i == 0:
|
183 |
+
log_alpha[0, :] = self._log(self.p_start) + logprob[0, :]
|
184 |
+
else:
|
185 |
+
log_alpha[i, :] = self._log(np.exp(log_alpha[i - 1, :])
|
186 |
+
.dot(self.P)) + logprob[i, :]
|
187 |
+
|
188 |
+
log_scale[i] = np.logaddexp.reduce(log_alpha[i, :])
|
189 |
+
log_alpha[i] -= log_scale[i]
|
190 |
+
|
191 |
+
# Backward pass
|
192 |
+
log_beta[-1, :] = self._log(self.p_end) - log_scale[-1]
|
193 |
+
for i in reversed(range(t - 1)):
|
194 |
+
obs_weighted = log_beta[i + 1, :] + logprob[i + 1, :]
|
195 |
+
offset = np.max(obs_weighted)
|
196 |
+
log_beta[i, :] = offset + self._log(
|
197 |
+
np.exp(obs_weighted - offset).dot(self.P.T)) - log_scale[i]
|
198 |
+
|
199 |
+
# Combine and normalize
|
200 |
+
log_gamma = log_alpha + log_beta
|
201 |
+
log_gamma -= np.logaddexp.reduce(log_gamma, axis=1, keepdims=True)
|
202 |
+
|
203 |
+
ll = np.sum(log_scale[:(t - 1)]) + np.logaddexp.reduce(
|
204 |
+
log_alpha[-1, :] + log_scale[-1] + self._log(self.p_end))
|
205 |
+
|
206 |
+
log_gamma = log_gamma[:, :-1]
|
207 |
+
|
208 |
+
return log_gamma, ll
|
209 |
+
|
210 |
+
def _log(self, x):
|
211 |
+
xshape = x.shape
|
212 |
+
_x = x.flatten()
|
213 |
+
y = masked_log(_x)
|
214 |
+
return y.reshape(xshape)
|
215 |
+
|
216 |
+
def set_event_patterns(self, event_pat):
|
217 |
+
if event_pat.shape[1] != self.n_events:
|
218 |
+
raise ValueError(("Number of columns of event_pat must match "
|
219 |
+
"number of events"))
|
220 |
+
self.event_pat_ = event_pat.copy()
|
221 |
+
|
222 |
+
def find_events(self, testing_data, var=None, scramble=False):
|
223 |
+
if var is None:
|
224 |
+
if not hasattr(self, 'event_var_'):
|
225 |
+
raise NotFittedError(("Event variance must be provided, if "
|
226 |
+
"not previously set by fit()"))
|
227 |
+
else:
|
228 |
+
var = self.event_var_
|
229 |
+
|
230 |
+
if not hasattr(self, 'event_pat_'):
|
231 |
+
raise NotFittedError(("The event patterns must first be set "
|
232 |
+
"by fit() or set_event_patterns()"))
|
233 |
+
if scramble:
|
234 |
+
mean_pat = self.event_pat_[:, np.random.permutation(self.n_events)]
|
235 |
+
else:
|
236 |
+
mean_pat = self.event_pat_
|
237 |
+
|
238 |
+
logprob = self._logprob_obs(testing_data.T, mean_pat, var)
|
239 |
+
lg, test_ll = self._forward_backward(logprob)
|
240 |
+
segments = np.exp(lg)
|
241 |
+
|
242 |
+
return segments, test_ll
|
243 |
+
|
244 |
+
def predict(self, X):
|
245 |
+
check_is_fitted(self, ["event_pat_", "event_var_"])
|
246 |
+
X = check_array(X)
|
247 |
+
segments, test_ll = self.find_events(X)
|
248 |
+
return np.argmax(segments, axis=1)
|
249 |
+
|
250 |
+
def calc_weighted_event_var(self, D, weights, event_pat):
|
251 |
+
Dz = stats.zscore(D, axis=1, ddof=1)
|
252 |
+
ev_var = np.empty(event_pat.shape[1])
|
253 |
+
for e in range(event_pat.shape[1]):
|
254 |
+
# Only compute variances for weights > 0.1% of max weight
|
255 |
+
nz = weights[:, e] > np.max(weights[:, e])/1000
|
256 |
+
sumsq = np.dot(weights[nz, e],
|
257 |
+
np.sum(np.square(Dz[nz, :] -
|
258 |
+
event_pat[:, e]), axis=1))
|
259 |
+
ev_var[e] = sumsq/(np.sum(weights[nz, e]) -
|
260 |
+
np.sum(np.square(weights[nz, e])) /
|
261 |
+
np.sum(weights[nz, e]))
|
262 |
+
ev_var = ev_var / D.shape[1]
|
263 |
+
return ev_var
|
264 |
+
|
265 |
+
def model_prior(self, t):
|
266 |
+
lg, test_ll = self._forward_backward(np.zeros((t, self.n_events)))
|
267 |
+
segments = np.exp(lg)
|
268 |
+
|
269 |
+
return segments, test_ll
|
270 |
+
|
271 |
+
def _split_merge(self, X, log_gamma, iteration_var, curr_ll):
|
272 |
+
# Compute current probabilities and mean patterns
|
273 |
+
n_train = len(X)
|
274 |
+
n_dim = X[0].shape[0]
|
275 |
+
|
276 |
+
seg_prob = [np.exp(lg) / np.sum(np.exp(lg), axis=0)
|
277 |
+
for lg in log_gamma]
|
278 |
+
mean_pat = np.empty((n_train, n_dim, self.n_events))
|
279 |
+
for i in range(n_train):
|
280 |
+
mean_pat[i, :, :] = X[i].dot(seg_prob[i])
|
281 |
+
mean_pat = np.mean(mean_pat, axis=0)
|
282 |
+
|
283 |
+
# For each event, merge its probability distribution
|
284 |
+
# with the next event, and also split its probability
|
285 |
+
# distribution at its median into two separate events.
|
286 |
+
# Use these new event probability distributions to compute
|
287 |
+
# merged and split event patterns.
|
288 |
+
merge_pat = np.empty((n_train, n_dim, self.n_events))
|
289 |
+
split_pat = np.empty((n_train, n_dim, 2 * self.n_events))
|
290 |
+
for i, sp in enumerate(seg_prob): # Iterate over datasets
|
291 |
+
m_evprob = np.zeros((sp.shape[0], sp.shape[1]))
|
292 |
+
s_evprob = np.zeros((sp.shape[0], 2 * sp.shape[1]))
|
293 |
+
cs = np.cumsum(sp, axis=0)
|
294 |
+
for e in range(sp.shape[1]):
|
295 |
+
# Split distribution at midpoint and normalize each half
|
296 |
+
mid = np.where(cs[:, e] >= 0.5)[0][0]
|
297 |
+
cs_first = cs[mid, e] - sp[mid, e]
|
298 |
+
cs_second = 1 - cs_first
|
299 |
+
s_evprob[:mid, 2 * e] = sp[:mid, e] / cs_first
|
300 |
+
s_evprob[mid:, 2 * e + 1] = sp[mid:, e] / cs_second
|
301 |
+
|
302 |
+
# Merge distribution with next event distribution
|
303 |
+
m_evprob[:, e] = sp[:, e:(e + 2)].mean(1)
|
304 |
+
|
305 |
+
# Weight data by distribution to get event patterns
|
306 |
+
merge_pat[i, :, :] = X[i].dot(m_evprob)
|
307 |
+
split_pat[i, :, :] = X[i].dot(s_evprob)
|
308 |
+
|
309 |
+
# Average across datasets
|
310 |
+
merge_pat = np.mean(merge_pat, axis=0)
|
311 |
+
split_pat = np.mean(split_pat, axis=0)
|
312 |
+
|
313 |
+
# Correlate the current event patterns with the split and
|
314 |
+
# merged patterns
|
315 |
+
merge_corr = np.zeros(self.n_events)
|
316 |
+
split_corr = np.zeros(self.n_events)
|
317 |
+
for e in range(self.n_events):
|
318 |
+
split_corr[e] = np.corrcoef(mean_pat[:, e],
|
319 |
+
split_pat[:, (2 * e):(2 * e + 2)],
|
320 |
+
rowvar=False)[0, 1:3].max()
|
321 |
+
merge_corr[e] = np.corrcoef(merge_pat[:, e],
|
322 |
+
mean_pat[:, e:(e + 2)],
|
323 |
+
rowvar=False)[0, 1:3].min()
|
324 |
+
merge_corr = merge_corr[:-1]
|
325 |
+
|
326 |
+
# Find best merge/split candidates
|
327 |
+
# A high value of merge_corr indicates that a pair of events are
|
328 |
+
# very similar to their merged pattern, and are good candidates for
|
329 |
+
# being merged.
|
330 |
+
# A low value of split_corr indicates that an event's pattern is
|
331 |
+
# very dissimilar from the patterns in its first and second half,
|
332 |
+
# and is a good candidate for being split.
|
333 |
+
best_merge = np.flipud(np.argsort(merge_corr))
|
334 |
+
best_merge = best_merge[:self.split_merge_proposals]
|
335 |
+
best_split = np.argsort(split_corr)
|
336 |
+
best_split = best_split[:self.split_merge_proposals]
|
337 |
+
|
338 |
+
# For every pair of merge/split candidates, attempt the merge/split
|
339 |
+
# and measure the log-likelihood. If any are better than curr_ll,
|
340 |
+
# accept this best merge/split
|
341 |
+
mean_pat_last = mean_pat.copy()
|
342 |
+
return_ll = curr_ll
|
343 |
+
return_lg = copy.deepcopy(log_gamma)
|
344 |
+
return_mp = mean_pat.copy()
|
345 |
+
for m_e, s_e in itertools.product(best_merge, best_split):
|
346 |
+
if m_e == s_e or m_e+1 == s_e:
|
347 |
+
# Don't attempt to merge/split same event
|
348 |
+
continue
|
349 |
+
|
350 |
+
# Construct new set of patterns with merge/split
|
351 |
+
mean_pat_ms = np.delete(mean_pat_last, s_e, axis=1)
|
352 |
+
mean_pat_ms = np.insert(mean_pat_ms, [s_e, s_e],
|
353 |
+
split_pat[:, (2 * s_e):(2 * s_e + 2)],
|
354 |
+
axis=1)
|
355 |
+
mean_pat_ms = np.delete(mean_pat_ms,
|
356 |
+
[m_e + (s_e < m_e), m_e + (s_e < m_e) + 1],
|
357 |
+
axis=1)
|
358 |
+
mean_pat_ms = np.insert(mean_pat_ms, m_e + (s_e < m_e),
|
359 |
+
merge_pat[:, m_e], axis=1)
|
360 |
+
|
361 |
+
# Measure log-likelihood with these new patterns
|
362 |
+
ll_ms = np.zeros(n_train)
|
363 |
+
log_gamma_ms = list()
|
364 |
+
for i in range(n_train):
|
365 |
+
logprob = self._logprob_obs(X[i],
|
366 |
+
mean_pat_ms, iteration_var)
|
367 |
+
lg, ll_ms[i] = self._forward_backward(logprob)
|
368 |
+
log_gamma_ms.append(lg)
|
369 |
+
|
370 |
+
# If better than best ll so far, save to return to fit()
|
371 |
+
if ll_ms.mean() > return_ll:
|
372 |
+
return_mp = mean_pat_ms.copy()
|
373 |
+
return_ll = ll_ms
|
374 |
+
for i in range(n_train):
|
375 |
+
return_lg[i] = log_gamma_ms[i].copy()
|
376 |
+
|
377 |
+
return return_ll, return_lg, return_mp
|
BrainPulse/features_space.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from sklearn import svm
|
5 |
+
import sklearn.model_selection as model_selection
|
6 |
+
from sklearn.metrics import accuracy_score
|
7 |
+
from sklearn.metrics import confusion_matrix
|
8 |
+
from sklearn.impute import SimpleImputer
|
9 |
+
from sklearn.svm import SVC
|
10 |
+
from sklearn.model_selection import StratifiedKFold
|
11 |
+
from sklearn.feature_selection import RFECV
|
12 |
+
from sklearn.ensemble import RandomForestClassifier
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
# import rcr
|
16 |
+
|
17 |
+
def save_features_as_csv():
|
18 |
+
return
|
19 |
+
|
20 |
+
def load_features_csv_concat(folder_path):
|
21 |
+
df_list = []
|
22 |
+
for file in glob.glob(folder_path+"/*.csv"):
|
23 |
+
df_ = pd.read_csv(file)
|
24 |
+
df_list.append(df_)
|
25 |
+
df = pd.concat(df_list)
|
26 |
+
df = df.reset_index(drop=True)
|
27 |
+
return df
|
28 |
+
|
29 |
+
def exclude_subject(df,exluded_subjects):
|
30 |
+
condition_string = ''
|
31 |
+
for ex_sub in exluded_subjects:
|
32 |
+
condition_string += "(df['Subject'] !='" +ex_sub+"') & "
|
33 |
+
evaluation_string = 'df['+condition_string[:len(condition_string)-2]+']'
|
34 |
+
df_ex = eval(evaluation_string)
|
35 |
+
return df_ex.reset_index(drop=True)
|
36 |
+
|
37 |
+
def electrode_wise_dataframe(df, condition_list, id_vars = ['Subject', 'Task', 'Electrode']):
|
38 |
+
stats_frame = df[
|
39 |
+
['Subject', 'Task', 'Electrode','Lentr', 'TT', 'L', 'RR', 'LAM', 'DET', 'V','Vmax', 'Ventr', 'W','Wentr']
|
40 |
+
]
|
41 |
+
|
42 |
+
stats_frame.melt(id_vars=id_vars, var_name='RQA_feature', value_name='feature_value')
|
43 |
+
stats = stats_frame.pivot_table(index=['Subject', 'Task'], columns='Electrode',
|
44 |
+
values=['Lentr', 'TT', 'L', 'RR', 'LAM', 'DET', 'V','Vmax', 'Ventr', 'W', 'Wentr']).reset_index()
|
45 |
+
|
46 |
+
stats = stats.replace(condition_list[0], 0)
|
47 |
+
stats = stats.replace(condition_list[1], 1)
|
48 |
+
y = stats.Task.values
|
49 |
+
return stats, y
|
50 |
+
|
51 |
+
|
52 |
+
def electrode_wise_dataframe_epochs(df, condition_list, id_vars = ['Subject', 'Task', 'Epoch_id','Electrode']):
|
53 |
+
stats_frame = df[
|
54 |
+
['Subject', 'Task','Epoch_id','Electrode','Lentr', 'TT', 'L', 'RR', 'LAM', 'DET', 'V','Vmax', 'Ventr', 'W','Wentr']
|
55 |
+
]
|
56 |
+
|
57 |
+
stats_frame.melt(id_vars=id_vars, var_name='RQA_feature', value_name='feature_value')
|
58 |
+
stats = stats_frame.pivot_table(index=['Subject', 'Task'], columns=['Electrode', 'Epoch_id'],
|
59 |
+
values=['Lentr', 'TT', 'L', 'RR', 'LAM', 'DET', 'V','Vmax', 'Ventr', 'W', 'Wentr']).reset_index()
|
60 |
+
|
61 |
+
stats = stats.replace(condition_list[0], 0)
|
62 |
+
stats = stats.replace(condition_list[1], 1)
|
63 |
+
y = stats.Task.values
|
64 |
+
return stats, y
|
65 |
+
|
66 |
+
|
67 |
+
def select_features_clean_and_normalize(df,features=['Lentr', 'TT', 'L', 'LAM', 'DET','V', 'Ventr', 'W','Wentr']):
|
68 |
+
|
69 |
+
stats_data = df[features].values
|
70 |
+
|
71 |
+
#rcr
|
72 |
+
stats_data_cleaned=np.empty((stats_data.shape[0],stats_data.shape[1]))
|
73 |
+
stats_data_cleaned[:]=np.nan
|
74 |
+
# r = rcr.RCR(rcr.SS_MEDIAN_DL)
|
75 |
+
r = stats_data_cleaned
|
76 |
+
|
77 |
+
for ii in range(stats_data.shape[1]):
|
78 |
+
# fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(16,8),dpi=200)
|
79 |
+
# ax1.hist(stats_data[:,ii])
|
80 |
+
# ax1.set_title('Raw')
|
81 |
+
|
82 |
+
r.performBulkRejection(stats_data[:,ii])
|
83 |
+
cleaned_data_indices = r.result.indices
|
84 |
+
stats_data_cleaned[cleaned_data_indices,ii]=stats_data[cleaned_data_indices,ii]
|
85 |
+
|
86 |
+
# ax2.hist(stats_data_cleaned[:,ii][~np.isnan(stats_data_cleaned[:,ii])])
|
87 |
+
# ax2.set_title('Cleaned')
|
88 |
+
|
89 |
+
# plt.savefig('Feature_nr_'+str(ii)+'jpg')
|
90 |
+
# plt.close()
|
91 |
+
|
92 |
+
|
93 |
+
df_stats_data_cleaned=pd.DataFrame(stats_data_cleaned)
|
94 |
+
# df_stats_data_cleaned=df_stats_data_cleaned.fillna(method='mean', axis=0)#+df_stats_data_cleaned.fillna(method='bfill', axis=0))/2
|
95 |
+
# df_stats_data_cleaned.interpolate(limit=5, inplace=True)
|
96 |
+
|
97 |
+
imputer = SimpleImputer(missing_values=np.nan, strategy='mean')
|
98 |
+
imputer = imputer.fit(df_stats_data_cleaned)
|
99 |
+
|
100 |
+
stats_data_cleaned = imputer.transform(df_stats_data_cleaned)
|
101 |
+
|
102 |
+
####normalize#########
|
103 |
+
stats_data_normed=np.empty((stats_data.shape[0],stats_data.shape[1]))
|
104 |
+
for ii in range(stats_data.shape[1]):
|
105 |
+
stats_data_normed[:,ii] = (stats_data_cleaned[:,ii]-stats_data_cleaned[:,ii].min(axis=0))/ (stats_data_cleaned[:,ii].max(axis=0)-stats_data_cleaned[:,ii].min(axis=0)) #stats_data[:,ii]-stats_data[:,ii].mean(axis=0))/ stats_data[:,ii].std(axis=0)
|
106 |
+
|
107 |
+
return stats_data_normed
|
108 |
+
|
109 |
+
|
110 |
+
def clasyfication_SVM(df,y,cv=10,type='linear'):
|
111 |
+
|
112 |
+
|
113 |
+
clf=svm.SVC(kernel=type)
|
114 |
+
skf = StratifiedKFold(n_splits=cv)
|
115 |
+
# run split() again to generate folds
|
116 |
+
folds = skf.split(df, y)
|
117 |
+
print('folds shape ', folds)
|
118 |
+
performance = np.zeros(skf.n_splits)
|
119 |
+
performance_open= np.zeros(skf.n_splits)
|
120 |
+
performance_closed= np.zeros(skf.n_splits)
|
121 |
+
|
122 |
+
for i, (train_idx, test_idx) in enumerate(folds):
|
123 |
+
|
124 |
+
X_train = df[train_idx,:]
|
125 |
+
y_train = y[train_idx]
|
126 |
+
|
127 |
+
X_test = df[test_idx,:]
|
128 |
+
y_test = y[test_idx]
|
129 |
+
|
130 |
+
# call fit (on train) and predict (on test)
|
131 |
+
model = clf.fit(X=X_train, y=y_train)
|
132 |
+
y_hat = model.predict(X=X_test)
|
133 |
+
|
134 |
+
# calculate accuracy
|
135 |
+
performance[i] = accuracy_score(y_test, y_hat)
|
136 |
+
cm = confusion_matrix(y_test, y_hat)
|
137 |
+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
138 |
+
# class_acuracy = cm.diagonal()
|
139 |
+
class_acuracy = cm.diagonal()
|
140 |
+
performance_open[i]=class_acuracy[0]*100
|
141 |
+
performance_closed[i]=class_acuracy[1]*100
|
142 |
+
|
143 |
+
# calculate average accuracy
|
144 |
+
print('Mean performance: %.3f' % np.mean(performance*100))
|
145 |
+
print('Mean performance 1st class: %.3f' % np.mean(performance_open))
|
146 |
+
print('Mean performance 2nd class: %.3f' % np.mean(performance_closed))
|
147 |
+
|
148 |
+
|
149 |
+
lin = svm.SVC(kernel=type).fit(X_train, y_train)
|
150 |
+
lin_pred = lin.predict(X_test)
|
151 |
+
|
152 |
+
return lin, lin_pred
|
153 |
+
|
154 |
+
def cross_validation(df,y,cv=10,title = 'cv job',type='linear'):
|
155 |
+
|
156 |
+
# Create the RFE object and compute a cross-validated score.
|
157 |
+
svc = SVC(kernel=type)
|
158 |
+
# The "accuracy" scoring shows the proportion of correct classifications
|
159 |
+
|
160 |
+
min_features_to_select = 4 # Minimum number of features to consider
|
161 |
+
rfecv = RFECV(
|
162 |
+
estimator=svc,
|
163 |
+
step=1,
|
164 |
+
cv=StratifiedKFold(n_splits=cv),
|
165 |
+
scoring="accuracy",
|
166 |
+
min_features_to_select=min_features_to_select,
|
167 |
+
)
|
168 |
+
rfecv.fit(df, y)
|
169 |
+
|
170 |
+
print("Optimal number of features : %d" % rfecv.n_features_)
|
171 |
+
|
172 |
+
# Plot number of features VS. cross-validation scores
|
173 |
+
plt.figure()
|
174 |
+
plt.xlabel("Number of features selected")
|
175 |
+
plt.ylabel("Cross validation score (accuracy)")
|
176 |
+
plt.plot(
|
177 |
+
range(min_features_to_select, len(rfecv.cv_results_['mean_test_score']) + min_features_to_select),
|
178 |
+
rfecv.cv_results_['mean_test_score'],
|
179 |
+
)
|
180 |
+
plt.title(title)
|
181 |
+
plt.show()
|
182 |
+
plt.savefig(title+' classification with feature selection_more_features'+str(rfecv.n_features_)+'_'+str(round(max(rfecv.cv_results_['mean_test_score'])*100,2))+'.png', dpi=150)
|
183 |
+
plt.close()
|
184 |
+
|
185 |
+
return rfecv.transform(df)
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
def compute_binary_SVM(df,y,predict_on_all_data = False,type='linear'):
|
190 |
+
|
191 |
+
# stats_data = df[['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr']].values
|
192 |
+
X_train, X_test, y_train, y_test = model_selection.train_test_split(df, y, train_size=0.80, test_size=0.20,
|
193 |
+
random_state=101)
|
194 |
+
global lin
|
195 |
+
|
196 |
+
if predict_on_all_data:
|
197 |
+
print('SVM prediction on all data')
|
198 |
+
lin = svm.SVC(kernel=type).fit(X_train, y_train)
|
199 |
+
|
200 |
+
lin_pred = lin.predict(df)
|
201 |
+
|
202 |
+
lin_accuracy = accuracy_score(y, lin_pred)
|
203 |
+
|
204 |
+
print('Accuracy (Linear Kernel): ', "%.2f" % (lin_accuracy * 100))
|
205 |
+
|
206 |
+
cm = confusion_matrix(y, lin_pred)
|
207 |
+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
208 |
+
class_acuracy = cm.diagonal()
|
209 |
+
print('Accuracy (1st class): ', "%.2f" % (class_acuracy[0] * 100))
|
210 |
+
print('Accuracy (2nd class): ', "%.2f" % (class_acuracy[1] * 100))
|
211 |
+
else:
|
212 |
+
print('SVM prediction on test data')
|
213 |
+
lin = svm.SVC(kernel=type).fit(X_train, y_train)
|
214 |
+
|
215 |
+
lin_pred = lin.predict(X_test)
|
216 |
+
|
217 |
+
lin_accuracy = accuracy_score(y_test, lin_pred)
|
218 |
+
|
219 |
+
print('Accuracy (Linear Kernel): ', "%.2f" % (lin_accuracy * 100))
|
220 |
+
print('Y train:', y_train)
|
221 |
+
print('Y test:', y_test)
|
222 |
+
print('Y pred:', lin_pred)
|
223 |
+
|
224 |
+
cm = confusion_matrix(y_test, lin_pred)
|
225 |
+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
226 |
+
class_acuracy = cm.diagonal()
|
227 |
+
print('Accuracy (1st class): ', "%.2f" % (class_acuracy[0] * 100))
|
228 |
+
print('Accuracy (2nd class): ', "%.2f" % (class_acuracy[1] * 100))
|
229 |
+
|
230 |
+
return lin, lin_pred
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
def clasyfication_RFC(df,y,cv=10,max_depth=2):
|
235 |
+
|
236 |
+
clf = RandomForestClassifier(max_depth=max_depth, random_state=0)
|
237 |
+
skf = StratifiedKFold(n_splits=cv)
|
238 |
+
# run split() again to generate folds
|
239 |
+
folds = skf.split(df, y)
|
240 |
+
|
241 |
+
performance = np.zeros(skf.n_splits)
|
242 |
+
performance_open= np.zeros(skf.n_splits)
|
243 |
+
performance_closed= np.zeros(skf.n_splits)
|
244 |
+
|
245 |
+
for i, (train_idx, test_idx) in enumerate(folds):
|
246 |
+
|
247 |
+
X_train = df[train_idx,:]
|
248 |
+
y_train = y[train_idx]
|
249 |
+
|
250 |
+
X_test = df[test_idx,:]
|
251 |
+
y_test = y[test_idx]
|
252 |
+
|
253 |
+
# call fit (on train) and predict (on test)
|
254 |
+
model = clf.fit(X=X_train, y=y_train)
|
255 |
+
y_hat = model.predict(X=X_test)
|
256 |
+
|
257 |
+
# calculate accuracy
|
258 |
+
performance[i] = accuracy_score(y_test, y_hat)
|
259 |
+
cm = confusion_matrix(y_test, y_hat)
|
260 |
+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
261 |
+
# class_acuracy = cm.diagonal()
|
262 |
+
class_acuracy = cm.diagonal()
|
263 |
+
performance_open[i]=class_acuracy[0]*100
|
264 |
+
performance_closed[i]=class_acuracy[1]*100
|
265 |
+
|
266 |
+
# calculate average accuracy
|
267 |
+
print('Mean performance: %.3f' % np.mean(performance*100))
|
268 |
+
print('Mean performance 1st class: %.3f' % np.mean(performance_open))
|
269 |
+
print('Mean performance 2nd class: %.3f' % np.mean(performance_closed))
|
270 |
+
|
271 |
+
|
272 |
+
lin = RandomForestClassifier(max_depth=max_depth, random_state=0)
|
273 |
+
lin.fit(X=X_train, y=y_train)
|
274 |
+
lin_pred = lin.predict(X_test)
|
275 |
+
|
276 |
+
return lin, lin_pred
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
|
BrainPulse/frequency_recurrence.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
# %matplotlib inline
|
3 |
+
# %load_ext autotime
|
4 |
+
plt.style.use('classic')
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def normalize(tSignal):
|
10 |
+
# copy the data if needed, omit and rename function argument if desired
|
11 |
+
signal = np.copy(tSignal) # signal is in range [a;b]
|
12 |
+
signal -= np.min(signal) # signal is in range to [0;b-a]
|
13 |
+
signal /= np.max(signal) # signal is normalized to [0;1]
|
14 |
+
signal -= 0.5 # signal is in range [-0.5;0.5]
|
15 |
+
signal *=2 # signal is in range [-1;1]
|
16 |
+
return signal
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def get_max_freqs(stft, cut_freq, crop=False, norm=True):
|
22 |
+
if crop==False:
|
23 |
+
crop=len(stft)
|
24 |
+
id_max_stft = stft[:crop].argmax(axis=1)*(cut_freq/stft.shape[1])
|
25 |
+
if norm:
|
26 |
+
id_max_norm=id_max_stft*(1/cut_freq)
|
27 |
+
return id_max_norm
|
28 |
+
else:
|
29 |
+
return id_max_stft
|
30 |
+
|
31 |
+
|
32 |
+
def symmetrize(a):
|
33 |
+
return a + np.transpose(a, (1, 0, 2)) - np.diag(a.diagonal())
|
34 |
+
|
35 |
+
def symmetrized(m):
|
36 |
+
|
37 |
+
import numpy as np
|
38 |
+
|
39 |
+
i_lower = np.tril_indices(m.shape[0], -1)
|
40 |
+
m[:,:,0][i_lower] = m[:,:,0].T[i_lower]
|
41 |
+
m[:,:,1][i_lower] = m[:,:,1].T[i_lower]
|
42 |
+
m[:,:,2][i_lower] = m[:,:,2].T[i_lower]
|
43 |
+
|
44 |
+
return m
|
45 |
+
|
46 |
+
def calc_color(id_max_norm, unique_values,colors,crop):
|
47 |
+
if crop==False:
|
48 |
+
crop=len(id_max_norm)
|
49 |
+
tmp = np.zeros((id_max_norm.shape[0],id_max_norm.shape[0],3), dtype = 'float64')
|
50 |
+
|
51 |
+
for i, v1 in enumerate(id_max_norm):
|
52 |
+
for j, v2 in enumerate(id_max_norm):
|
53 |
+
tmp[i,j] = v1,v2,0
|
54 |
+
|
55 |
+
tmp_color = np.zeros((id_max_norm.shape[0],id_max_norm.shape[0],3), dtype = 'float64')
|
56 |
+
for i in range(crop):
|
57 |
+
for j in range(crop):
|
58 |
+
x = tmp[i,j][0]
|
59 |
+
y = tmp[i,j][1]
|
60 |
+
|
61 |
+
for k in unique_values:
|
62 |
+
|
63 |
+
if x == k:
|
64 |
+
id = (unique_values).index(x)
|
65 |
+
|
66 |
+
tmp_color[i,j][0] = colors[id][0]
|
67 |
+
if y == k:
|
68 |
+
id = (unique_values).index(y)
|
69 |
+
tmp_color[i,j][1] = colors[id][1]
|
70 |
+
|
71 |
+
|
72 |
+
return symmetrized(tmp_color),symmetrized(tmp)
|
73 |
+
|
74 |
+
def calc_color_raw(id_max_stft):
|
75 |
+
|
76 |
+
tmp = np.zeros((id_max_stft.shape[0],id_max_stft.shape[0],3), dtype = 'float64')
|
77 |
+
|
78 |
+
for i, v1 in enumerate(id_max_stft):
|
79 |
+
for j, v2 in enumerate(id_max_stft):
|
80 |
+
tmp[i,j] = v1,v2,0
|
81 |
+
return symmetrized(tmp)
|
82 |
+
|
83 |
+
def get_unique_colors(cut_freq,unique_values):
|
84 |
+
from matplotlib.colors import LinearSegmentedColormap
|
85 |
+
|
86 |
+
vmax=cut_freq
|
87 |
+
cmap = LinearSegmentedColormap.from_list('mycmap1', [(0 / vmax, "violet"),
|
88 |
+
(4 / vmax, 'blue'),
|
89 |
+
(8 / vmax, 'green'),
|
90 |
+
(15 / vmax, 'yellow'),
|
91 |
+
(30 / vmax, 'red'),
|
92 |
+
(60 / vmax, 'black')
|
93 |
+
])
|
94 |
+
|
95 |
+
colors=[cmap(each) for each in unique_values]
|
96 |
+
|
97 |
+
return cmap, colors
|
98 |
+
|
99 |
+
|
100 |
+
def freqRP(stft, cut_freq, crop=False, norm=True):
|
101 |
+
|
102 |
+
id_max_stft=get_max_freqs(stft,cut_freq,crop,norm)
|
103 |
+
|
104 |
+
unique_values = np.unique(id_max_stft).tolist()
|
105 |
+
|
106 |
+
cmap,colors=get_unique_colors(cut_freq, unique_values)
|
107 |
+
|
108 |
+
color_matrix, raw_c_matrix = calc_color(id_max_stft, unique_values,colors,crop)
|
109 |
+
|
110 |
+
return color_matrix, raw_c_matrix
|
111 |
+
|
112 |
+
|
113 |
+
def plot_freqRP_interactive(color_matrix, raw_c_matrix, filename='', save=False):
|
114 |
+
import plotly.express as px
|
115 |
+
import numpy as np
|
116 |
+
fig=px.imshow(color_matrix, origin='lower')
|
117 |
+
fig.update_traces(customdata=np.round((raw_c_matrix*60),2),
|
118 |
+
hovertemplate="First frequency: %{customdata[0]}<br>Second frequency: %{customdata[1]}<extra></extra>"
|
119 |
+
)
|
120 |
+
if save:
|
121 |
+
fig.write_html(filename)
|
122 |
+
else:
|
123 |
+
fig.show()
|
124 |
+
|
BrainPulse/matrix_open_binary.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db7cd1f9b9fea62c6ecda8ec53c0e3759d6902e82590770080ef5c24531a8c0b
|
3 |
+
size 92236944
|
BrainPulse/model_SVM.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import seaborn as sns
|
5 |
+
from sklearn import svm, datasets
|
6 |
+
import sklearn.model_selection as model_selection
|
7 |
+
from sklearn.metrics import accuracy_score
|
8 |
+
from sklearn.metrics import f1_score
|
9 |
+
from sklearn.metrics import confusion_matrix
|
10 |
+
from sklearn.preprocessing import StandardScaler
|
11 |
+
import umap
|
12 |
+
import umap.plot
|
13 |
+
from umap.parametric_umap import ParametricUMAP
|
14 |
+
from sklearn.impute import SimpleImputer
|
15 |
+
from sklearn.pipeline import make_pipeline
|
16 |
+
from sklearn.preprocessing import QuantileTransformer
|
17 |
+
from mne.datasets import eegbci
|
18 |
+
|
19 |
+
def compute_Linear_Kernel(df):
|
20 |
+
stats_frame = df[
|
21 |
+
['Subject', 'Task', 'Electrode', 'Lentr', 'TT', 'L', 'RR', 'LAM', 'DET']
|
22 |
+
]
|
23 |
+
|
24 |
+
stats_frame.melt(id_vars=['Subject', 'Task', 'Electrode'], var_name='RQA_feature', value_name='feature_value')
|
25 |
+
stats1 = stats_frame.pivot_table(index=['Subject', 'Task'], columns='Electrode',
|
26 |
+
values=['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr']).reset_index()
|
27 |
+
|
28 |
+
y = stats1.Task.values
|
29 |
+
stats_data = stats1[['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr']].values
|
30 |
+
X_train, X_test, y_train, y_test = model_selection.train_test_split(stats_data, y, train_size=0.80, test_size=0.20,
|
31 |
+
random_state=101)
|
32 |
+
|
33 |
+
lin = svm.SVC(kernel='linear').fit(X_train, y_train)
|
34 |
+
|
35 |
+
lin_pred = lin.predict(X_test)
|
36 |
+
|
37 |
+
lin_accuracy = accuracy_score(y_test, lin_pred)
|
38 |
+
|
39 |
+
print('Accuracy (Linear Kernel): ', "%.2f" % (lin_accuracy * 100))
|
40 |
+
|
41 |
+
cm = confusion_matrix(y_test, lin_pred)
|
42 |
+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
43 |
+
class_acuracy = cm.diagonal()
|
44 |
+
print('Accuracy (open): ', "%.2f" % (class_acuracy[0] * 100))
|
45 |
+
print('Accuracy (close): ', "%.2f" % (class_acuracy[1] * 100))
|
BrainPulse/plot.py
ADDED
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from matplotlib import animation, cm
|
5 |
+
import matplotlib
|
6 |
+
import seaborn as sns
|
7 |
+
import seaborn as sns
|
8 |
+
import umap
|
9 |
+
import umap.plot
|
10 |
+
import pandas as pd
|
11 |
+
from .event import EventSegment
|
12 |
+
from sklearn.impute import SimpleImputer
|
13 |
+
from sklearn.pipeline import make_pipeline
|
14 |
+
from sklearn.preprocessing import QuantileTransformer
|
15 |
+
|
16 |
+
sns.set_style("whitegrid")
|
17 |
+
|
18 |
+
# plt.rcParams["font.family"] = "cursive"
|
19 |
+
# plt.rcParams.update({'font.sans-serif':'Times'})
|
20 |
+
# plt.rcParams.update({'font.family':'sans-serif'})
|
21 |
+
# plt.rcParams['font.size'] = 14
|
22 |
+
import matplotlib.font_manager as font_manager
|
23 |
+
font = font_manager.FontProperties(family='Times')
|
24 |
+
|
25 |
+
def explainer_(chan, stft, cut_freq, s_rate):
|
26 |
+
|
27 |
+
fig, axs = plt.subplots(4, figsize=(10, 14), dpi=150) # figsize=(12, 12),
|
28 |
+
time_crop = np.linspace(0, int(chan[:400].shape[0]), chan[:400].shape[0])
|
29 |
+
|
30 |
+
axs[0].plot(chan[:400],'k') # np.linspace(0, int(chan[:400].shape[0]/s_rate), chan[:400].shape[0]),
|
31 |
+
axs[0].fill_betweenx(y=[-210, 125], x1=time_crop[0],
|
32 |
+
x2=time_crop[240], color='white', alpha=0.9, edgecolor='red' )
|
33 |
+
|
34 |
+
axs[0].fill_betweenx(y=[-210, 130], x1=time_crop[2]+20,
|
35 |
+
x2=time_crop[260], color='white', alpha=0.9, edgecolor='green')
|
36 |
+
|
37 |
+
axs[0].fill_betweenx(y=[-210, 135], x1=time_crop[2]+40,
|
38 |
+
x2=time_crop[280], color='white', alpha=0.9, edgecolor='blue')
|
39 |
+
|
40 |
+
axs[0].annotate('$fft_{1}$', xy=(.25, 72), xycoords='data',
|
41 |
+
xytext=(0.05, 1.45), textcoords='axes fraction',
|
42 |
+
arrowprops=dict(arrowstyle="->",facecolor='black',color='black'),
|
43 |
+
horizontalalignment='right', verticalalignment='top',
|
44 |
+
)
|
45 |
+
|
46 |
+
axs[0].annotate('$fft_{2}$', xy=(23.35, 85), xycoords='data',
|
47 |
+
xytext=(0.15, 1.45), textcoords='axes fraction',
|
48 |
+
arrowprops=dict(arrowstyle="->",facecolor='black',color='black'),
|
49 |
+
horizontalalignment='right', verticalalignment='top',
|
50 |
+
)
|
51 |
+
|
52 |
+
axs[0].annotate('$fft_{3}$', xy=(43.45, 95), xycoords='data',
|
53 |
+
xytext=(0.25, 1.45), textcoords='axes fraction',
|
54 |
+
arrowprops=dict(arrowstyle="->",facecolor='black ',color='black'),
|
55 |
+
horizontalalignment='right', verticalalignment='top',
|
56 |
+
)
|
57 |
+
axs[0].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, chan[:400].shape[0], 5)))
|
58 |
+
axs[0].set_xticklabels(
|
59 |
+
[str(np.round(x, 1)) for x in np.linspace(0, int(chan[:400].shape[0] / s_rate), axs[0].get_xticks().shape[0])])
|
60 |
+
axs[0].set_ylabel('Amplitude (µV)', )
|
61 |
+
axs[0].set_xlabel('Time (s)', )
|
62 |
+
axs[0].set_title('(a)', )
|
63 |
+
axs[0].xaxis.grid()
|
64 |
+
axs[0].yaxis.grid()
|
65 |
+
|
66 |
+
|
67 |
+
axs[1].plot((stft[100]/stft.shape[1])**2, 'red',label='$fft_{1}$',marker="o",markersize=3)
|
68 |
+
axs[1].legend(prop=font)
|
69 |
+
axs[1].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
70 |
+
axs[1].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
71 |
+
axs[1].set_xlim([0, 100])
|
72 |
+
# axs[1].set_ylim([0, 250])
|
73 |
+
axs[1].set_ylabel('Power ($\mu V^{2}$)', )
|
74 |
+
axs[1].set_xlabel('Freq (Hz)', )
|
75 |
+
# axs[1].set_title('Frequency Domain ($fft_{1}$, $fft_{2}$, $fft_{3}$)', fontsize=10)
|
76 |
+
axs[1].set_title('(b)', )
|
77 |
+
axs[1].xaxis.grid()
|
78 |
+
axs[1].yaxis.grid()
|
79 |
+
|
80 |
+
|
81 |
+
axs[2].plot((stft[115]/stft.shape[1])**2, 'green', label='$fft_{2}$', marker="o", markersize=3)
|
82 |
+
axs[2].legend(prop=font)
|
83 |
+
axs[2].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
84 |
+
axs[2].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
85 |
+
axs[2].set_xlim([0, 100])
|
86 |
+
# axs[2].set_ylim([0, 250])
|
87 |
+
axs[2].set_ylabel('Power ($\mu V^{2}$)', )
|
88 |
+
axs[2].set_xlabel('Freq (Hz)', )
|
89 |
+
axs[2].set_title('(c)', )
|
90 |
+
axs[2].xaxis.grid()
|
91 |
+
axs[2].yaxis.grid()
|
92 |
+
|
93 |
+
|
94 |
+
axs[3].plot((stft[140]/stft.shape[1])**2, 'blue', label='$fft_{3}$', marker="o", markersize=3)
|
95 |
+
axs[3].legend(prop=font)
|
96 |
+
axs[3].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
97 |
+
axs[3].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
98 |
+
axs[3].set_xlim([0, 100])
|
99 |
+
axs[3].set_ylabel('Power ($\mu V^{2}$)', )
|
100 |
+
axs[3].set_xlabel('Freq (Hz)', )
|
101 |
+
axs[3].set_title('(d)', )
|
102 |
+
axs[3].xaxis.grid()
|
103 |
+
axs[3].yaxis.grid()
|
104 |
+
|
105 |
+
# plt.title('Frequency Domain ($fft_{1}$, $fft_{2}$, $fft_{3}$)', fontsize=10)
|
106 |
+
plt.tight_layout()
|
107 |
+
|
108 |
+
plt.savefig('fig_4.png')
|
109 |
+
|
110 |
+
|
111 |
+
def stft_collections(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args, max_indx = None, min_indx = None):
|
112 |
+
fig = plt.figure(figsize=(14, 12), dpi=150)
|
113 |
+
grid = plt.GridSpec(6, 8, hspace=0.0, wspace=3.5)
|
114 |
+
spectrogram = fig.add_subplot(grid[0:3, 0:4])
|
115 |
+
rp_plot = fig.add_subplot(grid[0:3, 4:])
|
116 |
+
fft_vector = fig.add_subplot(grid[4:, :])
|
117 |
+
|
118 |
+
if max_indx != None and min_indx != None:
|
119 |
+
max_index = max_indx
|
120 |
+
min_index = min_indx
|
121 |
+
else:
|
122 |
+
max_array = np.max(stft, axis=1)
|
123 |
+
max_value_stft = np.max(max_array, axis=0)
|
124 |
+
max_index = list(max_array).index(max_value_stft)
|
125 |
+
|
126 |
+
min_array = np.min(stft, axis=1)
|
127 |
+
min_value_stft = np.min(min_array, axis=0)
|
128 |
+
min_index = list(min_array).index(min_value_stft)
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
# ręczne ustawienie wskaźników
|
133 |
+
# max_index = int(1.52*s_rate)
|
134 |
+
# min_index = int(2.4*s_rate)
|
135 |
+
|
136 |
+
|
137 |
+
# top = np.triu(matrix)
|
138 |
+
# bottom = np.tril(matrix_binary)
|
139 |
+
|
140 |
+
# np.linspace(0, stft.shape[1], stft.shape[1]), np.linspace(0, stft.shape[0], cut_freq),
|
141 |
+
rp_plot.imshow(matrix_binary, cmap='Greys', origin='lower') # interpolation='none'
|
142 |
+
# axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
|
143 |
+
rp_plot.plot(max_index, max_index, 'orange', marker="o", markersize=9)
|
144 |
+
rp_plot.plot(min_index, min_index, 'red', marker="o", markersize=9)
|
145 |
+
# axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
|
146 |
+
# axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
|
147 |
+
rp_plot.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
148 |
+
rp_plot.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
149 |
+
rp_plot.set_xticklabels(
|
150 |
+
[str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_xticks().shape[0])])
|
151 |
+
rp_plot.set_yticklabels(
|
152 |
+
[str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_yticks().shape[0])])
|
153 |
+
rp_plot.set_xlabel('Time (s)', )
|
154 |
+
rp_plot.set_ylabel('Time (s)', )
|
155 |
+
rp_plot.set_title('(b) Recurrence Plot', )
|
156 |
+
rp_plot.xaxis.grid()
|
157 |
+
rp_plot.yaxis.grid()
|
158 |
+
|
159 |
+
spectrogram.pcolormesh(stft.T,cmap='viridis') #,vmax=max_value_stft
|
160 |
+
spectrogram.plot(max_index,2,'orange', marker="|", markersize=40)
|
161 |
+
spectrogram.plot(min_index,2,'red', marker="|", markersize=40)
|
162 |
+
|
163 |
+
spectrogram.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[0], 5)))
|
164 |
+
spectrogram.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, stft.shape[0] / s_rate, spectrogram.get_xticks().shape[0])])
|
165 |
+
spectrogram.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 5)))
|
166 |
+
spectrogram.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
|
167 |
+
spectrogram.set_ylabel('Freq (Hz)', )
|
168 |
+
spectrogram.set_xlabel('Time (s)', )
|
169 |
+
spectrogram.set_title('(a) Spectrogram', )
|
170 |
+
# spectrogram.xaxis.grid()
|
171 |
+
# spectrogram.yaxis.grid()
|
172 |
+
# fig.colorbar(im1, cax=spectrogram, orientation='vertical')
|
173 |
+
|
174 |
+
|
175 |
+
max_index_ = stft[max_index]/stft.shape[1]
|
176 |
+
min_index_ = stft[min_index]/stft.shape[1]
|
177 |
+
fft_vector.plot(max_index_**2,'orange',label='$fft_{t_{1}}$')#,marker="o", markersize=2
|
178 |
+
fft_vector.plot(min_index_**2,'red',label='$fft_{t_{2}}}$')#,marker="o", markersize=2
|
179 |
+
fft_vector.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
180 |
+
fft_vector.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
181 |
+
fft_vector.set_xlim([0,100])
|
182 |
+
fft_vector.set_ylabel('Power ($\mu V^{2}$)', )
|
183 |
+
fft_vector.set_xlabel('Freq (Hz)', )
|
184 |
+
fft_vector.set_title('(c) Frequency Domain', )
|
185 |
+
fft_vector.legend(prop=font)
|
186 |
+
fft_vector.xaxis.grid()
|
187 |
+
fft_vector.yaxis.grid()
|
188 |
+
|
189 |
+
# plt.suptitle( 'Condition: '+ task + '\n' + 'epsilon {}, FFT window size {} '.format(
|
190 |
+
# str(info_args['eps']), str(info_args['win_len'])) + '\n'
|
191 |
+
# + 'Subject {}, electrode {}, n_fft {}'.format(str(info_args['selected_subject']),str(info_args['electrode_name']),str(info_args['n_fft'])), fontsize=12 ,ha='left',va='top')
|
192 |
+
plt.tight_layout()
|
193 |
+
plt.savefig('fig_5.png')
|
194 |
+
# axs[0].imshow(matrix_binary, cmap='cividis', origin='lower') #interpolation='none'
|
195 |
+
# # axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
|
196 |
+
# axs[0].plot(max_index,max_index,'orange',marker="o", markersize=7)
|
197 |
+
# axs[0].plot(min_index,min_index,'red',marker="o", markersize=7)
|
198 |
+
# # axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
|
199 |
+
# # axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
|
200 |
+
# axs[0].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
201 |
+
# axs[0].yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
202 |
+
# axs[0].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_xticks().shape[0])])
|
203 |
+
# axs[0].set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_yticks().shape[0])])
|
204 |
+
# axs[0].set_xlabel('Time (s)')
|
205 |
+
# axs[0].set_ylabel('Time (s)')
|
206 |
+
# axs[0].set_title('Recurrence Plot', fontsize=12)
|
207 |
+
|
208 |
+
def diagnostic(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args):
|
209 |
+
|
210 |
+
fig, axs = plt.subplots(3,1, figsize=(7,12), gridspec_kw={'height_ratios':[6,2,1]},dpi=150)
|
211 |
+
|
212 |
+
# Set up the axes with gridspec
|
213 |
+
|
214 |
+
|
215 |
+
max_array = np.max(stft, axis=1)
|
216 |
+
max_value_stft = np.max(max_array, axis=0)
|
217 |
+
max_index = list(max_array).index(max_value_stft)
|
218 |
+
|
219 |
+
min_array = np.min(stft, axis=1)
|
220 |
+
min_value_stft = np.min(min_array, axis=0)
|
221 |
+
min_index = list(min_array).index(min_value_stft)
|
222 |
+
|
223 |
+
# top = np.triu(matrix)
|
224 |
+
# bottom = np.tril(matrix_binary)
|
225 |
+
|
226 |
+
axs[0].imshow(matrix_binary, cmap='cividis', origin='lower') #interpolation='none'
|
227 |
+
# axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
|
228 |
+
|
229 |
+
axs[0].plot(max_index,max_index,'orange',marker="o", markersize=7)
|
230 |
+
axs[0].plot(min_index,min_index,'red',marker="o", markersize=7)
|
231 |
+
|
232 |
+
axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
|
233 |
+
axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
|
234 |
+
axs[0].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
235 |
+
axs[0].yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
236 |
+
axs[0].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_xticks().shape[0])])
|
237 |
+
axs[0].set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_yticks().shape[0])])
|
238 |
+
axs[0].set_xlabel('Time (s)')
|
239 |
+
axs[0].set_ylabel('Time (s)')
|
240 |
+
axs[0].set_title('Recurrence Plot', )
|
241 |
+
|
242 |
+
|
243 |
+
# np.linspace(0, stft.shape[1], stft.shape[1]), np.linspace(0, stft.shape[0], cut_freq),
|
244 |
+
|
245 |
+
|
246 |
+
axs[1].pcolormesh(stft.T, shading='gouraud') #,vmax=max_value_stft
|
247 |
+
axs[1].plot(max_index,0,'orange', marker="o", markersize=7)
|
248 |
+
axs[1].plot(min_index,0,'red', marker="o", markersize=7)
|
249 |
+
axs[1].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
250 |
+
axs[1].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[1].get_xticks().shape[0])])
|
251 |
+
axs[1].yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 5)))
|
252 |
+
axs[1].set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
|
253 |
+
axs[1].set_ylabel('Freq (Hz)')
|
254 |
+
axs[1].set_xlabel('Time (s)')
|
255 |
+
axs[1].set_title('Spectrogram', )
|
256 |
+
|
257 |
+
max_index_ = stft[max_index]/stft.shape[1]
|
258 |
+
min_index_ = stft[min_index]/stft.shape[1]
|
259 |
+
axs[2].plot(max_index_**2,'orange')#,marker="o", markersize=2
|
260 |
+
axs[2].plot(min_index_**2,'red')#,marker="o", markersize=2
|
261 |
+
axs[2].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
262 |
+
axs[2].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
263 |
+
axs[2].set_xlim([0,100])
|
264 |
+
axs[2].set_ylabel('Power (µV^2)')
|
265 |
+
axs[2].set_xlabel('Freq (Hz)')
|
266 |
+
axs[2].set_title('Frequency Domain',)
|
267 |
+
|
268 |
+
plt.suptitle( 'Condition: '+ task + '\n' + 'epsilon {}, FFT window size {} '.format(
|
269 |
+
str(info_args['eps']), str(info_args['win_len'])) + '\n'
|
270 |
+
+ 'Subject {}, electrode {}, n_fft {}'.format(str(info_args['selected_subject']),str(info_args['electrode_name']),str(info_args['n_fft'])),
|
271 |
+
ha='left',va='top')
|
272 |
+
plt.tight_layout()
|
273 |
+
|
274 |
+
def RecurrencePlot(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args):
|
275 |
+
|
276 |
+
fig, axs = plt.subplots( figsize=(12,12),dpi=200)
|
277 |
+
|
278 |
+
# Set up the axes with gridspec
|
279 |
+
|
280 |
+
|
281 |
+
max_array = np.max(stft, axis=1)
|
282 |
+
max_value_stft = np.max(max_array, axis=0)
|
283 |
+
max_index = list(max_array).index(max_value_stft)
|
284 |
+
|
285 |
+
min_array = np.min(stft, axis=1)
|
286 |
+
min_value_stft = np.min(min_array, axis=0)
|
287 |
+
min_index = list(min_array).index(min_value_stft)
|
288 |
+
|
289 |
+
# top = np.triu(matrix)
|
290 |
+
# bottom = np.tril(matrix_binary)
|
291 |
+
|
292 |
+
axs.imshow(matrix_binary, cmap='cividis', origin='lower') #interpolation='none'
|
293 |
+
# axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
|
294 |
+
|
295 |
+
# axs[0].plot(max_index,max_index,'orange',marker="o", markersize=7)
|
296 |
+
# axs[0].plot(min_index,min_index,'red',marker="o", markersize=7)
|
297 |
+
|
298 |
+
# axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
|
299 |
+
# axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
|
300 |
+
axs.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
301 |
+
axs.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
302 |
+
axs.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs.get_xticks().shape[0])])
|
303 |
+
axs.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs.get_yticks().shape[0])])
|
304 |
+
axs.set_xlabel('Time (s)')
|
305 |
+
axs.set_ylabel('Time (s)')
|
306 |
+
axs.set_title('Recurrence Plot')
|
307 |
+
|
308 |
+
|
309 |
+
# np.linspace(0, stft.shape[1], stft.shape[1]), np.linspace(0, stft.shape[0], cut_freq),
|
310 |
+
|
311 |
+
def features_hists(df, features_list, condition, dpi = 200):
|
312 |
+
fig, axs = plt.subplots(len(features_list),figsize=(6, len(features_list)*3), dpi=dpi)
|
313 |
+
abc = ['(a)','(b)','(c)','(d)','(e)','(f)']
|
314 |
+
|
315 |
+
for i,ax in enumerate(axs):
|
316 |
+
sns.histplot(data=df, x=features_list[i], hue=condition, alpha=0.8, element="bars", fill=False, ax=ax, kde=True)
|
317 |
+
ax.containers[1].remove()
|
318 |
+
ax.containers[0].remove()
|
319 |
+
ax.xaxis.grid()
|
320 |
+
ax.yaxis.grid()
|
321 |
+
ax.set_title(abc[i])
|
322 |
+
# plt.grid(b=None)
|
323 |
+
|
324 |
+
plt.autoscale(enable=True, axis='both', tight=None)
|
325 |
+
fig.tight_layout()
|
326 |
+
|
327 |
+
def features_per_subjects_violin(df, features_list, condition, dpi = 200):
|
328 |
+
fig, axs = plt.subplots(len(features_list),figsize=(14, len(features_list)*2), dpi=dpi,sharex='col')
|
329 |
+
|
330 |
+
for i,ax in enumerate(axs):
|
331 |
+
sns.violinplot(data=df, x=df.Subject, y=features_list[i], hue=condition, ax=ax, split=True,linewidth=0.2)
|
332 |
+
ax.legend(loc='lower right')
|
333 |
+
|
334 |
+
axs[len(features_list)-1].set_xticklabels(axs[len(features_list)-1].get_xticklabels(), rotation=90)
|
335 |
+
# axs.set_ylim([0,1])
|
336 |
+
|
337 |
+
|
338 |
+
plt.tick_params(axis='x', which='major', labelsize=16)
|
339 |
+
fig.tight_layout()
|
340 |
+
|
341 |
+
def umap_on_condition(df,y, title,labels_name,features_list=['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr'], random_state = 70, n_neighbors = 15, min_dist = 0.25, metric = "hamming", df_type=True):
|
342 |
+
|
343 |
+
|
344 |
+
fig, ax1 = plt.subplots(figsize=(8, 8), dpi=150)
|
345 |
+
|
346 |
+
if df_type:
|
347 |
+
stats_data = df
|
348 |
+
else:
|
349 |
+
stats_data = df[features_list].values
|
350 |
+
|
351 |
+
# Preprocess again
|
352 |
+
pipe = make_pipeline(SimpleImputer(strategy="mean"), QuantileTransformer())
|
353 |
+
X = pipe.fit_transform(stats_data.copy())
|
354 |
+
|
355 |
+
# Fit UMAP to processed data
|
356 |
+
manifold = umap.UMAP(random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit(X, y)
|
357 |
+
# X_reduced_2 = manifold.transform(X)
|
358 |
+
umap.plot.points(manifold, labels=labels_name, ax=ax1, color_key=np.array(
|
359 |
+
[(0, 0.35, 0.73), (1, 0.83, 0)])) # ,color_key=np.array([(1,0.83,0),(0,0.35,0.73)])
|
360 |
+
ax1.set_title(title)
|
361 |
+
|
362 |
+
def umap_side_by_side_plot(df1, df2, features_list=['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr'], random_state = 70, n_neighbors = 15, min_dist = 0.25, metric = "hamming"):
|
363 |
+
|
364 |
+
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(16,8),dpi=150)
|
365 |
+
|
366 |
+
stats_data = df1[features_list].values
|
367 |
+
y = df1.Task.values
|
368 |
+
|
369 |
+
# Preprocess again
|
370 |
+
pipe = make_pipeline(SimpleImputer(strategy="mean"), QuantileTransformer())
|
371 |
+
X = pipe.fit_transform(stats_data.copy())
|
372 |
+
|
373 |
+
# Fit UMAP to processed data
|
374 |
+
manifold = umap.UMAP(random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit(X, y)
|
375 |
+
# X_reduced_2 = manifold.transform(X)
|
376 |
+
umap.plot.points(manifold, labels=y, ax=ax1, color_key=np.array(
|
377 |
+
[(0, 0.35, 0.73), (1, 0.83, 0)])) # ,color_key=np.array([(1,0.83,0),(0,0.35,0.73)])
|
378 |
+
ax1.set_xlabel('(a) STFT Condition 0 - open eyes, 1 - closed eyes')
|
379 |
+
|
380 |
+
stats_data = df2[features_list].values
|
381 |
+
y = df2.Task.values
|
382 |
+
|
383 |
+
# Preprocess again
|
384 |
+
pipe = make_pipeline(SimpleImputer(strategy="mean"), QuantileTransformer())
|
385 |
+
X = pipe.fit_transform(stats_data.copy())
|
386 |
+
|
387 |
+
# Fit UMAP to processed data
|
388 |
+
manifold = umap.UMAP(random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit(X, y)
|
389 |
+
# X_reduced_2 = manifold.transform(X)
|
390 |
+
umap.plot.points(manifold, labels=y, ax=ax2, color_key=np.array(
|
391 |
+
[(0, 0.35, 0.73), (1, 0.83, 0)])) # ,color_key=np.array([(1,0.83,0),(0,0.35,0.73)])
|
392 |
+
ax2.set_xlabel('(b) TDEMB Condition 0 - open eyes, 1 - closed eyes')
|
393 |
+
|
394 |
+
return
|
395 |
+
|
396 |
+
def SVM_histogram(df, lin, lin_pred,title):
|
397 |
+
stats_data = df #[features_list].values
|
398 |
+
plt.figure(dpi=150)
|
399 |
+
all_cechy=np.dot(stats_data, lin.coef_.T)
|
400 |
+
df_all=pd.DataFrame({'vectors':all_cechy.ravel(), 'Task':lin_pred})
|
401 |
+
|
402 |
+
|
403 |
+
a = sns.histplot(data=df_all, x='vectors', hue='Task', alpha=0.8, element="bars", fill=False,kde=True, kde_kws={'bw_adjust':0.4},palette=np.array([(0.3,0.85,0),(0.8,0.0,0.44)]))
|
404 |
+
a.containers[1].remove()
|
405 |
+
a.containers[0].remove()
|
406 |
+
# a = sns.kdeplot(data=df_all, x='vectors', hue='Task', alpha=0.8, bw_adjust=0.4,palette=np.array([(0.3,0.85,0),(0.8,0.0,0.44)]))
|
407 |
+
plt.title(title)
|
408 |
+
plt.xlabel('All')
|
409 |
+
plt.grid(b=None)
|
410 |
+
plt.show()
|
411 |
+
|
412 |
+
def f_importances(coef, names):
|
413 |
+
imp = coef
|
414 |
+
imp,names = zip(*sorted(zip(imp,names)))
|
415 |
+
plt.figure()
|
416 |
+
plt.barh(range(len(names)), imp, align='center')
|
417 |
+
plt.yticks(range(len(names)), names)
|
418 |
+
plt.show()
|
419 |
+
|
420 |
+
def SVM_features_importance(lin):
|
421 |
+
|
422 |
+
|
423 |
+
lebel_ll = np.array([['TT']*int(64)+ ['RR']*int(64)+
|
424 |
+
['DET']*int(64)+ ['LAM']*int(64)+
|
425 |
+
['L']*int(64)+ ['L_entr']*int(64)])
|
426 |
+
|
427 |
+
e_long = "Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8".replace('\t',',').split(",")
|
428 |
+
y_e_long = np.array(np.unique(e_long, return_inverse=True)[1].tolist())
|
429 |
+
|
430 |
+
df = pd.DataFrame({'feature':lebel_ll[0],
|
431 |
+
'electrode':e_long,
|
432 |
+
'coef':lin.coef_[0],
|
433 |
+
})
|
434 |
+
# df = df[(df.coef.values >= 0.15) | (df.coef.values <= -0.15)]
|
435 |
+
|
436 |
+
f_importances(df.coef, df.feature)
|
437 |
+
f_importances(df.coef, df.electrode)
|
438 |
+
|
439 |
+
|
440 |
+
sns.set_theme(style='darkgrid', rc={'figure.dpi': 120},
|
441 |
+
font_scale=1.7)
|
442 |
+
fig, ax = plt.subplots(figsize=(16, 10))
|
443 |
+
ax.set_title('Weight of features by electrodes')
|
444 |
+
sns.barplot(x='feature', y='coef', data=df, ax=ax,
|
445 |
+
ci=None,
|
446 |
+
hue='electrode')
|
447 |
+
ax.legend(bbox_to_anchor=(1, 1), title='electrode',prop={'size': 7})
|
448 |
+
|
449 |
+
##### HIDDEN MARKOV MODEL
|
450 |
+
|
451 |
+
def soft_bounds(T,seg):
|
452 |
+
|
453 |
+
# Identify soft boundaries at each step of fitting
|
454 |
+
bounds_anim = []
|
455 |
+
K = seg[0].shape[1]
|
456 |
+
for it in range(1,len(seg)):
|
457 |
+
sb = np.zeros((T,T))
|
458 |
+
for k in range(K-1):
|
459 |
+
p_change = np.diff(seg[it][:,(k+1):].sum(1))
|
460 |
+
sb[1:,1:] += np.outer(p_change, seg[it][1:,k:(k+2)].sum(1))
|
461 |
+
sb = np.maximum(sb,sb.T)
|
462 |
+
sb = sb/np.max(sb)
|
463 |
+
bounds_anim.append(sb)
|
464 |
+
return bounds_anim
|
465 |
+
|
466 |
+
|
467 |
+
def fitting_animation(seg,matrix,s_rate,meta_tick,metastate_id, state_width,color_states_matrix):
|
468 |
+
|
469 |
+
bounds_anim = soft_bounds(matrix.shape[0],seg)
|
470 |
+
|
471 |
+
# Plot timepoint-timepoint correation matrix, with boundaries animated on top
|
472 |
+
|
473 |
+
fig = plt.figure(figsize=(18, 12), dpi=300)
|
474 |
+
grid = plt.GridSpec(4, 12, hspace=0.0, wspace=3.5)
|
475 |
+
ax1 = fig.add_subplot(grid[:, 0:4])
|
476 |
+
ax2 = fig.add_subplot(grid[:, 4:8])
|
477 |
+
ax3 = fig.add_subplot(grid[:, 8:])
|
478 |
+
|
479 |
+
|
480 |
+
# fig, axs = plt.subplots(2,figsize=(8,8), dpi=120)
|
481 |
+
datamat = matrix # np.corrcoef(D)
|
482 |
+
bk = cm.viridis((datamat-np.min(datamat))/(np.max(datamat)-np.min(datamat)))
|
483 |
+
im = ax1.imshow(bk, interpolation='none',origin='lower')
|
484 |
+
fg = cm.gray(1-(sum(bounds_anim)/len(bounds_anim)))
|
485 |
+
# im.set_array(np.minimum(np.maximum(bk + fg, 0), 1))
|
486 |
+
im.set_array(bk * fg)
|
487 |
+
|
488 |
+
|
489 |
+
ax1.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
490 |
+
ax1.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
491 |
+
ax1.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax1.get_xticks().shape[0])])
|
492 |
+
ax1.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax1.get_yticks().shape[0])])
|
493 |
+
ax1.set_xlabel('Time (s)')
|
494 |
+
ax1.set_ylabel('Time (s)')
|
495 |
+
ax1.set_title('Metastates plot over recurrence plot', fontsize=10)
|
496 |
+
ax1.scatter(meta_tick,meta_tick,s=2)
|
497 |
+
|
498 |
+
|
499 |
+
|
500 |
+
ax2.imshow(fg, interpolation='none',origin='lower')
|
501 |
+
ax2.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
502 |
+
ax2.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
503 |
+
ax2.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax2.get_xticks().shape[0])])
|
504 |
+
ax2.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax2.get_yticks().shape[0])])
|
505 |
+
ax2.set_xlabel('Time (s)')
|
506 |
+
ax2.set_ylabel('Time (s)')
|
507 |
+
ax2.set_title('Metastates plot', fontsize=10)
|
508 |
+
ax2.scatter(meta_tick,meta_tick,s=2)
|
509 |
+
|
510 |
+
text_kwargs = dict(ha='center', va='center', fontsize=4, color='0')
|
511 |
+
for i,mstate in enumerate(metastate_id):
|
512 |
+
# ax1.text(meta_tick[i]-35, meta_tick[i]+(state_width[i]/2)+45, 's'+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', **text_kwargs)
|
513 |
+
ax2.annotate('s '+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', xy=(meta_tick[i], meta_tick[i]+(state_width[i]/2)),
|
514 |
+
xytext =(meta_tick[i], meta_tick[i]+(state_width[i]/2)+70),
|
515 |
+
xycoords='data',
|
516 |
+
textcoords='data',
|
517 |
+
arrowprops=dict(arrowstyle="->",facecolor='blue'),
|
518 |
+
horizontalalignment='right', verticalalignment='top', fontsize=5
|
519 |
+
)
|
520 |
+
|
521 |
+
|
522 |
+
color_states = color_states_matrix
|
523 |
+
ax3.imshow(fg[:,:,:3]*color_states, interpolation='none',origin='lower')
|
524 |
+
ax3.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
525 |
+
ax3.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
526 |
+
ax3.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax3.get_xticks().shape[0])])
|
527 |
+
ax3.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax3.get_yticks().shape[0])])
|
528 |
+
ax3.set_xlabel('Time (s)')
|
529 |
+
ax3.set_ylabel('Time (s)')
|
530 |
+
ax3.set_title('Metastates plot', fontsize=10)
|
531 |
+
ax3.scatter(meta_tick,meta_tick,s=2)
|
532 |
+
|
533 |
+
text_kwargs = dict(ha='center', va='center', fontsize=4, color='0')
|
534 |
+
for i,mstate in enumerate(metastate_id):
|
535 |
+
# ax1.text(meta_tick[i]-35, meta_tick[i]+(state_width[i]/2)+45, 's'+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', **text_kwargs)
|
536 |
+
ax3.annotate('s '+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', xy=(meta_tick[i], meta_tick[i]+(state_width[i]/2)),
|
537 |
+
xytext =(meta_tick[i], meta_tick[i]+(state_width[i]/2)+70),
|
538 |
+
xycoords='data',
|
539 |
+
textcoords='data',
|
540 |
+
arrowprops=dict(arrowstyle="->",facecolor='blue'),
|
541 |
+
horizontalalignment='right', verticalalignment='top', fontsize=5
|
542 |
+
)
|
543 |
+
|
544 |
+
|
545 |
+
# def animate_func(i):
|
546 |
+
# fg = cm.Greys(1-bounds_anim[i])
|
547 |
+
# im.set_array(np.minimum(np.maximum(bk + fg,0),1))
|
548 |
+
# return [im]
|
549 |
+
#
|
550 |
+
# anim = animation.FuncAnimation(fig, animate_func,
|
551 |
+
# frames = len(bounds_anim), interval = 1)
|
552 |
+
#
|
553 |
+
#
|
554 |
+
plt.savefig('Metastate.png')
|
555 |
+
# plt.close("all")
|
556 |
+
|
557 |
+
return fig
|
558 |
+
|
559 |
+
# return HTML(anim.to_jshtml(default_mode='Once'))
|
560 |
+
|
561 |
+
def fit_HMM(matrix,n_events):
|
562 |
+
return EventSegment(n_events=n_events).fit(matrix)
|
563 |
+
|
564 |
+
def metastates(seg,matrix,s_rate,meta_tick,metastate_id, state_width,color_states_matrix):
|
565 |
+
fitting_animation(seg,matrix,s_rate,meta_tick,metastate_id, state_width,color_states_matrix)
|
566 |
+
|
567 |
+
# def diagnostic(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args):
|
568 |
+
#
|
569 |
+
# # fig, axs = plt.subplots(3,1, figsize=(4,8), gridspec_kw={'height_ratios':[6,2,1]},dpi=120)
|
570 |
+
#
|
571 |
+
# # Set up the axes with gridspec
|
572 |
+
# fig = plt.figure(figsize=(6, 6),dpi=120)
|
573 |
+
# grid = plt.GridSpec(6, 6, hspace=1.0, wspace=1.0)
|
574 |
+
# spectrogram = fig.add_subplot(grid[0:3, 0:3])
|
575 |
+
# rp_plot = fig.add_subplot(grid[0:3, 3:])
|
576 |
+
# fft_vector = fig.add_subplot(grid[3:,:])
|
577 |
+
#
|
578 |
+
# max_array = np.max(stft, axis=1)
|
579 |
+
# max_value_stft = np.max(max_array, axis=0)
|
580 |
+
# max_index = list(max_array).index(max_value_stft)
|
581 |
+
#
|
582 |
+
# min_array = np.min(stft, axis=1)
|
583 |
+
# min_value_stft = np.min(min_array, axis=0)
|
584 |
+
# min_index = list(min_array).index(min_value_stft)
|
585 |
+
#
|
586 |
+
# # top = np.triu(matrix)
|
587 |
+
# # bottom = np.tril(matrix_binary)
|
588 |
+
#
|
589 |
+
# rp_plot.imshow(matrix_binary, cmap='cividis', origin='lower') #interpolation='none'
|
590 |
+
# # axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
|
591 |
+
# rp_plot.plot(max_index,max_index,'orange',marker="o", markersize=7)
|
592 |
+
# rp_plot.plot(min_index,min_index,'red',marker="o", markersize=7)
|
593 |
+
# # axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
|
594 |
+
# # axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
|
595 |
+
# rp_plot.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
596 |
+
# rp_plot.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
597 |
+
# rp_plot.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_xticks().shape[0])])
|
598 |
+
# rp_plot.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_yticks().shape[0])])
|
599 |
+
# rp_plot.set_xlabel('Time (s)')
|
600 |
+
# rp_plot.set_ylabel('Time (s)')
|
601 |
+
# rp_plot.set_title('Recurrence Plot', fontsize=10)
|
602 |
+
#
|
603 |
+
#
|
604 |
+
# # np.linspace(0, stft.shape[1], stft.shape[1]), np.linspace(0, stft.shape[0], cut_freq),
|
605 |
+
#
|
606 |
+
#
|
607 |
+
# spectrogram.pcolormesh(stft.T, shading='gouraud') #,vmax=max_value_stft
|
608 |
+
# spectrogram.plot(max_index,0,'orange', marker="o", markersize=7)
|
609 |
+
# spectrogram.plot(min_index,0,'red', marker="o", markersize=7)
|
610 |
+
# spectrogram.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
611 |
+
# spectrogram.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, spectrogram.get_xticks().shape[0])])
|
612 |
+
# spectrogram.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 5)))
|
613 |
+
# spectrogram.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
|
614 |
+
# spectrogram.set_ylabel('Freq (Hz)')
|
615 |
+
# spectrogram.set_xlabel('Time (s)')
|
616 |
+
# spectrogram.set_title('Spectrogram', fontsize=10)
|
617 |
+
#
|
618 |
+
# max_index_ = stft[max_index]/stft.shape[1]
|
619 |
+
# min_index_ = stft[min_index]/stft.shape[1]
|
620 |
+
# fft_vector.plot(max_index_**2,'orange')#,marker="o", markersize=2
|
621 |
+
# fft_vector.plot(min_index_**2,'red')#,marker="o", markersize=2
|
622 |
+
# fft_vector.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
623 |
+
# fft_vector.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
624 |
+
# fft_vector.set_xlim([0,100])
|
625 |
+
# fft_vector.set_ylabel('Power (µV^2)')
|
626 |
+
# fft_vector.set_xlabel('Freq (Hz)')
|
627 |
+
# fft_vector.set_title('Frequency Domain', size=10)
|
628 |
+
#
|
629 |
+
# plt.suptitle( 'Condition: '+ task + '\n' + 'epsilon {}, FFT window size {} '.format(
|
630 |
+
# str(info_args['eps']), str(info_args['win_len'])) + '\n'
|
631 |
+
# + 'Subject {}, electrode {}, n_fft {}'.format(str(info_args['selected_subject']),str(info_args['electrode_name']),str(info_args['n_fft'])),
|
632 |
+
# fontsize=8,ha='left',va='top')
|
633 |
+
# plt.tight_layout()
|
634 |
+
|
BrainPulse/recurrence_quantification_analysis.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numba import jit
|
2 |
+
import numpy as np
|
3 |
+
from pyrqa.computation import RPComputation, RQAComputation
|
4 |
+
from pyrqa.time_series import TimeSeries, EmbeddedSeries
|
5 |
+
from pyrqa.settings import Settings
|
6 |
+
from pyrqa.analysis_type import Classic
|
7 |
+
from pyrqa.metric import EuclideanMetric
|
8 |
+
# from pyrqa.metric import Sigmoid
|
9 |
+
# from pyrqa.metric import Cosine
|
10 |
+
from pyrqa.neighbourhood import Unthresholded,FixedRadius
|
11 |
+
|
12 |
+
def get_results(recurrence_matrix,
|
13 |
+
minimum_diagonal_line_length,
|
14 |
+
minimum_vertical_line_length,
|
15 |
+
minimum_white_vertical_line_length):
|
16 |
+
|
17 |
+
number_of_vectors = recurrence_matrix.shape[0]
|
18 |
+
diagonal = diagonal_frequency_distribution(recurrence_matrix)
|
19 |
+
vertical = vertical_frequency_distribution(recurrence_matrix)
|
20 |
+
white = white_vertical_frequency_distribution(recurrence_matrix)
|
21 |
+
|
22 |
+
number_of_vert_lines = number_of_vertical_lines(vertical, minimum_vertical_line_length)
|
23 |
+
number_of_vert_lines_points = number_of_vertical_lines_points(vertical, minimum_vertical_line_length)
|
24 |
+
|
25 |
+
RR = recurrence_rate(recurrence_matrix)
|
26 |
+
DET = determinism(number_of_vectors, diagonal, minimum_diagonal_line_length)
|
27 |
+
L = average_diagonal_line_length(number_of_vectors, diagonal, minimum_diagonal_line_length)
|
28 |
+
Lmax = longest_diagonal_line_length(number_of_vectors, diagonal)
|
29 |
+
DIV = divergence(Lmax)
|
30 |
+
Lentr = entropy_diagonal_lines(number_of_vectors, diagonal, minimum_diagonal_line_length)
|
31 |
+
DET_RR = ratio_determinism_recurrence_rate(DET, RR)
|
32 |
+
LAM = laminarity(number_of_vectors, vertical, minimum_vertical_line_length)
|
33 |
+
V = average_vertical_line_length(number_of_vectors, vertical, minimum_vertical_line_length)
|
34 |
+
Vmax = longest_vertical_line_length(number_of_vectors, vertical)
|
35 |
+
Ventr = entropy_vertical_lines(number_of_vectors, vertical, minimum_vertical_line_length)
|
36 |
+
LAM_DET = laminarity_determinism(LAM, DET)
|
37 |
+
W = average_white_vertical_line_length(number_of_vectors, white, minimum_white_vertical_line_length)
|
38 |
+
Wmax = longest_white_vertical_line_length(number_of_vectors, white)
|
39 |
+
Wentr = entropy_white_vertical_lines(number_of_vectors, white, minimum_white_vertical_line_length)
|
40 |
+
TT = trapping_time(number_of_vert_lines_points, number_of_vert_lines)
|
41 |
+
|
42 |
+
return [RR, DET, L, Lmax, DIV, Lentr, DET_RR, LAM, V, Vmax, Ventr, LAM_DET, W, Wmax, Wentr, TT]
|
43 |
+
|
44 |
+
|
45 |
+
@jit(nopython=True)
|
46 |
+
def diagonal_frequency_distribution(recurrence_matrix):
|
47 |
+
# Calculating the number of states - N
|
48 |
+
number_of_vectors = recurrence_matrix.shape[0]
|
49 |
+
diagonal_frequency_distribution = np.zeros(number_of_vectors + 1)
|
50 |
+
|
51 |
+
# Calculating the diagonal frequency distribution - P(l)
|
52 |
+
for i in range(number_of_vectors - 1, -1, -1):
|
53 |
+
diagonal_line_length = 0
|
54 |
+
for j in range(0, number_of_vectors - i):
|
55 |
+
if recurrence_matrix[i + j, j] == 1:
|
56 |
+
diagonal_line_length += 1
|
57 |
+
if j == (number_of_vectors - i - 1):
|
58 |
+
diagonal_frequency_distribution[diagonal_line_length] += 1.0
|
59 |
+
else:
|
60 |
+
if diagonal_line_length != 0:
|
61 |
+
diagonal_frequency_distribution[diagonal_line_length] += 1.0
|
62 |
+
diagonal_line_length = 0
|
63 |
+
for k in range(1, number_of_vectors):
|
64 |
+
diagonal_line_length = 0
|
65 |
+
for i in range(number_of_vectors - k):
|
66 |
+
j = i + k
|
67 |
+
if recurrence_matrix[i, j] == 1:
|
68 |
+
diagonal_line_length += 1
|
69 |
+
if j == (number_of_vectors - 1):
|
70 |
+
diagonal_frequency_distribution[diagonal_line_length] += 1.0
|
71 |
+
else:
|
72 |
+
if diagonal_line_length != 0:
|
73 |
+
diagonal_frequency_distribution[diagonal_line_length] += 1.0
|
74 |
+
diagonal_line_length = 0
|
75 |
+
|
76 |
+
return diagonal_frequency_distribution
|
77 |
+
|
78 |
+
|
79 |
+
@jit(nopython=True)
|
80 |
+
def vertical_frequency_distribution(recurrence_matrix):
|
81 |
+
number_of_vectors = recurrence_matrix.shape[0]
|
82 |
+
|
83 |
+
# Calculating the vertical frequency distribution - P(v)
|
84 |
+
vertical_frequency_distribution = np.zeros(number_of_vectors + 1)
|
85 |
+
for i in range(number_of_vectors):
|
86 |
+
vertical_line_length = 0
|
87 |
+
for j in range(number_of_vectors):
|
88 |
+
if recurrence_matrix[i, j] == 1:
|
89 |
+
vertical_line_length += 1
|
90 |
+
if j == (number_of_vectors - 1):
|
91 |
+
vertical_frequency_distribution[vertical_line_length] += 1.0
|
92 |
+
else:
|
93 |
+
if vertical_line_length != 0:
|
94 |
+
vertical_frequency_distribution[vertical_line_length] += 1.0
|
95 |
+
vertical_line_length = 0
|
96 |
+
|
97 |
+
return vertical_frequency_distribution
|
98 |
+
|
99 |
+
|
100 |
+
@jit(nopython=True)
|
101 |
+
def white_vertical_frequency_distribution(recurrence_matrix):
|
102 |
+
number_of_vectors = recurrence_matrix.shape[0]
|
103 |
+
|
104 |
+
# Calculating the white vertical frequency distribution - P(w)
|
105 |
+
white_vertical_frequency_distribution = np.zeros(number_of_vectors + 1)
|
106 |
+
for i in range(number_of_vectors):
|
107 |
+
white_vertical_line_length = 0
|
108 |
+
for j in range(number_of_vectors):
|
109 |
+
if recurrence_matrix[i, j] == 0:
|
110 |
+
white_vertical_line_length += 1
|
111 |
+
if j == (number_of_vectors - 1):
|
112 |
+
white_vertical_frequency_distribution[white_vertical_line_length] += 1.0
|
113 |
+
else:
|
114 |
+
if white_vertical_line_length != 0:
|
115 |
+
white_vertical_frequency_distribution[white_vertical_line_length] += 1.0
|
116 |
+
white_vertical_line_length = 0
|
117 |
+
|
118 |
+
return white_vertical_frequency_distribution
|
119 |
+
|
120 |
+
|
121 |
+
@jit(nopython=True)
|
122 |
+
def recurrence_rate(recurrence_matrix):
|
123 |
+
# Calculating the recurrence rate - RR
|
124 |
+
number_of_vectors = recurrence_matrix.shape[0]
|
125 |
+
return np.float(np.sum(recurrence_matrix)) / np.power(number_of_vectors, 2)
|
126 |
+
|
127 |
+
|
128 |
+
def determinism(number_of_vectors, diagonal_frequency_distribution_, minimum_diagonal_line_length):
|
129 |
+
# Calculating the determinism - DET
|
130 |
+
numerator = np.sum(
|
131 |
+
[l * diagonal_frequency_distribution_[l] for l in range(minimum_diagonal_line_length, number_of_vectors)])
|
132 |
+
denominator = np.sum([l * diagonal_frequency_distribution_[l] for l in range(1, number_of_vectors)])
|
133 |
+
return numerator / denominator
|
134 |
+
|
135 |
+
|
136 |
+
def average_diagonal_line_length(number_of_vectors, diagonal_frequency_distribution_, minimum_diagonal_line_length):
|
137 |
+
# Calculating the average diagonal line length - L
|
138 |
+
numerator = np.sum(
|
139 |
+
[l * diagonal_frequency_distribution_[l] for l in range(minimum_diagonal_line_length, number_of_vectors)])
|
140 |
+
denominator = np.sum(
|
141 |
+
[diagonal_frequency_distribution_[l] for l in range(minimum_diagonal_line_length, number_of_vectors)])
|
142 |
+
return numerator / denominator
|
143 |
+
|
144 |
+
|
145 |
+
@jit(nopython=True)
|
146 |
+
def longest_diagonal_line_length(number_of_vectors, diagonal_frequency_distribution_):
|
147 |
+
# Calculating the longest diagonal line length - Lmax
|
148 |
+
for l in range(number_of_vectors - 1, 0, -1):
|
149 |
+
if diagonal_frequency_distribution_[l] != 0:
|
150 |
+
longest_diagonal_line_length = l
|
151 |
+
break
|
152 |
+
return longest_diagonal_line_length
|
153 |
+
|
154 |
+
|
155 |
+
@jit(nopython=True)
|
156 |
+
def divergence(longest_diagonal_line_length_):
|
157 |
+
# Calculating the divergence - DIV
|
158 |
+
return 1. / longest_diagonal_line_length_
|
159 |
+
|
160 |
+
|
161 |
+
@jit(nopython=True)
|
162 |
+
def entropy_diagonal_lines(number_of_vectors, diagonal_frequency_distribution_, minimum_diagonal_line_length):
|
163 |
+
# Calculating the entropy diagonal lines - Lentr
|
164 |
+
sum_diagonal_frequency_distribution = np.float(
|
165 |
+
np.sum(diagonal_frequency_distribution_[minimum_diagonal_line_length:-1]))
|
166 |
+
entropy_diagonal_lines = 0
|
167 |
+
for l in range(minimum_diagonal_line_length, number_of_vectors):
|
168 |
+
if diagonal_frequency_distribution_[l] != 0:
|
169 |
+
entropy_diagonal_lines += (diagonal_frequency_distribution_[
|
170 |
+
l] / sum_diagonal_frequency_distribution) * np.log(
|
171 |
+
diagonal_frequency_distribution_[l] / sum_diagonal_frequency_distribution)
|
172 |
+
entropy_diagonal_lines *= -1
|
173 |
+
return entropy_diagonal_lines
|
174 |
+
|
175 |
+
|
176 |
+
@jit(nopython=True)
|
177 |
+
def ratio_determinism_recurrence_rate(determinism_, recurrence_rate_):
|
178 |
+
# Calculating the divergence - DIV
|
179 |
+
return determinism_ / recurrence_rate_
|
180 |
+
|
181 |
+
|
182 |
+
def laminarity(number_of_vectors, vertical_frequency_distribution_, minimum_vertical_line_length):
|
183 |
+
# Calculating the laminarity - LAM
|
184 |
+
numerator = np.sum(
|
185 |
+
[v * vertical_frequency_distribution_[v] for v in range(minimum_vertical_line_length, number_of_vectors + 1)])
|
186 |
+
denominator = np.sum([v * vertical_frequency_distribution_[v] for v in range(1, number_of_vectors + 1)])
|
187 |
+
return numerator / denominator
|
188 |
+
|
189 |
+
|
190 |
+
def average_vertical_line_length(number_of_vectors, vertical_frequency_distribution_, minimum_vertical_line_length):
|
191 |
+
# Calculating the average vertical line length - V
|
192 |
+
numerator = np.sum(
|
193 |
+
[v * vertical_frequency_distribution_[v] for v in range(minimum_vertical_line_length, number_of_vectors + 1)])
|
194 |
+
denominator = np.sum(
|
195 |
+
[vertical_frequency_distribution_[v] for v in range(minimum_vertical_line_length, number_of_vectors + 1)])
|
196 |
+
return numerator / denominator
|
197 |
+
|
198 |
+
|
199 |
+
@jit(nopython=True)
|
200 |
+
def longest_vertical_line_length(number_of_vectors, vertical_frequency_distribution_):
|
201 |
+
# Calculating the longest vertical line length - Vmax
|
202 |
+
longest_vertical_line_length_ = 0
|
203 |
+
for v in range(number_of_vectors, 0, -1):
|
204 |
+
if vertical_frequency_distribution_[v] != 0:
|
205 |
+
longest_vertical_line_length_ = v
|
206 |
+
break
|
207 |
+
return longest_vertical_line_length_
|
208 |
+
|
209 |
+
|
210 |
+
@jit(nopython=True)
|
211 |
+
def entropy_vertical_lines(number_of_vectors, vertical_frequency_distribution_, minimum_vertical_line_length):
|
212 |
+
# Calculating the entropy vertical lines - Ventr
|
213 |
+
sum_vertical_frequency_distribution = np.float(
|
214 |
+
np.sum(vertical_frequency_distribution_[minimum_vertical_line_length:]))
|
215 |
+
entropy_vertical_lines_ = 0
|
216 |
+
for v in range(minimum_vertical_line_length, number_of_vectors + 1):
|
217 |
+
if vertical_frequency_distribution_[v] != 0:
|
218 |
+
entropy_vertical_lines_ += (vertical_frequency_distribution_[
|
219 |
+
v] / sum_vertical_frequency_distribution) * np.log(
|
220 |
+
vertical_frequency_distribution_[v] / sum_vertical_frequency_distribution)
|
221 |
+
entropy_vertical_lines_ *= -1
|
222 |
+
return entropy_vertical_lines_
|
223 |
+
|
224 |
+
|
225 |
+
@jit(nopython=True)
|
226 |
+
def laminarity_determinism(laminarity_, determinism_):
|
227 |
+
# Calculating the ratio laminarity_determinism - LAM/DET
|
228 |
+
return laminarity_ / determinism_
|
229 |
+
|
230 |
+
|
231 |
+
def average_white_vertical_line_length(number_of_vectors, white_vertical_frequency_distribution_,
|
232 |
+
minimum_white_vertical_line_length):
|
233 |
+
# Calculating the average white vertical line length - W
|
234 |
+
numerator = np.sum([w * white_vertical_frequency_distribution_[w] for w in
|
235 |
+
range(minimum_white_vertical_line_length, number_of_vectors + 1)])
|
236 |
+
denominator = np.sum([white_vertical_frequency_distribution_[w] for w in
|
237 |
+
range(minimum_white_vertical_line_length, number_of_vectors + 1)])
|
238 |
+
return numerator / denominator
|
239 |
+
|
240 |
+
|
241 |
+
@jit(nopython=True)
|
242 |
+
def longest_white_vertical_line_length(number_of_vectors, white_vertical_frequency_distribution_):
|
243 |
+
# Calculating the longest white vertical line length - Wmax
|
244 |
+
longest_white_vertical_line_length_ = 0
|
245 |
+
for w in range(number_of_vectors, 0, -1):
|
246 |
+
if white_vertical_frequency_distribution_[w] != 0:
|
247 |
+
longest_white_vertical_line_length_ = w
|
248 |
+
break
|
249 |
+
return longest_white_vertical_line_length_
|
250 |
+
|
251 |
+
|
252 |
+
@jit(nopython=True)
|
253 |
+
def entropy_white_vertical_lines(number_of_vectors, white_vertical_frequency_distribution_,
|
254 |
+
minimum_white_vertical_line_length):
|
255 |
+
# Calculating the entropy white vertical lines - Wentr
|
256 |
+
sum_white_vertical_frequency_distribution = np.float(
|
257 |
+
np.sum(white_vertical_frequency_distribution_[minimum_white_vertical_line_length:]))
|
258 |
+
entropy_white_vertical_lines_ = 0
|
259 |
+
for w in range(minimum_white_vertical_line_length, number_of_vectors + 1):
|
260 |
+
if white_vertical_frequency_distribution_[w] != 0:
|
261 |
+
entropy_white_vertical_lines_ += (white_vertical_frequency_distribution_[
|
262 |
+
w] / sum_white_vertical_frequency_distribution) * np.log(
|
263 |
+
white_vertical_frequency_distribution_[w] / sum_white_vertical_frequency_distribution)
|
264 |
+
entropy_white_vertical_lines_ *= -1
|
265 |
+
return entropy_white_vertical_lines_
|
266 |
+
|
267 |
+
def number_of_vertical_lines(vertical_frequency_distribution_, minimum_vertical_line_length):
|
268 |
+
if minimum_vertical_line_length > 0:
|
269 |
+
return np.sum(vertical_frequency_distribution_[minimum_vertical_line_length - 1:])
|
270 |
+
|
271 |
+
return np.uint(0)
|
272 |
+
|
273 |
+
|
274 |
+
def number_of_vertical_lines_points(vertical_frequency_distribution_, minimum_vertical_line_length):
|
275 |
+
if minimum_vertical_line_length > 0:
|
276 |
+
return np.sum(
|
277 |
+
((np.arange(vertical_frequency_distribution_.size) + 1) * vertical_frequency_distribution_)[minimum_vertical_line_length - 1:])
|
278 |
+
|
279 |
+
return np.uint(0)
|
280 |
+
|
281 |
+
@jit(nopython=True)
|
282 |
+
def trapping_time(number_of_vertical_lines_points_, number_of_vertical_lines_):
|
283 |
+
"""
|
284 |
+
Trapping time (TT).
|
285 |
+
"""
|
286 |
+
try:
|
287 |
+
return np.float32(number_of_vertical_lines_points_ / number_of_vertical_lines_)
|
288 |
+
except:
|
289 |
+
return 0
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
|
295 |
+
def return_pyRQA_results(signal, nbr):
|
296 |
+
time_series = EmbeddedSeries(signal)
|
297 |
+
settings = Settings(time_series,
|
298 |
+
analysis_type=Classic,
|
299 |
+
neighbourhood=FixedRadius(nbr),
|
300 |
+
similarity_measure=EuclideanMetric,
|
301 |
+
theiler_corrector=1)
|
302 |
+
computation = RQAComputation.create(settings,
|
303 |
+
verbose=True)
|
304 |
+
result = computation.run()
|
305 |
+
return result
|
BrainPulse/requirements.txt
ADDED
File without changes
|
BrainPulse/vector_space.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def compute_stft(epoch_, n_fft, win_len, s_rate, cut_freq):
|
5 |
+
# stft_time = time.time()
|
6 |
+
|
7 |
+
signal_tensor = torch.tensor(epoch_, dtype=torch.float)
|
8 |
+
stft_tensor = torch.stft(signal_tensor,n_fft=n_fft, win_length=win_len, hop_length=1,return_complex=True,window=torch.hann_window(win_len))
|
9 |
+
|
10 |
+
sft = torch.abs(stft_tensor).numpy()
|
11 |
+
|
12 |
+
freq_to_take = (((n_fft/2)+1)*cut_freq) / ((s_rate/2)+1)
|
13 |
+
|
14 |
+
sft = sft[:int(freq_to_take),::]
|
15 |
+
|
16 |
+
return sft.T
|
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# syntax=docker/dockerfile:1
|
2 |
+
|
3 |
+
FROM python:3.8-slim-buster
|
4 |
+
RUN apt-get update && apt-get install -y git
|
5 |
+
|
6 |
+
COPY ./requirements.txt /app/requirements.txt
|
7 |
+
|
8 |
+
WORKDIR /app
|
9 |
+
|
10 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
11 |
+
|
12 |
+
EXPOSE 5000
|
13 |
+
|
14 |
+
COPY . /app
|
15 |
+
|
16 |
+
# ENTRYPOINT [ "python" ]
|
17 |
+
# ENTRYPOINT [ "python" ]
|
18 |
+
# ENTRYPOINT [ "streamlit run BrainPulseAPP.py" ]
|
19 |
+
|
20 |
+
# CMD [ "BrainPulseAPP.py" ]
|
21 |
+
CMD ["streamlit", "run", "BrainPulseAPP.py"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Łukasz Furman
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 📚
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.10.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# BrainPulse_webapp
|
2 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cProfile import run
|
2 |
+
import streamlit as st
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
from complexRadar import ComplexRadar
|
9 |
+
import math
|
10 |
+
from zipfile import ZipFile
|
11 |
+
from glob import glob
|
12 |
+
import os
|
13 |
+
from BrainPulse import (dataset,
|
14 |
+
vector_space,
|
15 |
+
distance_matrix,
|
16 |
+
recurrence_quantification_analysis,
|
17 |
+
features_space,
|
18 |
+
plot)
|
19 |
+
|
20 |
+
# path
|
21 |
+
path = "./mne_data"
|
22 |
+
path2 = "./RPs"
|
23 |
+
|
24 |
+
# Remove the specified
|
25 |
+
# file path
|
26 |
+
try:
|
27 |
+
os.remove(path)
|
28 |
+
print("% s removed successfully" % path)
|
29 |
+
except:
|
30 |
+
pass
|
31 |
+
|
32 |
+
path = "./mne_data"
|
33 |
+
os.makedirs(path, exist_ok = True)
|
34 |
+
path1 = "./RPs"
|
35 |
+
os.makedirs(path1, exist_ok = True)
|
36 |
+
|
37 |
+
def run_computation(t_start, t_end, selected_subject, fir_filter, electrode_name, cut_freq, win_len, n_fft, percentile, run_list, options):
|
38 |
+
|
39 |
+
epochs, raw = dataset.eegbci_data(tmin=t_start, tmax=t_end,
|
40 |
+
subject=selected_subject,
|
41 |
+
filter_range=fir_filter,run_list=run_list)
|
42 |
+
|
43 |
+
s_rate = epochs.info['sfreq']
|
44 |
+
|
45 |
+
electrode_index = epochs.ch_names.index(electrode_name)
|
46 |
+
|
47 |
+
electrode_open = epochs.get_data()[0][electrode_index]
|
48 |
+
electrode_close = epochs.get_data()[1][electrode_index]
|
49 |
+
|
50 |
+
stft_open = vector_space.compute_stft(electrode_open,
|
51 |
+
n_fft=n_fft, win_len=win_len,
|
52 |
+
s_rate=epochs.info['sfreq'],
|
53 |
+
cut_freq=cut_freq)
|
54 |
+
|
55 |
+
stft_close = vector_space.compute_stft(electrode_close,
|
56 |
+
n_fft=n_fft, win_len=win_len,
|
57 |
+
s_rate=epochs.info['sfreq'],
|
58 |
+
cut_freq=cut_freq)
|
59 |
+
del raw
|
60 |
+
del electrode_open, electrode_close
|
61 |
+
# matrix_open = distance_matrix.EuclideanPyRQA_RP_stft(stft_open)
|
62 |
+
# matrix_close = distance_matrix.EuclideanPyRQA_RP_stft(stft_close)
|
63 |
+
matrix_open = distance_matrix.EuclideanPyRQA_RP_stft_cpu(stft_open)
|
64 |
+
matrix_close = distance_matrix.EuclideanPyRQA_RP_stft_cpu(stft_close)
|
65 |
+
|
66 |
+
nbr_open = np.percentile(matrix_open, percentile)
|
67 |
+
nbr_close = np.percentile(matrix_close, percentile)
|
68 |
+
|
69 |
+
matrix_open_binary = distance_matrix.set_epsilon(matrix_open,nbr_open)
|
70 |
+
matrix_close_binary = distance_matrix.set_epsilon(matrix_close,nbr_close)
|
71 |
+
|
72 |
+
del matrix_open, matrix_close
|
73 |
+
# matrix_open_to_plot = matrix_open_binary
|
74 |
+
# matrix_closed_to_plot = matrix_close_binary
|
75 |
+
|
76 |
+
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(16,8),dpi=200)
|
77 |
+
ax1.imshow(matrix_open_binary, cmap='Greys', origin='lower') #cividis
|
78 |
+
ax1.set_xticks(np.linspace(0, matrix_open_binary.shape[0] , ax1.get_xticks().shape[0]))
|
79 |
+
ax1.set_yticks(np.linspace(0, matrix_open_binary.shape[0] , ax1.get_xticks().shape[0]))
|
80 |
+
ax1.set_xticklabels([str(np.around(x,decimals=0)) for x in np.linspace(0, matrix_open_binary.shape[0] / s_rate, ax1.get_xticks().shape[0])])
|
81 |
+
ax1.set_yticklabels([str(np.around(x, decimals=0)) for x in np.linspace(0, matrix_open_binary.shape[0] / s_rate, ax1.get_yticks().shape[0])])
|
82 |
+
ax1.set_title(options[0]+' window size = 240 samples, ε = '+str(np.round(nbr_open,4)))
|
83 |
+
ax1.set_xlabel('time (s)')
|
84 |
+
ax1.set_ylabel('time (s)')
|
85 |
+
|
86 |
+
ax2.imshow(matrix_close_binary, cmap='Greys', origin='lower')
|
87 |
+
ax2.set_xticks(np.linspace(0, matrix_close_binary.shape[0] , ax1.get_xticks().shape[0]))
|
88 |
+
ax2.set_yticks(np.linspace(0, matrix_close_binary.shape[0] , ax1.get_xticks().shape[0]))
|
89 |
+
ax2.set_xticklabels([str(np.around(x,decimals=0)) for x in np.linspace(0, matrix_close_binary.shape[0] / s_rate, ax1.get_xticks().shape[0])])
|
90 |
+
ax2.set_yticklabels([str(np.around(x, decimals=0)) for x in np.linspace(0, matrix_close_binary.shape[0] / s_rate, ax2.get_yticks().shape[0])])
|
91 |
+
ax2.set_title(options[1]+' window size = 240 samples, ε = '+str(np.round(nbr_close,4)))
|
92 |
+
ax2.set_xlabel('time (s)')
|
93 |
+
ax2.set_ylabel('time (s)')
|
94 |
+
|
95 |
+
return fig, matrix_open_binary, matrix_close_binary, epochs, stft_open, stft_close
|
96 |
+
|
97 |
+
|
98 |
+
def plot_rqa(matrix_open_binary, matrix_close_binary, min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len,options):
|
99 |
+
|
100 |
+
categories = ['RR', 'DET', 'L', 'Lmax', 'DIV', 'Lentr', 'DET_RR', 'LAM', 'V', 'Vmax', 'Ventr', 'LAM_DET', 'W', 'Wmax', 'Wentr', 'TT']
|
101 |
+
|
102 |
+
result_rqa_open = recurrence_quantification_analysis.get_results(matrix_open_binary,min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len)
|
103 |
+
result_rqa_closed = recurrence_quantification_analysis.get_results(matrix_close_binary,min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len)
|
104 |
+
|
105 |
+
data = pd.DataFrame([result_rqa_open,result_rqa_closed], columns=categories)
|
106 |
+
|
107 |
+
data = data.drop(['RR', 'DIV', 'Lmax'],axis=1)
|
108 |
+
# print(data)
|
109 |
+
min_max_per_variable = data.describe().T[['min', 'max']]
|
110 |
+
min_max_per_variable['min'] = min_max_per_variable['min'].apply(lambda x: int(x))
|
111 |
+
min_max_per_variable['max'] = min_max_per_variable['max'].apply(lambda x: math.ceil(x))
|
112 |
+
# print(min_max_per_variable)
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
variables = data.columns
|
117 |
+
ranges = list(min_max_per_variable.itertuples(index=False, name=None))
|
118 |
+
|
119 |
+
format_cfg = {
|
120 |
+
#'axes_args':{'facecolor':'#84A8CD'},
|
121 |
+
'rad_ln_args': {'visible':True, 'linestyle':'dotted'},
|
122 |
+
'angle_ln_args':{'linestyle':'dotted'},
|
123 |
+
'outer_ring': {'visible':True, 'linestyle':'dotted'},
|
124 |
+
'rgrid_tick_lbls_args': {'fontsize':6},
|
125 |
+
'theta_tick_lbls': {'fontsize':9, 'backgroundcolor':'#355C7D', 'color':'#FFFFFF'},
|
126 |
+
'theta_tick_lbls_pad':3
|
127 |
+
}
|
128 |
+
|
129 |
+
|
130 |
+
fig = plt.figure(figsize=(5,3),dpi=100)
|
131 |
+
radar = ComplexRadar(fig, variables, ranges,n_ring_levels=3 ,show_scales=True, format_cfg=format_cfg)
|
132 |
+
|
133 |
+
|
134 |
+
custom_colors = ['#F67280', '#6C5B7B', '#355C7D']
|
135 |
+
k=0
|
136 |
+
for g,c in zip(data.index, custom_colors):
|
137 |
+
# radar.plot(data.loc[g].values, label=f"condition {g}", color=c, marker='o')
|
138 |
+
radar.plot(data.loc[g].values, label=options[k], color=c, marker='o')
|
139 |
+
radar.fill(data.loc[g].values, alpha=0.5, color=c)
|
140 |
+
k+=1
|
141 |
+
|
142 |
+
radar.use_legend(loc='upper left', bbox_to_anchor=(-0.4, 1.1), fontsize = 'xx-small') #, bbox_to_anchor=(0.15, -0.25),ncol=radar.plot_counter
|
143 |
+
|
144 |
+
return fig
|
145 |
+
|
146 |
+
def waterfall_spectrum(stft1, stft2, s_rate, cut_freq, options):
|
147 |
+
|
148 |
+
fig = plt.figure(figsize=(14, 12), dpi=150)
|
149 |
+
grid = plt.GridSpec(8, 8, hspace=0.0, wspace=3.5)
|
150 |
+
spectrogram1 = fig.add_subplot(grid[0:3, 0:4])
|
151 |
+
spectrogram2 = fig.add_subplot(grid[0:3, 4:])
|
152 |
+
|
153 |
+
spectrogram1.pcolormesh(stft1.T,cmap='viridis')
|
154 |
+
spectrogram1.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft1.shape[0], 5)))
|
155 |
+
spectrogram1.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, stft1.shape[0] / s_rate, spectrogram1.get_xticks().shape[0])])
|
156 |
+
spectrogram1.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft1.shape[1], 5)))
|
157 |
+
spectrogram1.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
|
158 |
+
spectrogram1.set_ylabel('Freq (Hz)', )
|
159 |
+
spectrogram1.set_xlabel('Time (s)', )
|
160 |
+
spectrogram1.set_title(options[0] + ' Spectrogram', )
|
161 |
+
|
162 |
+
spectrogram2.pcolormesh(stft2.T,cmap='viridis')
|
163 |
+
spectrogram2.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft2.shape[0], 5)))
|
164 |
+
spectrogram2.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, stft2.shape[0] / s_rate, spectrogram2.get_xticks().shape[0])])
|
165 |
+
spectrogram2.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft2.shape[1], 5)))
|
166 |
+
spectrogram2.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
|
167 |
+
spectrogram2.set_ylabel('Freq (Hz)', )
|
168 |
+
spectrogram2.set_xlabel('Time (s)', )
|
169 |
+
spectrogram2.set_title(options[1] +' Spectrogram', )
|
170 |
+
return fig
|
171 |
+
|
172 |
+
def save(matrix_open_binary, matrix_close_binary):
|
173 |
+
|
174 |
+
file_name_open = './RPs/subject-'+str(selected_subject)+'_electrode-'+electrode_name+'_percentile-'+str(percentile)+'_run-open_binary.npy'
|
175 |
+
np.save(file_name_open, np.asarray(matrix_close_binary, dtype=np.ubyte))
|
176 |
+
file_name_close = './RPs/subject-'+str(selected_subject)+'_electrode-'+electrode_name+'_percentile-'+str(percentile)+'_run-close_binary.npy'
|
177 |
+
np.save(file_name_close, np.asarray(matrix_close_binary, dtype=np.ubyte))
|
178 |
+
|
179 |
+
def download():
|
180 |
+
|
181 |
+
file_paths = glob('./RPs/*')
|
182 |
+
|
183 |
+
with ZipFile('download.zip','w') as zip:
|
184 |
+
for file in file_paths:
|
185 |
+
# writing each file one by one
|
186 |
+
zip.write(file)
|
187 |
+
|
188 |
+
return open('download.zip', 'rb')
|
189 |
+
|
190 |
+
# ---------------Settings--------------------
|
191 |
+
|
192 |
+
st.set_page_config(layout="wide")
|
193 |
+
st.title('BrainPulse Playground')
|
194 |
+
sidebar = st.sidebar
|
195 |
+
|
196 |
+
selected_subject = sidebar.slider('Select Subject', 0, 100, 25)
|
197 |
+
|
198 |
+
electrode_name = sidebar.selectbox(
|
199 |
+
'Select Electrode',
|
200 |
+
('FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FT8', 'T7', 'T8', 'T9', 'T10', 'TP7', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2', 'Iz'))
|
201 |
+
|
202 |
+
t_start, t_end = sidebar.slider(
|
203 |
+
'Select a time range in seconds',
|
204 |
+
0, 60, (0, 30), step=1)
|
205 |
+
|
206 |
+
f1, f2 = sidebar.slider(
|
207 |
+
'Select a FIR filter range',
|
208 |
+
0, 60, (2, 50), step=1)
|
209 |
+
fir_filter = [f1, f2]
|
210 |
+
|
211 |
+
cut_freq = f2
|
212 |
+
|
213 |
+
win_len = sidebar.slider('FFT window size', 0, 512, 170, step=1)
|
214 |
+
|
215 |
+
n_fft = sidebar.slider('numer of FFT bins', 0, 1024, 512, step=1)
|
216 |
+
|
217 |
+
min_vert_line_len = sidebar.slider('Minimum vertical line length', 0, 250, 2, step=1)
|
218 |
+
|
219 |
+
min_diagonal_line_len = sidebar.slider('Minimum diagonal line length', 0, 250, 2, step=1)
|
220 |
+
|
221 |
+
min_white_vert_line_len = sidebar.slider('Minimum white vertical line length', 0, 250, 2, step=1)
|
222 |
+
|
223 |
+
percentile = sidebar.slider('Precentile', 0, 100, 24, step=1)
|
224 |
+
|
225 |
+
sidebar.download_button('Download file', download(),file_name='archive.zip')
|
226 |
+
|
227 |
+
# ---------------Plot RPs--------------------
|
228 |
+
# runs_ = ['Baseline open eyes', 'Baseline closed eyes', 'Motor execution: left vs right hand', 'Motor imagery: left vs right hand',
|
229 |
+
# 'Motor execution: hands vs feet', 'Motor imagery: hands vs feet']
|
230 |
+
#
|
231 |
+
# options = st.multiselect('Select two runs to compare', runs_, ['Baseline open eyes', 'Baseline closed eyes'])
|
232 |
+
|
233 |
+
|
234 |
+
# run_list = []
|
235 |
+
#
|
236 |
+
# for v in options:
|
237 |
+
# run_list.append(runs_.index(v)+1)
|
238 |
+
# if len(run_list) <= 1:
|
239 |
+
# run_list = [1,2]
|
240 |
+
st.markdown('Baseline open eyes vs Baseline closed eyes')
|
241 |
+
options = ['Baseline open eyes', 'Baseline closed eyes']
|
242 |
+
run_list = [1,2]
|
243 |
+
|
244 |
+
rp_plot, matrix_open_binary, matrix_close_binary, epochs, stft1, stft2 = run_computation(t_start, t_end, selected_subject, fir_filter, electrode_name, cut_freq, win_len, n_fft, percentile, run_list,options)
|
245 |
+
st.write(rp_plot)
|
246 |
+
|
247 |
+
# ---------------Plot Spectrum--------------------
|
248 |
+
st.write(waterfall_spectrum(stft1, stft2, 160, cut_freq, options))
|
249 |
+
|
250 |
+
# ---------------Save RPs--------------------
|
251 |
+
if st.button('Save RPs as *.npy'):
|
252 |
+
save(matrix_open_binary, matrix_close_binary)
|
253 |
+
|
254 |
+
# ---------------Plot Radar--------------------
|
255 |
+
rqa_radar = plot_rqa(matrix_open_binary, matrix_close_binary, min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len, options)
|
256 |
+
st.write(rqa_radar)
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
complexRadar.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import textwrap
|
3 |
+
|
4 |
+
class ComplexRadar():
|
5 |
+
"""
|
6 |
+
Create a complex radar chart with different scales for each variable
|
7 |
+
Parameters
|
8 |
+
----------
|
9 |
+
fig : figure object
|
10 |
+
A matplotlib figure object to add the axes on
|
11 |
+
variables : list
|
12 |
+
A list of variables
|
13 |
+
ranges : list
|
14 |
+
A list of tuples (min, max) for each variable
|
15 |
+
n_ring_levels: int, defaults to 5
|
16 |
+
Number of ordinate or ring levels to draw
|
17 |
+
show_scales: bool, defaults to True
|
18 |
+
Indicates if we the ranges for each variable are plotted
|
19 |
+
format_cfg: dict, defaults to None
|
20 |
+
A dictionary with formatting configurations
|
21 |
+
"""
|
22 |
+
def __init__(self, fig, variables, ranges, n_ring_levels=5, show_scales=True, format_cfg=None):
|
23 |
+
|
24 |
+
# Default formatting
|
25 |
+
self.format_cfg = {
|
26 |
+
# Axes
|
27 |
+
# https://matplotlib.org/stable/api/figure_api.html
|
28 |
+
'axes_args': {},
|
29 |
+
# Tick labels on the scales
|
30 |
+
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.rgrids.html
|
31 |
+
'rgrid_tick_lbls_args': {'fontsize':8},
|
32 |
+
# Radial (circle) lines
|
33 |
+
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.grid.html
|
34 |
+
'rad_ln_args': {},
|
35 |
+
# Angle lines
|
36 |
+
# https://matplotlib.org/3.2.2/api/_as_gen/matplotlib.lines.Line2D.html#matplotlib.lines.Line2D
|
37 |
+
'angle_ln_args': {},
|
38 |
+
# Include last value (endpoint) on scale
|
39 |
+
'incl_endpoint':False,
|
40 |
+
# Variable labels (ThetaTickLabel)
|
41 |
+
'theta_tick_lbls':{'va':'top', 'ha':'center'},
|
42 |
+
'theta_tick_lbls_txt_wrap':15,
|
43 |
+
'theta_tick_lbls_brk_lng_wrds':False,
|
44 |
+
'theta_tick_lbls_pad':25,
|
45 |
+
# Outer ring
|
46 |
+
# https://matplotlib.org/stable/api/spines_api.html
|
47 |
+
'outer_ring':{'visible':True, 'color':'#d6d6d6'}
|
48 |
+
}
|
49 |
+
|
50 |
+
if format_cfg is not None:
|
51 |
+
self.format_cfg = { k:(format_cfg[k]) if k in format_cfg.keys() else (self.format_cfg[k])
|
52 |
+
for k in self.format_cfg.keys()}
|
53 |
+
|
54 |
+
|
55 |
+
# Calculate angles and create for each variable an axes
|
56 |
+
# Consider here the trick with having the first axes element twice (len+1)
|
57 |
+
angles = np.arange(0, 360, 360./len(variables))
|
58 |
+
axes = [fig.add_axes([0.1,0.1,0.9,0.9],
|
59 |
+
polar=True,
|
60 |
+
label = "axes{}".format(i),
|
61 |
+
**self.format_cfg['axes_args']) for i in range(len(variables)+1)]
|
62 |
+
|
63 |
+
# Ensure clockwise rotation (first variable at the top N)
|
64 |
+
for ax in axes:
|
65 |
+
ax.set_theta_zero_location('N')
|
66 |
+
ax.set_theta_direction(-1)
|
67 |
+
ax.set_axisbelow(True)
|
68 |
+
|
69 |
+
# Writing the ranges on each axes
|
70 |
+
for i, ax in enumerate(axes):
|
71 |
+
|
72 |
+
# Here we do the trick by repeating the first iteration
|
73 |
+
j = 0 if (i==0 or i==1) else i-1
|
74 |
+
ax.set_ylim(*ranges[j])
|
75 |
+
# Set endpoint to True if you like to have values right before the last circle
|
76 |
+
grid = np.linspace(*ranges[j], num=n_ring_levels,
|
77 |
+
endpoint=self.format_cfg['incl_endpoint'])
|
78 |
+
gridlabel = ["{}".format(round(x,2)) for x in grid]
|
79 |
+
gridlabel[0] = "" # remove values from the center
|
80 |
+
lines, labels = ax.set_rgrids(grid,
|
81 |
+
labels=gridlabel,
|
82 |
+
angle=angles[j],
|
83 |
+
**self.format_cfg['rgrid_tick_lbls_args']
|
84 |
+
)
|
85 |
+
|
86 |
+
ax.set_ylim(*ranges[j])
|
87 |
+
ax.spines["polar"].set_visible(False)
|
88 |
+
ax.grid(visible=False)
|
89 |
+
|
90 |
+
if show_scales == False:
|
91 |
+
ax.set_yticklabels([])
|
92 |
+
|
93 |
+
# Set all axes except the first one unvisible
|
94 |
+
for ax in axes[1:]:
|
95 |
+
ax.patch.set_visible(False)
|
96 |
+
ax.xaxis.set_visible(False)
|
97 |
+
|
98 |
+
# Setting the attributes
|
99 |
+
self.angle = np.deg2rad(np.r_[angles, angles[0]])
|
100 |
+
self.ranges = ranges
|
101 |
+
self.ax = axes[0]
|
102 |
+
self.ax1 = axes[1]
|
103 |
+
self.plot_counter = 0
|
104 |
+
|
105 |
+
|
106 |
+
# Draw (inner) circles and lines
|
107 |
+
self.ax.yaxis.grid(**self.format_cfg['rad_ln_args'])
|
108 |
+
# Draw outer circle
|
109 |
+
self.ax.spines['polar'].set(**self.format_cfg['outer_ring'])
|
110 |
+
# Draw angle lines
|
111 |
+
self.ax.xaxis.grid(**self.format_cfg['angle_ln_args'])
|
112 |
+
|
113 |
+
# ax1 is the duplicate of axes[0] (self.ax)
|
114 |
+
# Remove everything from ax1 except the plot itself
|
115 |
+
self.ax1.axis('off')
|
116 |
+
self.ax1.set_zorder(9)
|
117 |
+
|
118 |
+
# Create the outer labels for each variable
|
119 |
+
l, text = self.ax.set_thetagrids(angles, labels=variables)
|
120 |
+
|
121 |
+
# Beautify them
|
122 |
+
labels = [t.get_text() for t in self.ax.get_xticklabels()]
|
123 |
+
labels = ['\n'.join(textwrap.wrap(l, self.format_cfg['theta_tick_lbls_txt_wrap'],
|
124 |
+
break_long_words=self.format_cfg['theta_tick_lbls_brk_lng_wrds'])) for l in labels]
|
125 |
+
self.ax.set_xticklabels(labels, **self.format_cfg['theta_tick_lbls'])
|
126 |
+
|
127 |
+
for t,a in zip(self.ax.get_xticklabels(),angles):
|
128 |
+
if a == 0:
|
129 |
+
t.set_ha('center')
|
130 |
+
elif a > 0 and a < 180:
|
131 |
+
t.set_ha('left')
|
132 |
+
elif a == 180:
|
133 |
+
t.set_ha('center')
|
134 |
+
else:
|
135 |
+
t.set_ha('right')
|
136 |
+
|
137 |
+
self.ax.tick_params(axis='both', pad=self.format_cfg['theta_tick_lbls_pad'])
|
138 |
+
|
139 |
+
|
140 |
+
def _scale_data(self, data, ranges):
|
141 |
+
"""Scales data[1:] to ranges[0]"""
|
142 |
+
for d, (y1, y2) in zip(data[1:], ranges[1:]):
|
143 |
+
assert (y1 <= d <= y2) or (y2 <= d <= y1)
|
144 |
+
x1, x2 = ranges[0]
|
145 |
+
d = data[0]
|
146 |
+
sdata = [d]
|
147 |
+
for d, (y1, y2) in zip(data[1:], ranges[1:]):
|
148 |
+
sdata.append((d-y1) / (y2-y1) * (x2 - x1) + x1)
|
149 |
+
return sdata
|
150 |
+
|
151 |
+
def plot(self, data, *args, **kwargs):
|
152 |
+
"""Plots a line"""
|
153 |
+
sdata = self._scale_data(data, self.ranges)
|
154 |
+
self.ax1.plot(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs)
|
155 |
+
self.plot_counter = self.plot_counter+1
|
156 |
+
|
157 |
+
def fill(self, data, *args, **kwargs):
|
158 |
+
"""Plots an area"""
|
159 |
+
sdata = self._scale_data(data, self.ranges)
|
160 |
+
self.ax1.fill(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs)
|
161 |
+
|
162 |
+
def use_legend(self, *args, **kwargs):
|
163 |
+
"""Shows a legend"""
|
164 |
+
self.ax1.legend(*args, **kwargs)
|
165 |
+
|
166 |
+
def set_title(self, title, pad=25, **kwargs):
|
167 |
+
"""Set a title"""
|
168 |
+
self.ax.set_title(title,pad=pad, **kwargs)
|
papaerspace_config.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image: datalab108/cookiewriter64:test
|
2 |
+
port: 5005
|
3 |
+
resources:
|
4 |
+
replicas: 1
|
5 |
+
instanceType: P4000
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
torchaudio
|
4 |
+
umap-learn
|
5 |
+
umap-learn[plot]
|
6 |
+
PyRQA
|
7 |
+
pandas
|
8 |
+
mne
|
9 |
+
plotly
|
10 |
+
streamlit
|
test.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
while True:
|
4 |
+
print("Hello World")
|
5 |
+
time.sleep(10)
|