Łukasz Furman commited on
Commit
a59bdc5
1 Parent(s): e592727

update app.py

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ *.pyc
3
+ *.zip
BrainPulse/.DS_Store ADDED
Binary file (6.15 kB). View file
 
BrainPulse/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # BrainPulse
2
+
BrainPulse/__init__.py ADDED
File without changes
BrainPulse/dataset.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mne import Epochs, pick_types, events_from_annotations, create_info, EpochsArray
2
+ from mne.channels import make_standard_montage
3
+ from mne.io import concatenate_raws, read_raw_edf, RawArray
4
+ from mne.datasets import eegbci
5
+ import scipy.io
6
+ import numpy as np
7
+ import os
8
+ import mne
9
+
10
+
11
+
12
+ def eegbci_data(tmin, tmax, subject, filter_range = None, run_list = None):
13
+
14
+ event_id = dict(ev=0)
15
+ runs = run_list # open eyes vs closed eyes
16
+
17
+ # raw_fnames = eegbci.load_data(subject, runs, path='./datasets', update_path=True)
18
+ raw_fnames = eegbci.load_data(subject, runs, update_path=True, path='../mne_data')
19
+ raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
20
+ eegbci.standardize(raw) # set channel names
21
+ montage = make_standard_montage('standard_1005')
22
+ raw.set_montage(montage)
23
+
24
+ # strip channel names of "." characters
25
+ raw.rename_channels(lambda x: x.strip('.'))
26
+ raw.set_eeg_reference(projection=True)
27
+ raw.apply_proj()
28
+ # Apply band-pass filter
29
+ if filter_range != None:
30
+ raw.filter(filter_range[0], filter_range[1], fir_design='firwin', skip_by_annotation='edge')
31
+
32
+ events, _ = events_from_annotations(raw, event_id=dict(T0=0))
33
+
34
+ picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
35
+ exclude='bads')
36
+ # Read epochs
37
+ epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
38
+ baseline=None, preload=True)
39
+ # print(epochs.get_data())
40
+ return epochs, raw
41
+
42
+
43
+
44
+
45
+ def eegMCI_data(stim_type, subject, run_type, path_to_data_folder):
46
+ """
47
+ eegMCI_data does read 3d data from .mat file structured as ( n epochs, n channels, n samples ).
48
+
49
+ :param stim_type: 'AV','A','V'
50
+ :param subject: select a subject from dataset
51
+ :param run_type: chose run type of selected subject, possible types ( button.dat, test.dat_1, test.dat_2, train.dat )
52
+ :param path_to_data_folder: root path to data folder
53
+
54
+ :return: epochs, raw mne objects
55
+ """
56
+
57
+ stim_type_long = ['audiovisual','audio','visual']
58
+
59
+ if stim_type == 'AV':
60
+ stim_typeL = stim_type_long[0]
61
+ elif stim_type == 'A':
62
+ stim_typeL = stim_type_long[1]
63
+ else:
64
+ stim_typeL = stim_type_long[2]
65
+
66
+ targets = ['allNTARGETS','allTARGETS']
67
+ path = os.path.join(path_to_data_folder,stim_typeL,'s'+str(subject)+'_'+stim_type+'_'+run_type+'.mat')
68
+
69
+ mat = scipy.io.loadmat(path)
70
+
71
+ raw_no_targets = mat[targets[0]].transpose((0, 2, 1))
72
+ raw_targets = mat[targets[1]].transpose((0, 2, 1))
73
+
74
+ electrodes_l = np.concatenate(mat['electrodes'])
75
+ electrodes_l = [str(x).replace('[','').replace(']','').replace("'",'') for x in electrodes_l]
76
+ ch_types = ['eeg'] * 16
77
+ info = create_info(electrodes_l, ch_types=ch_types, sfreq=512)
78
+ info.set_montage('standard_1020')
79
+
80
+ epochs_no_target = EpochsArray(raw_no_targets, info, tmin=-0.2)
81
+ epochs_target = EpochsArray(raw_targets, info, tmin=-0.2)
82
+
83
+ erp_no_target = epochs_no_target.average()
84
+ erp_target = epochs_target.average()
85
+
86
+ epochs = np.array([erp_no_target.get_data(),erp_target.get_data()],dtype=np.object)
87
+ raw = [raw_no_targets,raw_targets]
88
+ print(electrodes_l)
89
+ return EpochsArray(epochs,info,tmin=-0.2), np.array(raw)
90
+
91
+
92
+
93
+
94
+ def eegMCI_data_epochs(stim_type1,stim_type2, subject, run_type, path_to_data_folder):
95
+ """
96
+ eegMCI_data does read 3d data from .mat file structured as ( n epochs, n channels, n samples ).
97
+
98
+ :param stim_type1: 'AV','A','V'
99
+ :param stim_type2: 'AV','A','V'
100
+ :param subject: select a subject from dataset
101
+ :param run_type: chose run type of selected subject, possible types ( button.dat, test.dat_1, test.dat_2, train.dat )
102
+ :param path_to_data_folder: root path to data folder
103
+
104
+ :return: epochs, raw mne objects
105
+ """
106
+
107
+ stim_type_long = ['audiovisual','audio','visual']
108
+
109
+ if stim_type1 == 'AV':
110
+ stim_typeL1 = stim_type_long[0]
111
+ elif stim_type1 == 'A':
112
+ stim_typeL1 = stim_type_long[1]
113
+ else:
114
+ stim_typeL1 = stim_type_long[2]
115
+
116
+ if stim_type2 == 'AV':
117
+ stim_typeL2 = stim_type_long[0]
118
+ elif stim_type2 == 'A':
119
+ stim_typeL2 = stim_type_long[1]
120
+ else:
121
+ stim_typeL2 = stim_type_long[2]
122
+
123
+ targets = ['allNTARGETS','allTARGETS']
124
+ path1 = os.path.join(path_to_data_folder,stim_typeL1,'s'+str(subject)+'_'+stim_type1+'_'+run_type+'.mat')
125
+ path2 = os.path.join(path_to_data_folder,stim_typeL2,'s'+str(subject)+'_'+stim_type2+'_'+run_type+'.mat')
126
+
127
+ mat1 = scipy.io.loadmat(path1)
128
+ mat2 = scipy.io.loadmat(path2)
129
+
130
+ raw_no_targets1 = mat1[targets[0]].transpose((0, 2, 1))
131
+ raw_targets1 = mat1[targets[1]].transpose((0, 2, 1))
132
+
133
+ raw_no_targets2 = mat2[targets[0]].transpose((0, 2, 1))
134
+ raw_targets2 = mat2[targets[1]].transpose((0, 2, 1))
135
+
136
+ electrodes_l = np.concatenate(mat1['electrodes'])
137
+ electrodes_l = [str(x).replace('[','').replace(']','').replace("'",'') for x in electrodes_l]
138
+ ch_types = ['eeg'] * 16
139
+ info = create_info(electrodes_l, ch_types=ch_types, sfreq=512)
140
+ info.set_montage('standard_1020')
141
+
142
+ # epochs_no_target = EpochsArray(raw_no_targets1, info, tmin=-0.2)
143
+ epochs_target1 = EpochsArray(raw_targets1, info, tmin=-0.2, baseline=(-0.2,0))
144
+ epochs_target2 = EpochsArray(raw_targets2, info, tmin=-0.2, baseline=(-0.2,0))
145
+
146
+ # erp_no_target = epochs_no_target.average()
147
+ erp_target1 = epochs_target1.average()
148
+ erp_target2 = epochs_target2.average()
149
+
150
+ evoked = np.array([erp_target1.get_data(),erp_target2.get_data()],dtype=np.object)
151
+ epochs = [epochs_target1,epochs_target2]
152
+ print(electrodes_l)
153
+ return EpochsArray(evoked,info,tmin=-0.2), epochs
BrainPulse/dependency.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import pynndescent
2
+ import numpy
3
+ import pyrqa
4
+ import pandas
5
+ import sklearn
6
+ import scipy
7
+ import torch
8
+ import umap
BrainPulse/distance_matrix.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numba import jit
2
+ from pynndescent.distances import wasserstein_1d, spearmanr, euclidean
3
+ import numpy as np
4
+ from pyrqa.computation import RPComputation
5
+ from pyrqa.time_series import TimeSeries, EmbeddedSeries
6
+ from pyrqa.settings import Settings
7
+ from pyrqa.analysis_type import Classic
8
+ from pyrqa.metric import EuclideanMetric
9
+ # from pyrqa.metric import Sigmoid
10
+ # from pyrqa.metric import Cosine
11
+ from pyrqa.neighbourhood import Unthresholded
12
+ import torch
13
+
14
+
15
+ #
16
+ # @jit(parallel=True)
17
+ @jit(nopython=True)
18
+ def EuclideanPyRQA_RP_stft_cpu(stft_):
19
+ result = np.zeros((stft_.shape[0], stft_.shape[0]))
20
+
21
+ for i, fft in enumerate(stft_):
22
+ for j, v in enumerate(stft_):
23
+ d = euclidean(fft, v)
24
+ result[i, j] = d
25
+ return result
26
+
27
+ @jit(nopython=True)
28
+ def wasserstein_squereform(stft_):
29
+ result = np.zeros((stft_.shape[0], stft_.shape[0]))
30
+
31
+ for i, fft in enumerate(stft_):
32
+ for j, v in enumerate(stft_):
33
+ d = wasserstein_1d(fft, v)
34
+ result[i, j] = d
35
+ return result
36
+
37
+ @jit(nopython=True)
38
+ def spearmanr_squereform(stft_):
39
+ result = np.zeros((stft_.shape[0], stft_.shape[0]))
40
+
41
+ for i, fft in enumerate(stft_):
42
+ for j, v in enumerate(stft_):
43
+ d = spearmanr(fft, v)
44
+ result[i, j] = d
45
+ return result
46
+
47
+ @jit(nopython=True)
48
+ def wasserstein_squereform_binary(stft_, eps_):
49
+ result = np.zeros((stft_.shape[0], stft_.shape[0]))
50
+
51
+ for i, fft in enumerate(stft_):
52
+ for j, v in enumerate(stft_):
53
+ d = wasserstein_1d(fft, v)
54
+ if d <= eps_:
55
+ dist = 1
56
+ else:
57
+ dist = 0
58
+
59
+ result[i, j] = dist
60
+ return result
61
+
62
+
63
+ @jit(nopython=True)
64
+ def wasserstein_1d_array(stft_):
65
+ result = np.zeros((stft_.shape[0] * stft_.shape[0]))
66
+ k = 0
67
+ for i, fft in enumerate(stft_):
68
+ for j, v in enumerate(stft_):
69
+ d = wasserstein_1d(fft, v)
70
+ result[k] = d
71
+ k += 1
72
+
73
+ return result
74
+
75
+ def set_epsilon(matrix, eps):
76
+ return np.heaviside(eps - matrix, 0)
77
+
78
+
79
+ def EuclideanPyRQA_RP(signal, embedding = 2,timedelay = 9):
80
+ time_series = TimeSeries(signal,
81
+ embedding_dimension=embedding,
82
+ time_delay=timedelay)
83
+ settings = Settings(time_series,
84
+ analysis_type=Classic,
85
+ neighbourhood=Unthresholded(),
86
+ similarity_measure=EuclideanMetric,
87
+ theiler_corrector=1)
88
+
89
+ computation = RPComputation.create(settings)
90
+ result = computation.run()
91
+ return result.recurrence_matrix
92
+
93
+
94
+ def EuclideanPyRQA_RP_stft(signal, embedding = 2,timedelay = 9):
95
+ time_series = EmbeddedSeries(signal)
96
+ settings = Settings(time_series,
97
+ analysis_type=Classic,
98
+ neighbourhood=Unthresholded(),
99
+ similarity_measure=EuclideanMetric,
100
+ theiler_corrector=1)
101
+
102
+ computation = RPComputation.create(settings)
103
+ result = computation.run()
104
+ return result.recurrence_matrix
BrainPulse/event.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Event segmentation using a Hidden Markov Model
2
+
3
+ Adapted from the brainiak package for this workshop.
4
+ See https://brainiak.org/ for full documentation."""
5
+
6
+ import numpy as np
7
+ from scipy import stats
8
+ import logging
9
+ import copy
10
+ from sklearn.base import BaseEstimator
11
+ from sklearn.utils.validation import check_is_fitted, check_array
12
+ from sklearn.exceptions import NotFittedError
13
+ import itertools
14
+
15
+
16
+ def masked_log(x):
17
+ y = np.empty(x.shape, dtype=x.dtype)
18
+ lim = x.shape[0]
19
+ for i in range(lim):
20
+ if x[i] <= 0:
21
+ y[i] = float('-inf')
22
+ else:
23
+ y[i] = np.log(x[i])
24
+ return y
25
+
26
+
27
+ class EventSegment(BaseEstimator):
28
+
29
+ def _default_var_schedule(step):
30
+ return 4 * (0.98 ** (step - 1))
31
+
32
+ def __init__(self, n_events=2,
33
+ step_var=_default_var_schedule,
34
+ n_iter=500, event_chains=None,
35
+ split_merge=False, split_merge_proposals=1):
36
+ self.n_events = n_events
37
+ self.step_var = step_var
38
+ self.n_iter = n_iter
39
+ self.split_merge = split_merge
40
+ self.split_merge_proposals = split_merge_proposals
41
+ if event_chains is None:
42
+ self.event_chains = np.zeros(n_events)
43
+ else:
44
+ self.event_chains = event_chains
45
+
46
+ def _fit_validate(self, X):
47
+ if len(np.unique(self.event_chains)) > 1:
48
+ raise RuntimeError("Cannot fit chains, use set_event_patterns")
49
+
50
+ # Copy X into a list and transpose
51
+ X = copy.deepcopy(X)
52
+ if type(X) is not list:
53
+ X = [X]
54
+ for i in range(len(X)):
55
+ X[i] = check_array(X[i])
56
+ X[i] = X[i].T
57
+
58
+ # Check that number of voxels is consistent across datasets
59
+ n_dim = X[0].shape[0]
60
+ for i in range(len(X)):
61
+ assert (X[i].shape[0] == n_dim)
62
+
63
+ # Double-check that data is z-scored in time
64
+ for i in range(len(X)):
65
+ X[i] = stats.zscore(X[i], axis=1, ddof=1)
66
+
67
+ return X
68
+
69
+ def fit(self, X, y=None):
70
+ seg = []
71
+ X = self._fit_validate(X)
72
+ n_train = len(X)
73
+ n_dim = X[0].shape[0]
74
+ self.classes_ = np.arange(self.n_events)
75
+
76
+ # Initialize variables for fitting
77
+ log_gamma = []
78
+ for i in range(n_train):
79
+ log_gamma.append(np.zeros((X[i].shape[1], self.n_events)))
80
+ step = 1
81
+ best_ll = float("-inf")
82
+ self.ll_ = np.empty((0, n_train))
83
+ while step <= self.n_iter:
84
+ iteration_var = self.step_var(step)
85
+
86
+ # Based on the current segmentation, compute the mean pattern
87
+ # for each event
88
+ seg_prob = [np.exp(lg) / np.sum(np.exp(lg), axis=0)
89
+ for lg in log_gamma]
90
+ mean_pat = np.empty((n_train, n_dim, self.n_events))
91
+ for i in range(n_train):
92
+ mean_pat[i, :, :] = X[i].dot(seg_prob[i])
93
+ mean_pat = np.mean(mean_pat, axis=0)
94
+
95
+ # Based on the current mean patterns, compute the event
96
+ # segmentation
97
+ self.ll_ = np.append(self.ll_, np.empty((1, n_train)), axis=0)
98
+ for i in range(n_train):
99
+ logprob = self._logprob_obs(X[i], mean_pat, iteration_var)
100
+ log_gamma[i], self.ll_[-1, i] = self._forward_backward(logprob)
101
+
102
+ if step > 1 and self.split_merge:
103
+ curr_ll = np.mean(self.ll_[-1, :])
104
+ self.ll_[-1, :], log_gamma, mean_pat = \
105
+ self._split_merge(X, log_gamma, iteration_var, curr_ll)
106
+
107
+ # If log-likelihood has started decreasing, undo last step and stop
108
+ if np.mean(self.ll_[-1, :]) < best_ll:
109
+ self.ll_ = self.ll_[:-1, :]
110
+ break
111
+
112
+ self.segments_ = [np.exp(lg) for lg in log_gamma]
113
+ self.event_var_ = iteration_var
114
+ self.event_pat_ = mean_pat
115
+ best_ll = np.mean(self.ll_[-1, :])
116
+
117
+ seg.append(self.segments_[0].copy())
118
+
119
+ step += 1
120
+
121
+ return seg
122
+
123
+ def _logprob_obs(self, data, mean_pat, var):
124
+ n_vox = data.shape[0]
125
+ t = data.shape[1]
126
+
127
+ # z-score both data and mean patterns in space, so that Gaussians
128
+ # are measuring Pearson correlations and are insensitive to overall
129
+ # activity changes
130
+ data_z = stats.zscore(data, axis=0, ddof=1)
131
+ mean_pat_z = stats.zscore(mean_pat, axis=0, ddof=1)
132
+
133
+ logprob = np.empty((t, self.n_events))
134
+
135
+ if type(var) is not np.ndarray:
136
+ var = var * np.ones(self.n_events)
137
+
138
+ for k in range(self.n_events):
139
+ logprob[:, k] = -0.5 * n_vox * np.log(
140
+ 2 * np.pi * var[k]) - 0.5 * np.sum(
141
+ (data_z.T - mean_pat_z[:, k]).T ** 2, axis=0) / var[k]
142
+
143
+ logprob /= n_vox
144
+ return logprob
145
+
146
+ def _forward_backward(self, logprob):
147
+ logprob = copy.copy(logprob)
148
+ t = logprob.shape[0]
149
+ logprob = np.hstack((logprob, float("-inf") * np.ones((t, 1))))
150
+
151
+ # Initialize variables
152
+ log_scale = np.zeros(t)
153
+ log_alpha = np.zeros((t, self.n_events + 1))
154
+ log_beta = np.zeros((t, self.n_events + 1))
155
+
156
+ # Set up transition matrix, with final sink state
157
+ self.p_start = np.zeros(self.n_events + 1)
158
+ self.p_end = np.zeros(self.n_events + 1)
159
+ self.P = np.zeros((self.n_events + 1, self.n_events + 1))
160
+ label_ind = np.unique(self.event_chains, return_inverse=True)[1]
161
+ n_chains = np.max(label_ind) + 1
162
+
163
+ # For each chain of events, link them together and then to sink state
164
+ for c in range(n_chains):
165
+ chain_ind = np.nonzero(label_ind == c)[0]
166
+ self.p_start[chain_ind[0]] = 1 / n_chains
167
+ self.p_end[chain_ind[-1]] = 1 / n_chains
168
+
169
+ p_trans = (len(chain_ind) - 1) / t
170
+ if p_trans >= 1:
171
+ raise ValueError('Too few timepoints')
172
+ for i in range(len(chain_ind)):
173
+ self.P[chain_ind[i], chain_ind[i]] = 1 - p_trans
174
+ if i < len(chain_ind) - 1:
175
+ self.P[chain_ind[i], chain_ind[i+1]] = p_trans
176
+ else:
177
+ self.P[chain_ind[i], -1] = p_trans
178
+ self.P[-1, -1] = 1
179
+
180
+ # Forward pass
181
+ for i in range(t):
182
+ if i == 0:
183
+ log_alpha[0, :] = self._log(self.p_start) + logprob[0, :]
184
+ else:
185
+ log_alpha[i, :] = self._log(np.exp(log_alpha[i - 1, :])
186
+ .dot(self.P)) + logprob[i, :]
187
+
188
+ log_scale[i] = np.logaddexp.reduce(log_alpha[i, :])
189
+ log_alpha[i] -= log_scale[i]
190
+
191
+ # Backward pass
192
+ log_beta[-1, :] = self._log(self.p_end) - log_scale[-1]
193
+ for i in reversed(range(t - 1)):
194
+ obs_weighted = log_beta[i + 1, :] + logprob[i + 1, :]
195
+ offset = np.max(obs_weighted)
196
+ log_beta[i, :] = offset + self._log(
197
+ np.exp(obs_weighted - offset).dot(self.P.T)) - log_scale[i]
198
+
199
+ # Combine and normalize
200
+ log_gamma = log_alpha + log_beta
201
+ log_gamma -= np.logaddexp.reduce(log_gamma, axis=1, keepdims=True)
202
+
203
+ ll = np.sum(log_scale[:(t - 1)]) + np.logaddexp.reduce(
204
+ log_alpha[-1, :] + log_scale[-1] + self._log(self.p_end))
205
+
206
+ log_gamma = log_gamma[:, :-1]
207
+
208
+ return log_gamma, ll
209
+
210
+ def _log(self, x):
211
+ xshape = x.shape
212
+ _x = x.flatten()
213
+ y = masked_log(_x)
214
+ return y.reshape(xshape)
215
+
216
+ def set_event_patterns(self, event_pat):
217
+ if event_pat.shape[1] != self.n_events:
218
+ raise ValueError(("Number of columns of event_pat must match "
219
+ "number of events"))
220
+ self.event_pat_ = event_pat.copy()
221
+
222
+ def find_events(self, testing_data, var=None, scramble=False):
223
+ if var is None:
224
+ if not hasattr(self, 'event_var_'):
225
+ raise NotFittedError(("Event variance must be provided, if "
226
+ "not previously set by fit()"))
227
+ else:
228
+ var = self.event_var_
229
+
230
+ if not hasattr(self, 'event_pat_'):
231
+ raise NotFittedError(("The event patterns must first be set "
232
+ "by fit() or set_event_patterns()"))
233
+ if scramble:
234
+ mean_pat = self.event_pat_[:, np.random.permutation(self.n_events)]
235
+ else:
236
+ mean_pat = self.event_pat_
237
+
238
+ logprob = self._logprob_obs(testing_data.T, mean_pat, var)
239
+ lg, test_ll = self._forward_backward(logprob)
240
+ segments = np.exp(lg)
241
+
242
+ return segments, test_ll
243
+
244
+ def predict(self, X):
245
+ check_is_fitted(self, ["event_pat_", "event_var_"])
246
+ X = check_array(X)
247
+ segments, test_ll = self.find_events(X)
248
+ return np.argmax(segments, axis=1)
249
+
250
+ def calc_weighted_event_var(self, D, weights, event_pat):
251
+ Dz = stats.zscore(D, axis=1, ddof=1)
252
+ ev_var = np.empty(event_pat.shape[1])
253
+ for e in range(event_pat.shape[1]):
254
+ # Only compute variances for weights > 0.1% of max weight
255
+ nz = weights[:, e] > np.max(weights[:, e])/1000
256
+ sumsq = np.dot(weights[nz, e],
257
+ np.sum(np.square(Dz[nz, :] -
258
+ event_pat[:, e]), axis=1))
259
+ ev_var[e] = sumsq/(np.sum(weights[nz, e]) -
260
+ np.sum(np.square(weights[nz, e])) /
261
+ np.sum(weights[nz, e]))
262
+ ev_var = ev_var / D.shape[1]
263
+ return ev_var
264
+
265
+ def model_prior(self, t):
266
+ lg, test_ll = self._forward_backward(np.zeros((t, self.n_events)))
267
+ segments = np.exp(lg)
268
+
269
+ return segments, test_ll
270
+
271
+ def _split_merge(self, X, log_gamma, iteration_var, curr_ll):
272
+ # Compute current probabilities and mean patterns
273
+ n_train = len(X)
274
+ n_dim = X[0].shape[0]
275
+
276
+ seg_prob = [np.exp(lg) / np.sum(np.exp(lg), axis=0)
277
+ for lg in log_gamma]
278
+ mean_pat = np.empty((n_train, n_dim, self.n_events))
279
+ for i in range(n_train):
280
+ mean_pat[i, :, :] = X[i].dot(seg_prob[i])
281
+ mean_pat = np.mean(mean_pat, axis=0)
282
+
283
+ # For each event, merge its probability distribution
284
+ # with the next event, and also split its probability
285
+ # distribution at its median into two separate events.
286
+ # Use these new event probability distributions to compute
287
+ # merged and split event patterns.
288
+ merge_pat = np.empty((n_train, n_dim, self.n_events))
289
+ split_pat = np.empty((n_train, n_dim, 2 * self.n_events))
290
+ for i, sp in enumerate(seg_prob): # Iterate over datasets
291
+ m_evprob = np.zeros((sp.shape[0], sp.shape[1]))
292
+ s_evprob = np.zeros((sp.shape[0], 2 * sp.shape[1]))
293
+ cs = np.cumsum(sp, axis=0)
294
+ for e in range(sp.shape[1]):
295
+ # Split distribution at midpoint and normalize each half
296
+ mid = np.where(cs[:, e] >= 0.5)[0][0]
297
+ cs_first = cs[mid, e] - sp[mid, e]
298
+ cs_second = 1 - cs_first
299
+ s_evprob[:mid, 2 * e] = sp[:mid, e] / cs_first
300
+ s_evprob[mid:, 2 * e + 1] = sp[mid:, e] / cs_second
301
+
302
+ # Merge distribution with next event distribution
303
+ m_evprob[:, e] = sp[:, e:(e + 2)].mean(1)
304
+
305
+ # Weight data by distribution to get event patterns
306
+ merge_pat[i, :, :] = X[i].dot(m_evprob)
307
+ split_pat[i, :, :] = X[i].dot(s_evprob)
308
+
309
+ # Average across datasets
310
+ merge_pat = np.mean(merge_pat, axis=0)
311
+ split_pat = np.mean(split_pat, axis=0)
312
+
313
+ # Correlate the current event patterns with the split and
314
+ # merged patterns
315
+ merge_corr = np.zeros(self.n_events)
316
+ split_corr = np.zeros(self.n_events)
317
+ for e in range(self.n_events):
318
+ split_corr[e] = np.corrcoef(mean_pat[:, e],
319
+ split_pat[:, (2 * e):(2 * e + 2)],
320
+ rowvar=False)[0, 1:3].max()
321
+ merge_corr[e] = np.corrcoef(merge_pat[:, e],
322
+ mean_pat[:, e:(e + 2)],
323
+ rowvar=False)[0, 1:3].min()
324
+ merge_corr = merge_corr[:-1]
325
+
326
+ # Find best merge/split candidates
327
+ # A high value of merge_corr indicates that a pair of events are
328
+ # very similar to their merged pattern, and are good candidates for
329
+ # being merged.
330
+ # A low value of split_corr indicates that an event's pattern is
331
+ # very dissimilar from the patterns in its first and second half,
332
+ # and is a good candidate for being split.
333
+ best_merge = np.flipud(np.argsort(merge_corr))
334
+ best_merge = best_merge[:self.split_merge_proposals]
335
+ best_split = np.argsort(split_corr)
336
+ best_split = best_split[:self.split_merge_proposals]
337
+
338
+ # For every pair of merge/split candidates, attempt the merge/split
339
+ # and measure the log-likelihood. If any are better than curr_ll,
340
+ # accept this best merge/split
341
+ mean_pat_last = mean_pat.copy()
342
+ return_ll = curr_ll
343
+ return_lg = copy.deepcopy(log_gamma)
344
+ return_mp = mean_pat.copy()
345
+ for m_e, s_e in itertools.product(best_merge, best_split):
346
+ if m_e == s_e or m_e+1 == s_e:
347
+ # Don't attempt to merge/split same event
348
+ continue
349
+
350
+ # Construct new set of patterns with merge/split
351
+ mean_pat_ms = np.delete(mean_pat_last, s_e, axis=1)
352
+ mean_pat_ms = np.insert(mean_pat_ms, [s_e, s_e],
353
+ split_pat[:, (2 * s_e):(2 * s_e + 2)],
354
+ axis=1)
355
+ mean_pat_ms = np.delete(mean_pat_ms,
356
+ [m_e + (s_e < m_e), m_e + (s_e < m_e) + 1],
357
+ axis=1)
358
+ mean_pat_ms = np.insert(mean_pat_ms, m_e + (s_e < m_e),
359
+ merge_pat[:, m_e], axis=1)
360
+
361
+ # Measure log-likelihood with these new patterns
362
+ ll_ms = np.zeros(n_train)
363
+ log_gamma_ms = list()
364
+ for i in range(n_train):
365
+ logprob = self._logprob_obs(X[i],
366
+ mean_pat_ms, iteration_var)
367
+ lg, ll_ms[i] = self._forward_backward(logprob)
368
+ log_gamma_ms.append(lg)
369
+
370
+ # If better than best ll so far, save to return to fit()
371
+ if ll_ms.mean() > return_ll:
372
+ return_mp = mean_pat_ms.copy()
373
+ return_ll = ll_ms
374
+ for i in range(n_train):
375
+ return_lg[i] = log_gamma_ms[i].copy()
376
+
377
+ return return_ll, return_lg, return_mp
BrainPulse/features_space.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import numpy as np
3
+ import pandas as pd
4
+ from sklearn import svm
5
+ import sklearn.model_selection as model_selection
6
+ from sklearn.metrics import accuracy_score
7
+ from sklearn.metrics import confusion_matrix
8
+ from sklearn.impute import SimpleImputer
9
+ from sklearn.svm import SVC
10
+ from sklearn.model_selection import StratifiedKFold
11
+ from sklearn.feature_selection import RFECV
12
+ from sklearn.ensemble import RandomForestClassifier
13
+ import matplotlib.pyplot as plt
14
+
15
+ # import rcr
16
+
17
+ def save_features_as_csv():
18
+ return
19
+
20
+ def load_features_csv_concat(folder_path):
21
+ df_list = []
22
+ for file in glob.glob(folder_path+"/*.csv"):
23
+ df_ = pd.read_csv(file)
24
+ df_list.append(df_)
25
+ df = pd.concat(df_list)
26
+ df = df.reset_index(drop=True)
27
+ return df
28
+
29
+ def exclude_subject(df,exluded_subjects):
30
+ condition_string = ''
31
+ for ex_sub in exluded_subjects:
32
+ condition_string += "(df['Subject'] !='" +ex_sub+"') & "
33
+ evaluation_string = 'df['+condition_string[:len(condition_string)-2]+']'
34
+ df_ex = eval(evaluation_string)
35
+ return df_ex.reset_index(drop=True)
36
+
37
+ def electrode_wise_dataframe(df, condition_list, id_vars = ['Subject', 'Task', 'Electrode']):
38
+ stats_frame = df[
39
+ ['Subject', 'Task', 'Electrode','Lentr', 'TT', 'L', 'RR', 'LAM', 'DET', 'V','Vmax', 'Ventr', 'W','Wentr']
40
+ ]
41
+
42
+ stats_frame.melt(id_vars=id_vars, var_name='RQA_feature', value_name='feature_value')
43
+ stats = stats_frame.pivot_table(index=['Subject', 'Task'], columns='Electrode',
44
+ values=['Lentr', 'TT', 'L', 'RR', 'LAM', 'DET', 'V','Vmax', 'Ventr', 'W', 'Wentr']).reset_index()
45
+
46
+ stats = stats.replace(condition_list[0], 0)
47
+ stats = stats.replace(condition_list[1], 1)
48
+ y = stats.Task.values
49
+ return stats, y
50
+
51
+
52
+ def electrode_wise_dataframe_epochs(df, condition_list, id_vars = ['Subject', 'Task', 'Epoch_id','Electrode']):
53
+ stats_frame = df[
54
+ ['Subject', 'Task','Epoch_id','Electrode','Lentr', 'TT', 'L', 'RR', 'LAM', 'DET', 'V','Vmax', 'Ventr', 'W','Wentr']
55
+ ]
56
+
57
+ stats_frame.melt(id_vars=id_vars, var_name='RQA_feature', value_name='feature_value')
58
+ stats = stats_frame.pivot_table(index=['Subject', 'Task'], columns=['Electrode', 'Epoch_id'],
59
+ values=['Lentr', 'TT', 'L', 'RR', 'LAM', 'DET', 'V','Vmax', 'Ventr', 'W', 'Wentr']).reset_index()
60
+
61
+ stats = stats.replace(condition_list[0], 0)
62
+ stats = stats.replace(condition_list[1], 1)
63
+ y = stats.Task.values
64
+ return stats, y
65
+
66
+
67
+ def select_features_clean_and_normalize(df,features=['Lentr', 'TT', 'L', 'LAM', 'DET','V', 'Ventr', 'W','Wentr']):
68
+
69
+ stats_data = df[features].values
70
+
71
+ #rcr
72
+ stats_data_cleaned=np.empty((stats_data.shape[0],stats_data.shape[1]))
73
+ stats_data_cleaned[:]=np.nan
74
+ # r = rcr.RCR(rcr.SS_MEDIAN_DL)
75
+ r = stats_data_cleaned
76
+
77
+ for ii in range(stats_data.shape[1]):
78
+ # fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(16,8),dpi=200)
79
+ # ax1.hist(stats_data[:,ii])
80
+ # ax1.set_title('Raw')
81
+
82
+ r.performBulkRejection(stats_data[:,ii])
83
+ cleaned_data_indices = r.result.indices
84
+ stats_data_cleaned[cleaned_data_indices,ii]=stats_data[cleaned_data_indices,ii]
85
+
86
+ # ax2.hist(stats_data_cleaned[:,ii][~np.isnan(stats_data_cleaned[:,ii])])
87
+ # ax2.set_title('Cleaned')
88
+
89
+ # plt.savefig('Feature_nr_'+str(ii)+'jpg')
90
+ # plt.close()
91
+
92
+
93
+ df_stats_data_cleaned=pd.DataFrame(stats_data_cleaned)
94
+ # df_stats_data_cleaned=df_stats_data_cleaned.fillna(method='mean', axis=0)#+df_stats_data_cleaned.fillna(method='bfill', axis=0))/2
95
+ # df_stats_data_cleaned.interpolate(limit=5, inplace=True)
96
+
97
+ imputer = SimpleImputer(missing_values=np.nan, strategy='mean')
98
+ imputer = imputer.fit(df_stats_data_cleaned)
99
+
100
+ stats_data_cleaned = imputer.transform(df_stats_data_cleaned)
101
+
102
+ ####normalize#########
103
+ stats_data_normed=np.empty((stats_data.shape[0],stats_data.shape[1]))
104
+ for ii in range(stats_data.shape[1]):
105
+ stats_data_normed[:,ii] = (stats_data_cleaned[:,ii]-stats_data_cleaned[:,ii].min(axis=0))/ (stats_data_cleaned[:,ii].max(axis=0)-stats_data_cleaned[:,ii].min(axis=0)) #stats_data[:,ii]-stats_data[:,ii].mean(axis=0))/ stats_data[:,ii].std(axis=0)
106
+
107
+ return stats_data_normed
108
+
109
+
110
+ def clasyfication_SVM(df,y,cv=10,type='linear'):
111
+
112
+
113
+ clf=svm.SVC(kernel=type)
114
+ skf = StratifiedKFold(n_splits=cv)
115
+ # run split() again to generate folds
116
+ folds = skf.split(df, y)
117
+ print('folds shape ', folds)
118
+ performance = np.zeros(skf.n_splits)
119
+ performance_open= np.zeros(skf.n_splits)
120
+ performance_closed= np.zeros(skf.n_splits)
121
+
122
+ for i, (train_idx, test_idx) in enumerate(folds):
123
+
124
+ X_train = df[train_idx,:]
125
+ y_train = y[train_idx]
126
+
127
+ X_test = df[test_idx,:]
128
+ y_test = y[test_idx]
129
+
130
+ # call fit (on train) and predict (on test)
131
+ model = clf.fit(X=X_train, y=y_train)
132
+ y_hat = model.predict(X=X_test)
133
+
134
+ # calculate accuracy
135
+ performance[i] = accuracy_score(y_test, y_hat)
136
+ cm = confusion_matrix(y_test, y_hat)
137
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
138
+ # class_acuracy = cm.diagonal()
139
+ class_acuracy = cm.diagonal()
140
+ performance_open[i]=class_acuracy[0]*100
141
+ performance_closed[i]=class_acuracy[1]*100
142
+
143
+ # calculate average accuracy
144
+ print('Mean performance: %.3f' % np.mean(performance*100))
145
+ print('Mean performance 1st class: %.3f' % np.mean(performance_open))
146
+ print('Mean performance 2nd class: %.3f' % np.mean(performance_closed))
147
+
148
+
149
+ lin = svm.SVC(kernel=type).fit(X_train, y_train)
150
+ lin_pred = lin.predict(X_test)
151
+
152
+ return lin, lin_pred
153
+
154
+ def cross_validation(df,y,cv=10,title = 'cv job',type='linear'):
155
+
156
+ # Create the RFE object and compute a cross-validated score.
157
+ svc = SVC(kernel=type)
158
+ # The "accuracy" scoring shows the proportion of correct classifications
159
+
160
+ min_features_to_select = 4 # Minimum number of features to consider
161
+ rfecv = RFECV(
162
+ estimator=svc,
163
+ step=1,
164
+ cv=StratifiedKFold(n_splits=cv),
165
+ scoring="accuracy",
166
+ min_features_to_select=min_features_to_select,
167
+ )
168
+ rfecv.fit(df, y)
169
+
170
+ print("Optimal number of features : %d" % rfecv.n_features_)
171
+
172
+ # Plot number of features VS. cross-validation scores
173
+ plt.figure()
174
+ plt.xlabel("Number of features selected")
175
+ plt.ylabel("Cross validation score (accuracy)")
176
+ plt.plot(
177
+ range(min_features_to_select, len(rfecv.cv_results_['mean_test_score']) + min_features_to_select),
178
+ rfecv.cv_results_['mean_test_score'],
179
+ )
180
+ plt.title(title)
181
+ plt.show()
182
+ plt.savefig(title+' classification with feature selection_more_features'+str(rfecv.n_features_)+'_'+str(round(max(rfecv.cv_results_['mean_test_score'])*100,2))+'.png', dpi=150)
183
+ plt.close()
184
+
185
+ return rfecv.transform(df)
186
+
187
+
188
+
189
+ def compute_binary_SVM(df,y,predict_on_all_data = False,type='linear'):
190
+
191
+ # stats_data = df[['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr']].values
192
+ X_train, X_test, y_train, y_test = model_selection.train_test_split(df, y, train_size=0.80, test_size=0.20,
193
+ random_state=101)
194
+ global lin
195
+
196
+ if predict_on_all_data:
197
+ print('SVM prediction on all data')
198
+ lin = svm.SVC(kernel=type).fit(X_train, y_train)
199
+
200
+ lin_pred = lin.predict(df)
201
+
202
+ lin_accuracy = accuracy_score(y, lin_pred)
203
+
204
+ print('Accuracy (Linear Kernel): ', "%.2f" % (lin_accuracy * 100))
205
+
206
+ cm = confusion_matrix(y, lin_pred)
207
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
208
+ class_acuracy = cm.diagonal()
209
+ print('Accuracy (1st class): ', "%.2f" % (class_acuracy[0] * 100))
210
+ print('Accuracy (2nd class): ', "%.2f" % (class_acuracy[1] * 100))
211
+ else:
212
+ print('SVM prediction on test data')
213
+ lin = svm.SVC(kernel=type).fit(X_train, y_train)
214
+
215
+ lin_pred = lin.predict(X_test)
216
+
217
+ lin_accuracy = accuracy_score(y_test, lin_pred)
218
+
219
+ print('Accuracy (Linear Kernel): ', "%.2f" % (lin_accuracy * 100))
220
+ print('Y train:', y_train)
221
+ print('Y test:', y_test)
222
+ print('Y pred:', lin_pred)
223
+
224
+ cm = confusion_matrix(y_test, lin_pred)
225
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
226
+ class_acuracy = cm.diagonal()
227
+ print('Accuracy (1st class): ', "%.2f" % (class_acuracy[0] * 100))
228
+ print('Accuracy (2nd class): ', "%.2f" % (class_acuracy[1] * 100))
229
+
230
+ return lin, lin_pred
231
+
232
+
233
+
234
+ def clasyfication_RFC(df,y,cv=10,max_depth=2):
235
+
236
+ clf = RandomForestClassifier(max_depth=max_depth, random_state=0)
237
+ skf = StratifiedKFold(n_splits=cv)
238
+ # run split() again to generate folds
239
+ folds = skf.split(df, y)
240
+
241
+ performance = np.zeros(skf.n_splits)
242
+ performance_open= np.zeros(skf.n_splits)
243
+ performance_closed= np.zeros(skf.n_splits)
244
+
245
+ for i, (train_idx, test_idx) in enumerate(folds):
246
+
247
+ X_train = df[train_idx,:]
248
+ y_train = y[train_idx]
249
+
250
+ X_test = df[test_idx,:]
251
+ y_test = y[test_idx]
252
+
253
+ # call fit (on train) and predict (on test)
254
+ model = clf.fit(X=X_train, y=y_train)
255
+ y_hat = model.predict(X=X_test)
256
+
257
+ # calculate accuracy
258
+ performance[i] = accuracy_score(y_test, y_hat)
259
+ cm = confusion_matrix(y_test, y_hat)
260
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
261
+ # class_acuracy = cm.diagonal()
262
+ class_acuracy = cm.diagonal()
263
+ performance_open[i]=class_acuracy[0]*100
264
+ performance_closed[i]=class_acuracy[1]*100
265
+
266
+ # calculate average accuracy
267
+ print('Mean performance: %.3f' % np.mean(performance*100))
268
+ print('Mean performance 1st class: %.3f' % np.mean(performance_open))
269
+ print('Mean performance 2nd class: %.3f' % np.mean(performance_closed))
270
+
271
+
272
+ lin = RandomForestClassifier(max_depth=max_depth, random_state=0)
273
+ lin.fit(X=X_train, y=y_train)
274
+ lin_pred = lin.predict(X_test)
275
+
276
+ return lin, lin_pred
277
+
278
+
279
+
280
+
281
+
282
+
BrainPulse/frequency_recurrence.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ # %matplotlib inline
3
+ # %load_ext autotime
4
+ plt.style.use('classic')
5
+ import numpy as np
6
+
7
+
8
+
9
+ def normalize(tSignal):
10
+ # copy the data if needed, omit and rename function argument if desired
11
+ signal = np.copy(tSignal) # signal is in range [a;b]
12
+ signal -= np.min(signal) # signal is in range to [0;b-a]
13
+ signal /= np.max(signal) # signal is normalized to [0;1]
14
+ signal -= 0.5 # signal is in range [-0.5;0.5]
15
+ signal *=2 # signal is in range [-1;1]
16
+ return signal
17
+
18
+
19
+
20
+
21
+ def get_max_freqs(stft, cut_freq, crop=False, norm=True):
22
+ if crop==False:
23
+ crop=len(stft)
24
+ id_max_stft = stft[:crop].argmax(axis=1)*(cut_freq/stft.shape[1])
25
+ if norm:
26
+ id_max_norm=id_max_stft*(1/cut_freq)
27
+ return id_max_norm
28
+ else:
29
+ return id_max_stft
30
+
31
+
32
+ def symmetrize(a):
33
+ return a + np.transpose(a, (1, 0, 2)) - np.diag(a.diagonal())
34
+
35
+ def symmetrized(m):
36
+
37
+ import numpy as np
38
+
39
+ i_lower = np.tril_indices(m.shape[0], -1)
40
+ m[:,:,0][i_lower] = m[:,:,0].T[i_lower]
41
+ m[:,:,1][i_lower] = m[:,:,1].T[i_lower]
42
+ m[:,:,2][i_lower] = m[:,:,2].T[i_lower]
43
+
44
+ return m
45
+
46
+ def calc_color(id_max_norm, unique_values,colors,crop):
47
+ if crop==False:
48
+ crop=len(id_max_norm)
49
+ tmp = np.zeros((id_max_norm.shape[0],id_max_norm.shape[0],3), dtype = 'float64')
50
+
51
+ for i, v1 in enumerate(id_max_norm):
52
+ for j, v2 in enumerate(id_max_norm):
53
+ tmp[i,j] = v1,v2,0
54
+
55
+ tmp_color = np.zeros((id_max_norm.shape[0],id_max_norm.shape[0],3), dtype = 'float64')
56
+ for i in range(crop):
57
+ for j in range(crop):
58
+ x = tmp[i,j][0]
59
+ y = tmp[i,j][1]
60
+
61
+ for k in unique_values:
62
+
63
+ if x == k:
64
+ id = (unique_values).index(x)
65
+
66
+ tmp_color[i,j][0] = colors[id][0]
67
+ if y == k:
68
+ id = (unique_values).index(y)
69
+ tmp_color[i,j][1] = colors[id][1]
70
+
71
+
72
+ return symmetrized(tmp_color),symmetrized(tmp)
73
+
74
+ def calc_color_raw(id_max_stft):
75
+
76
+ tmp = np.zeros((id_max_stft.shape[0],id_max_stft.shape[0],3), dtype = 'float64')
77
+
78
+ for i, v1 in enumerate(id_max_stft):
79
+ for j, v2 in enumerate(id_max_stft):
80
+ tmp[i,j] = v1,v2,0
81
+ return symmetrized(tmp)
82
+
83
+ def get_unique_colors(cut_freq,unique_values):
84
+ from matplotlib.colors import LinearSegmentedColormap
85
+
86
+ vmax=cut_freq
87
+ cmap = LinearSegmentedColormap.from_list('mycmap1', [(0 / vmax, "violet"),
88
+ (4 / vmax, 'blue'),
89
+ (8 / vmax, 'green'),
90
+ (15 / vmax, 'yellow'),
91
+ (30 / vmax, 'red'),
92
+ (60 / vmax, 'black')
93
+ ])
94
+
95
+ colors=[cmap(each) for each in unique_values]
96
+
97
+ return cmap, colors
98
+
99
+
100
+ def freqRP(stft, cut_freq, crop=False, norm=True):
101
+
102
+ id_max_stft=get_max_freqs(stft,cut_freq,crop,norm)
103
+
104
+ unique_values = np.unique(id_max_stft).tolist()
105
+
106
+ cmap,colors=get_unique_colors(cut_freq, unique_values)
107
+
108
+ color_matrix, raw_c_matrix = calc_color(id_max_stft, unique_values,colors,crop)
109
+
110
+ return color_matrix, raw_c_matrix
111
+
112
+
113
+ def plot_freqRP_interactive(color_matrix, raw_c_matrix, filename='', save=False):
114
+ import plotly.express as px
115
+ import numpy as np
116
+ fig=px.imshow(color_matrix, origin='lower')
117
+ fig.update_traces(customdata=np.round((raw_c_matrix*60),2),
118
+ hovertemplate="First frequency: %{customdata[0]}<br>Second frequency: %{customdata[1]}<extra></extra>"
119
+ )
120
+ if save:
121
+ fig.write_html(filename)
122
+ else:
123
+ fig.show()
124
+
BrainPulse/matrix_open_binary.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db7cd1f9b9fea62c6ecda8ec53c0e3759d6902e82590770080ef5c24531a8c0b
3
+ size 92236944
BrainPulse/model_SVM.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import seaborn as sns
5
+ from sklearn import svm, datasets
6
+ import sklearn.model_selection as model_selection
7
+ from sklearn.metrics import accuracy_score
8
+ from sklearn.metrics import f1_score
9
+ from sklearn.metrics import confusion_matrix
10
+ from sklearn.preprocessing import StandardScaler
11
+ import umap
12
+ import umap.plot
13
+ from umap.parametric_umap import ParametricUMAP
14
+ from sklearn.impute import SimpleImputer
15
+ from sklearn.pipeline import make_pipeline
16
+ from sklearn.preprocessing import QuantileTransformer
17
+ from mne.datasets import eegbci
18
+
19
+ def compute_Linear_Kernel(df):
20
+ stats_frame = df[
21
+ ['Subject', 'Task', 'Electrode', 'Lentr', 'TT', 'L', 'RR', 'LAM', 'DET']
22
+ ]
23
+
24
+ stats_frame.melt(id_vars=['Subject', 'Task', 'Electrode'], var_name='RQA_feature', value_name='feature_value')
25
+ stats1 = stats_frame.pivot_table(index=['Subject', 'Task'], columns='Electrode',
26
+ values=['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr']).reset_index()
27
+
28
+ y = stats1.Task.values
29
+ stats_data = stats1[['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr']].values
30
+ X_train, X_test, y_train, y_test = model_selection.train_test_split(stats_data, y, train_size=0.80, test_size=0.20,
31
+ random_state=101)
32
+
33
+ lin = svm.SVC(kernel='linear').fit(X_train, y_train)
34
+
35
+ lin_pred = lin.predict(X_test)
36
+
37
+ lin_accuracy = accuracy_score(y_test, lin_pred)
38
+
39
+ print('Accuracy (Linear Kernel): ', "%.2f" % (lin_accuracy * 100))
40
+
41
+ cm = confusion_matrix(y_test, lin_pred)
42
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
43
+ class_acuracy = cm.diagonal()
44
+ print('Accuracy (open): ', "%.2f" % (class_acuracy[0] * 100))
45
+ print('Accuracy (close): ', "%.2f" % (class_acuracy[1] * 100))
BrainPulse/plot.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib import animation, cm
5
+ import matplotlib
6
+ import seaborn as sns
7
+ import seaborn as sns
8
+ import umap
9
+ import umap.plot
10
+ import pandas as pd
11
+ from .event import EventSegment
12
+ from sklearn.impute import SimpleImputer
13
+ from sklearn.pipeline import make_pipeline
14
+ from sklearn.preprocessing import QuantileTransformer
15
+
16
+ sns.set_style("whitegrid")
17
+
18
+ # plt.rcParams["font.family"] = "cursive"
19
+ # plt.rcParams.update({'font.sans-serif':'Times'})
20
+ # plt.rcParams.update({'font.family':'sans-serif'})
21
+ # plt.rcParams['font.size'] = 14
22
+ import matplotlib.font_manager as font_manager
23
+ font = font_manager.FontProperties(family='Times')
24
+
25
+ def explainer_(chan, stft, cut_freq, s_rate):
26
+
27
+ fig, axs = plt.subplots(4, figsize=(10, 14), dpi=150) # figsize=(12, 12),
28
+ time_crop = np.linspace(0, int(chan[:400].shape[0]), chan[:400].shape[0])
29
+
30
+ axs[0].plot(chan[:400],'k') # np.linspace(0, int(chan[:400].shape[0]/s_rate), chan[:400].shape[0]),
31
+ axs[0].fill_betweenx(y=[-210, 125], x1=time_crop[0],
32
+ x2=time_crop[240], color='white', alpha=0.9, edgecolor='red' )
33
+
34
+ axs[0].fill_betweenx(y=[-210, 130], x1=time_crop[2]+20,
35
+ x2=time_crop[260], color='white', alpha=0.9, edgecolor='green')
36
+
37
+ axs[0].fill_betweenx(y=[-210, 135], x1=time_crop[2]+40,
38
+ x2=time_crop[280], color='white', alpha=0.9, edgecolor='blue')
39
+
40
+ axs[0].annotate('$fft_{1}$', xy=(.25, 72), xycoords='data',
41
+ xytext=(0.05, 1.45), textcoords='axes fraction',
42
+ arrowprops=dict(arrowstyle="->",facecolor='black',color='black'),
43
+ horizontalalignment='right', verticalalignment='top',
44
+ )
45
+
46
+ axs[0].annotate('$fft_{2}$', xy=(23.35, 85), xycoords='data',
47
+ xytext=(0.15, 1.45), textcoords='axes fraction',
48
+ arrowprops=dict(arrowstyle="->",facecolor='black',color='black'),
49
+ horizontalalignment='right', verticalalignment='top',
50
+ )
51
+
52
+ axs[0].annotate('$fft_{3}$', xy=(43.45, 95), xycoords='data',
53
+ xytext=(0.25, 1.45), textcoords='axes fraction',
54
+ arrowprops=dict(arrowstyle="->",facecolor='black ',color='black'),
55
+ horizontalalignment='right', verticalalignment='top',
56
+ )
57
+ axs[0].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, chan[:400].shape[0], 5)))
58
+ axs[0].set_xticklabels(
59
+ [str(np.round(x, 1)) for x in np.linspace(0, int(chan[:400].shape[0] / s_rate), axs[0].get_xticks().shape[0])])
60
+ axs[0].set_ylabel('Amplitude (µV)', )
61
+ axs[0].set_xlabel('Time (s)', )
62
+ axs[0].set_title('(a)', )
63
+ axs[0].xaxis.grid()
64
+ axs[0].yaxis.grid()
65
+
66
+
67
+ axs[1].plot((stft[100]/stft.shape[1])**2, 'red',label='$fft_{1}$',marker="o",markersize=3)
68
+ axs[1].legend(prop=font)
69
+ axs[1].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
70
+ axs[1].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
71
+ axs[1].set_xlim([0, 100])
72
+ # axs[1].set_ylim([0, 250])
73
+ axs[1].set_ylabel('Power ($\mu V^{2}$)', )
74
+ axs[1].set_xlabel('Freq (Hz)', )
75
+ # axs[1].set_title('Frequency Domain ($fft_{1}$, $fft_{2}$, $fft_{3}$)', fontsize=10)
76
+ axs[1].set_title('(b)', )
77
+ axs[1].xaxis.grid()
78
+ axs[1].yaxis.grid()
79
+
80
+
81
+ axs[2].plot((stft[115]/stft.shape[1])**2, 'green', label='$fft_{2}$', marker="o", markersize=3)
82
+ axs[2].legend(prop=font)
83
+ axs[2].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
84
+ axs[2].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
85
+ axs[2].set_xlim([0, 100])
86
+ # axs[2].set_ylim([0, 250])
87
+ axs[2].set_ylabel('Power ($\mu V^{2}$)', )
88
+ axs[2].set_xlabel('Freq (Hz)', )
89
+ axs[2].set_title('(c)', )
90
+ axs[2].xaxis.grid()
91
+ axs[2].yaxis.grid()
92
+
93
+
94
+ axs[3].plot((stft[140]/stft.shape[1])**2, 'blue', label='$fft_{3}$', marker="o", markersize=3)
95
+ axs[3].legend(prop=font)
96
+ axs[3].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
97
+ axs[3].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
98
+ axs[3].set_xlim([0, 100])
99
+ axs[3].set_ylabel('Power ($\mu V^{2}$)', )
100
+ axs[3].set_xlabel('Freq (Hz)', )
101
+ axs[3].set_title('(d)', )
102
+ axs[3].xaxis.grid()
103
+ axs[3].yaxis.grid()
104
+
105
+ # plt.title('Frequency Domain ($fft_{1}$, $fft_{2}$, $fft_{3}$)', fontsize=10)
106
+ plt.tight_layout()
107
+
108
+ plt.savefig('fig_4.png')
109
+
110
+
111
+ def stft_collections(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args, max_indx = None, min_indx = None):
112
+ fig = plt.figure(figsize=(14, 12), dpi=150)
113
+ grid = plt.GridSpec(6, 8, hspace=0.0, wspace=3.5)
114
+ spectrogram = fig.add_subplot(grid[0:3, 0:4])
115
+ rp_plot = fig.add_subplot(grid[0:3, 4:])
116
+ fft_vector = fig.add_subplot(grid[4:, :])
117
+
118
+ if max_indx != None and min_indx != None:
119
+ max_index = max_indx
120
+ min_index = min_indx
121
+ else:
122
+ max_array = np.max(stft, axis=1)
123
+ max_value_stft = np.max(max_array, axis=0)
124
+ max_index = list(max_array).index(max_value_stft)
125
+
126
+ min_array = np.min(stft, axis=1)
127
+ min_value_stft = np.min(min_array, axis=0)
128
+ min_index = list(min_array).index(min_value_stft)
129
+
130
+
131
+
132
+ # ręczne ustawienie wskaźników
133
+ # max_index = int(1.52*s_rate)
134
+ # min_index = int(2.4*s_rate)
135
+
136
+
137
+ # top = np.triu(matrix)
138
+ # bottom = np.tril(matrix_binary)
139
+
140
+ # np.linspace(0, stft.shape[1], stft.shape[1]), np.linspace(0, stft.shape[0], cut_freq),
141
+ rp_plot.imshow(matrix_binary, cmap='Greys', origin='lower') # interpolation='none'
142
+ # axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
143
+ rp_plot.plot(max_index, max_index, 'orange', marker="o", markersize=9)
144
+ rp_plot.plot(min_index, min_index, 'red', marker="o", markersize=9)
145
+ # axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
146
+ # axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
147
+ rp_plot.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
148
+ rp_plot.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
149
+ rp_plot.set_xticklabels(
150
+ [str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_xticks().shape[0])])
151
+ rp_plot.set_yticklabels(
152
+ [str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_yticks().shape[0])])
153
+ rp_plot.set_xlabel('Time (s)', )
154
+ rp_plot.set_ylabel('Time (s)', )
155
+ rp_plot.set_title('(b) Recurrence Plot', )
156
+ rp_plot.xaxis.grid()
157
+ rp_plot.yaxis.grid()
158
+
159
+ spectrogram.pcolormesh(stft.T,cmap='viridis') #,vmax=max_value_stft
160
+ spectrogram.plot(max_index,2,'orange', marker="|", markersize=40)
161
+ spectrogram.plot(min_index,2,'red', marker="|", markersize=40)
162
+
163
+ spectrogram.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[0], 5)))
164
+ spectrogram.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, stft.shape[0] / s_rate, spectrogram.get_xticks().shape[0])])
165
+ spectrogram.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 5)))
166
+ spectrogram.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
167
+ spectrogram.set_ylabel('Freq (Hz)', )
168
+ spectrogram.set_xlabel('Time (s)', )
169
+ spectrogram.set_title('(a) Spectrogram', )
170
+ # spectrogram.xaxis.grid()
171
+ # spectrogram.yaxis.grid()
172
+ # fig.colorbar(im1, cax=spectrogram, orientation='vertical')
173
+
174
+
175
+ max_index_ = stft[max_index]/stft.shape[1]
176
+ min_index_ = stft[min_index]/stft.shape[1]
177
+ fft_vector.plot(max_index_**2,'orange',label='$fft_{t_{1}}$')#,marker="o", markersize=2
178
+ fft_vector.plot(min_index_**2,'red',label='$fft_{t_{2}}}$')#,marker="o", markersize=2
179
+ fft_vector.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
180
+ fft_vector.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
181
+ fft_vector.set_xlim([0,100])
182
+ fft_vector.set_ylabel('Power ($\mu V^{2}$)', )
183
+ fft_vector.set_xlabel('Freq (Hz)', )
184
+ fft_vector.set_title('(c) Frequency Domain', )
185
+ fft_vector.legend(prop=font)
186
+ fft_vector.xaxis.grid()
187
+ fft_vector.yaxis.grid()
188
+
189
+ # plt.suptitle( 'Condition: '+ task + '\n' + 'epsilon {}, FFT window size {} '.format(
190
+ # str(info_args['eps']), str(info_args['win_len'])) + '\n'
191
+ # + 'Subject {}, electrode {}, n_fft {}'.format(str(info_args['selected_subject']),str(info_args['electrode_name']),str(info_args['n_fft'])), fontsize=12 ,ha='left',va='top')
192
+ plt.tight_layout()
193
+ plt.savefig('fig_5.png')
194
+ # axs[0].imshow(matrix_binary, cmap='cividis', origin='lower') #interpolation='none'
195
+ # # axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
196
+ # axs[0].plot(max_index,max_index,'orange',marker="o", markersize=7)
197
+ # axs[0].plot(min_index,min_index,'red',marker="o", markersize=7)
198
+ # # axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
199
+ # # axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
200
+ # axs[0].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
201
+ # axs[0].yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
202
+ # axs[0].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_xticks().shape[0])])
203
+ # axs[0].set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_yticks().shape[0])])
204
+ # axs[0].set_xlabel('Time (s)')
205
+ # axs[0].set_ylabel('Time (s)')
206
+ # axs[0].set_title('Recurrence Plot', fontsize=12)
207
+
208
+ def diagnostic(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args):
209
+
210
+ fig, axs = plt.subplots(3,1, figsize=(7,12), gridspec_kw={'height_ratios':[6,2,1]},dpi=150)
211
+
212
+ # Set up the axes with gridspec
213
+
214
+
215
+ max_array = np.max(stft, axis=1)
216
+ max_value_stft = np.max(max_array, axis=0)
217
+ max_index = list(max_array).index(max_value_stft)
218
+
219
+ min_array = np.min(stft, axis=1)
220
+ min_value_stft = np.min(min_array, axis=0)
221
+ min_index = list(min_array).index(min_value_stft)
222
+
223
+ # top = np.triu(matrix)
224
+ # bottom = np.tril(matrix_binary)
225
+
226
+ axs[0].imshow(matrix_binary, cmap='cividis', origin='lower') #interpolation='none'
227
+ # axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
228
+
229
+ axs[0].plot(max_index,max_index,'orange',marker="o", markersize=7)
230
+ axs[0].plot(min_index,min_index,'red',marker="o", markersize=7)
231
+
232
+ axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
233
+ axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
234
+ axs[0].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
235
+ axs[0].yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
236
+ axs[0].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_xticks().shape[0])])
237
+ axs[0].set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_yticks().shape[0])])
238
+ axs[0].set_xlabel('Time (s)')
239
+ axs[0].set_ylabel('Time (s)')
240
+ axs[0].set_title('Recurrence Plot', )
241
+
242
+
243
+ # np.linspace(0, stft.shape[1], stft.shape[1]), np.linspace(0, stft.shape[0], cut_freq),
244
+
245
+
246
+ axs[1].pcolormesh(stft.T, shading='gouraud') #,vmax=max_value_stft
247
+ axs[1].plot(max_index,0,'orange', marker="o", markersize=7)
248
+ axs[1].plot(min_index,0,'red', marker="o", markersize=7)
249
+ axs[1].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
250
+ axs[1].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[1].get_xticks().shape[0])])
251
+ axs[1].yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 5)))
252
+ axs[1].set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
253
+ axs[1].set_ylabel('Freq (Hz)')
254
+ axs[1].set_xlabel('Time (s)')
255
+ axs[1].set_title('Spectrogram', )
256
+
257
+ max_index_ = stft[max_index]/stft.shape[1]
258
+ min_index_ = stft[min_index]/stft.shape[1]
259
+ axs[2].plot(max_index_**2,'orange')#,marker="o", markersize=2
260
+ axs[2].plot(min_index_**2,'red')#,marker="o", markersize=2
261
+ axs[2].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
262
+ axs[2].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
263
+ axs[2].set_xlim([0,100])
264
+ axs[2].set_ylabel('Power (µV^2)')
265
+ axs[2].set_xlabel('Freq (Hz)')
266
+ axs[2].set_title('Frequency Domain',)
267
+
268
+ plt.suptitle( 'Condition: '+ task + '\n' + 'epsilon {}, FFT window size {} '.format(
269
+ str(info_args['eps']), str(info_args['win_len'])) + '\n'
270
+ + 'Subject {}, electrode {}, n_fft {}'.format(str(info_args['selected_subject']),str(info_args['electrode_name']),str(info_args['n_fft'])),
271
+ ha='left',va='top')
272
+ plt.tight_layout()
273
+
274
+ def RecurrencePlot(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args):
275
+
276
+ fig, axs = plt.subplots( figsize=(12,12),dpi=200)
277
+
278
+ # Set up the axes with gridspec
279
+
280
+
281
+ max_array = np.max(stft, axis=1)
282
+ max_value_stft = np.max(max_array, axis=0)
283
+ max_index = list(max_array).index(max_value_stft)
284
+
285
+ min_array = np.min(stft, axis=1)
286
+ min_value_stft = np.min(min_array, axis=0)
287
+ min_index = list(min_array).index(min_value_stft)
288
+
289
+ # top = np.triu(matrix)
290
+ # bottom = np.tril(matrix_binary)
291
+
292
+ axs.imshow(matrix_binary, cmap='cividis', origin='lower') #interpolation='none'
293
+ # axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
294
+
295
+ # axs[0].plot(max_index,max_index,'orange',marker="o", markersize=7)
296
+ # axs[0].plot(min_index,min_index,'red',marker="o", markersize=7)
297
+
298
+ # axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
299
+ # axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
300
+ axs.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
301
+ axs.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
302
+ axs.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs.get_xticks().shape[0])])
303
+ axs.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs.get_yticks().shape[0])])
304
+ axs.set_xlabel('Time (s)')
305
+ axs.set_ylabel('Time (s)')
306
+ axs.set_title('Recurrence Plot')
307
+
308
+
309
+ # np.linspace(0, stft.shape[1], stft.shape[1]), np.linspace(0, stft.shape[0], cut_freq),
310
+
311
+ def features_hists(df, features_list, condition, dpi = 200):
312
+ fig, axs = plt.subplots(len(features_list),figsize=(6, len(features_list)*3), dpi=dpi)
313
+ abc = ['(a)','(b)','(c)','(d)','(e)','(f)']
314
+
315
+ for i,ax in enumerate(axs):
316
+ sns.histplot(data=df, x=features_list[i], hue=condition, alpha=0.8, element="bars", fill=False, ax=ax, kde=True)
317
+ ax.containers[1].remove()
318
+ ax.containers[0].remove()
319
+ ax.xaxis.grid()
320
+ ax.yaxis.grid()
321
+ ax.set_title(abc[i])
322
+ # plt.grid(b=None)
323
+
324
+ plt.autoscale(enable=True, axis='both', tight=None)
325
+ fig.tight_layout()
326
+
327
+ def features_per_subjects_violin(df, features_list, condition, dpi = 200):
328
+ fig, axs = plt.subplots(len(features_list),figsize=(14, len(features_list)*2), dpi=dpi,sharex='col')
329
+
330
+ for i,ax in enumerate(axs):
331
+ sns.violinplot(data=df, x=df.Subject, y=features_list[i], hue=condition, ax=ax, split=True,linewidth=0.2)
332
+ ax.legend(loc='lower right')
333
+
334
+ axs[len(features_list)-1].set_xticklabels(axs[len(features_list)-1].get_xticklabels(), rotation=90)
335
+ # axs.set_ylim([0,1])
336
+
337
+
338
+ plt.tick_params(axis='x', which='major', labelsize=16)
339
+ fig.tight_layout()
340
+
341
+ def umap_on_condition(df,y, title,labels_name,features_list=['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr'], random_state = 70, n_neighbors = 15, min_dist = 0.25, metric = "hamming", df_type=True):
342
+
343
+
344
+ fig, ax1 = plt.subplots(figsize=(8, 8), dpi=150)
345
+
346
+ if df_type:
347
+ stats_data = df
348
+ else:
349
+ stats_data = df[features_list].values
350
+
351
+ # Preprocess again
352
+ pipe = make_pipeline(SimpleImputer(strategy="mean"), QuantileTransformer())
353
+ X = pipe.fit_transform(stats_data.copy())
354
+
355
+ # Fit UMAP to processed data
356
+ manifold = umap.UMAP(random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit(X, y)
357
+ # X_reduced_2 = manifold.transform(X)
358
+ umap.plot.points(manifold, labels=labels_name, ax=ax1, color_key=np.array(
359
+ [(0, 0.35, 0.73), (1, 0.83, 0)])) # ,color_key=np.array([(1,0.83,0),(0,0.35,0.73)])
360
+ ax1.set_title(title)
361
+
362
+ def umap_side_by_side_plot(df1, df2, features_list=['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr'], random_state = 70, n_neighbors = 15, min_dist = 0.25, metric = "hamming"):
363
+
364
+ fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(16,8),dpi=150)
365
+
366
+ stats_data = df1[features_list].values
367
+ y = df1.Task.values
368
+
369
+ # Preprocess again
370
+ pipe = make_pipeline(SimpleImputer(strategy="mean"), QuantileTransformer())
371
+ X = pipe.fit_transform(stats_data.copy())
372
+
373
+ # Fit UMAP to processed data
374
+ manifold = umap.UMAP(random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit(X, y)
375
+ # X_reduced_2 = manifold.transform(X)
376
+ umap.plot.points(manifold, labels=y, ax=ax1, color_key=np.array(
377
+ [(0, 0.35, 0.73), (1, 0.83, 0)])) # ,color_key=np.array([(1,0.83,0),(0,0.35,0.73)])
378
+ ax1.set_xlabel('(a) STFT Condition 0 - open eyes, 1 - closed eyes')
379
+
380
+ stats_data = df2[features_list].values
381
+ y = df2.Task.values
382
+
383
+ # Preprocess again
384
+ pipe = make_pipeline(SimpleImputer(strategy="mean"), QuantileTransformer())
385
+ X = pipe.fit_transform(stats_data.copy())
386
+
387
+ # Fit UMAP to processed data
388
+ manifold = umap.UMAP(random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit(X, y)
389
+ # X_reduced_2 = manifold.transform(X)
390
+ umap.plot.points(manifold, labels=y, ax=ax2, color_key=np.array(
391
+ [(0, 0.35, 0.73), (1, 0.83, 0)])) # ,color_key=np.array([(1,0.83,0),(0,0.35,0.73)])
392
+ ax2.set_xlabel('(b) TDEMB Condition 0 - open eyes, 1 - closed eyes')
393
+
394
+ return
395
+
396
+ def SVM_histogram(df, lin, lin_pred,title):
397
+ stats_data = df #[features_list].values
398
+ plt.figure(dpi=150)
399
+ all_cechy=np.dot(stats_data, lin.coef_.T)
400
+ df_all=pd.DataFrame({'vectors':all_cechy.ravel(), 'Task':lin_pred})
401
+
402
+
403
+ a = sns.histplot(data=df_all, x='vectors', hue='Task', alpha=0.8, element="bars", fill=False,kde=True, kde_kws={'bw_adjust':0.4},palette=np.array([(0.3,0.85,0),(0.8,0.0,0.44)]))
404
+ a.containers[1].remove()
405
+ a.containers[0].remove()
406
+ # a = sns.kdeplot(data=df_all, x='vectors', hue='Task', alpha=0.8, bw_adjust=0.4,palette=np.array([(0.3,0.85,0),(0.8,0.0,0.44)]))
407
+ plt.title(title)
408
+ plt.xlabel('All')
409
+ plt.grid(b=None)
410
+ plt.show()
411
+
412
+ def f_importances(coef, names):
413
+ imp = coef
414
+ imp,names = zip(*sorted(zip(imp,names)))
415
+ plt.figure()
416
+ plt.barh(range(len(names)), imp, align='center')
417
+ plt.yticks(range(len(names)), names)
418
+ plt.show()
419
+
420
+ def SVM_features_importance(lin):
421
+
422
+
423
+ lebel_ll = np.array([['TT']*int(64)+ ['RR']*int(64)+
424
+ ['DET']*int(64)+ ['LAM']*int(64)+
425
+ ['L']*int(64)+ ['L_entr']*int(64)])
426
+
427
+ e_long = "Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8".replace('\t',',').split(",")
428
+ y_e_long = np.array(np.unique(e_long, return_inverse=True)[1].tolist())
429
+
430
+ df = pd.DataFrame({'feature':lebel_ll[0],
431
+ 'electrode':e_long,
432
+ 'coef':lin.coef_[0],
433
+ })
434
+ # df = df[(df.coef.values >= 0.15) | (df.coef.values <= -0.15)]
435
+
436
+ f_importances(df.coef, df.feature)
437
+ f_importances(df.coef, df.electrode)
438
+
439
+
440
+ sns.set_theme(style='darkgrid', rc={'figure.dpi': 120},
441
+ font_scale=1.7)
442
+ fig, ax = plt.subplots(figsize=(16, 10))
443
+ ax.set_title('Weight of features by electrodes')
444
+ sns.barplot(x='feature', y='coef', data=df, ax=ax,
445
+ ci=None,
446
+ hue='electrode')
447
+ ax.legend(bbox_to_anchor=(1, 1), title='electrode',prop={'size': 7})
448
+
449
+ ##### HIDDEN MARKOV MODEL
450
+
451
+ def soft_bounds(T,seg):
452
+
453
+ # Identify soft boundaries at each step of fitting
454
+ bounds_anim = []
455
+ K = seg[0].shape[1]
456
+ for it in range(1,len(seg)):
457
+ sb = np.zeros((T,T))
458
+ for k in range(K-1):
459
+ p_change = np.diff(seg[it][:,(k+1):].sum(1))
460
+ sb[1:,1:] += np.outer(p_change, seg[it][1:,k:(k+2)].sum(1))
461
+ sb = np.maximum(sb,sb.T)
462
+ sb = sb/np.max(sb)
463
+ bounds_anim.append(sb)
464
+ return bounds_anim
465
+
466
+
467
+ def fitting_animation(seg,matrix,s_rate,meta_tick,metastate_id, state_width,color_states_matrix):
468
+
469
+ bounds_anim = soft_bounds(matrix.shape[0],seg)
470
+
471
+ # Plot timepoint-timepoint correation matrix, with boundaries animated on top
472
+
473
+ fig = plt.figure(figsize=(18, 12), dpi=300)
474
+ grid = plt.GridSpec(4, 12, hspace=0.0, wspace=3.5)
475
+ ax1 = fig.add_subplot(grid[:, 0:4])
476
+ ax2 = fig.add_subplot(grid[:, 4:8])
477
+ ax3 = fig.add_subplot(grid[:, 8:])
478
+
479
+
480
+ # fig, axs = plt.subplots(2,figsize=(8,8), dpi=120)
481
+ datamat = matrix # np.corrcoef(D)
482
+ bk = cm.viridis((datamat-np.min(datamat))/(np.max(datamat)-np.min(datamat)))
483
+ im = ax1.imshow(bk, interpolation='none',origin='lower')
484
+ fg = cm.gray(1-(sum(bounds_anim)/len(bounds_anim)))
485
+ # im.set_array(np.minimum(np.maximum(bk + fg, 0), 1))
486
+ im.set_array(bk * fg)
487
+
488
+
489
+ ax1.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
490
+ ax1.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
491
+ ax1.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax1.get_xticks().shape[0])])
492
+ ax1.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax1.get_yticks().shape[0])])
493
+ ax1.set_xlabel('Time (s)')
494
+ ax1.set_ylabel('Time (s)')
495
+ ax1.set_title('Metastates plot over recurrence plot', fontsize=10)
496
+ ax1.scatter(meta_tick,meta_tick,s=2)
497
+
498
+
499
+
500
+ ax2.imshow(fg, interpolation='none',origin='lower')
501
+ ax2.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
502
+ ax2.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
503
+ ax2.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax2.get_xticks().shape[0])])
504
+ ax2.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax2.get_yticks().shape[0])])
505
+ ax2.set_xlabel('Time (s)')
506
+ ax2.set_ylabel('Time (s)')
507
+ ax2.set_title('Metastates plot', fontsize=10)
508
+ ax2.scatter(meta_tick,meta_tick,s=2)
509
+
510
+ text_kwargs = dict(ha='center', va='center', fontsize=4, color='0')
511
+ for i,mstate in enumerate(metastate_id):
512
+ # ax1.text(meta_tick[i]-35, meta_tick[i]+(state_width[i]/2)+45, 's'+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', **text_kwargs)
513
+ ax2.annotate('s '+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', xy=(meta_tick[i], meta_tick[i]+(state_width[i]/2)),
514
+ xytext =(meta_tick[i], meta_tick[i]+(state_width[i]/2)+70),
515
+ xycoords='data',
516
+ textcoords='data',
517
+ arrowprops=dict(arrowstyle="->",facecolor='blue'),
518
+ horizontalalignment='right', verticalalignment='top', fontsize=5
519
+ )
520
+
521
+
522
+ color_states = color_states_matrix
523
+ ax3.imshow(fg[:,:,:3]*color_states, interpolation='none',origin='lower')
524
+ ax3.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
525
+ ax3.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
526
+ ax3.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax3.get_xticks().shape[0])])
527
+ ax3.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax3.get_yticks().shape[0])])
528
+ ax3.set_xlabel('Time (s)')
529
+ ax3.set_ylabel('Time (s)')
530
+ ax3.set_title('Metastates plot', fontsize=10)
531
+ ax3.scatter(meta_tick,meta_tick,s=2)
532
+
533
+ text_kwargs = dict(ha='center', va='center', fontsize=4, color='0')
534
+ for i,mstate in enumerate(metastate_id):
535
+ # ax1.text(meta_tick[i]-35, meta_tick[i]+(state_width[i]/2)+45, 's'+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', **text_kwargs)
536
+ ax3.annotate('s '+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', xy=(meta_tick[i], meta_tick[i]+(state_width[i]/2)),
537
+ xytext =(meta_tick[i], meta_tick[i]+(state_width[i]/2)+70),
538
+ xycoords='data',
539
+ textcoords='data',
540
+ arrowprops=dict(arrowstyle="->",facecolor='blue'),
541
+ horizontalalignment='right', verticalalignment='top', fontsize=5
542
+ )
543
+
544
+
545
+ # def animate_func(i):
546
+ # fg = cm.Greys(1-bounds_anim[i])
547
+ # im.set_array(np.minimum(np.maximum(bk + fg,0),1))
548
+ # return [im]
549
+ #
550
+ # anim = animation.FuncAnimation(fig, animate_func,
551
+ # frames = len(bounds_anim), interval = 1)
552
+ #
553
+ #
554
+ plt.savefig('Metastate.png')
555
+ # plt.close("all")
556
+
557
+ return fig
558
+
559
+ # return HTML(anim.to_jshtml(default_mode='Once'))
560
+
561
+ def fit_HMM(matrix,n_events):
562
+ return EventSegment(n_events=n_events).fit(matrix)
563
+
564
+ def metastates(seg,matrix,s_rate,meta_tick,metastate_id, state_width,color_states_matrix):
565
+ fitting_animation(seg,matrix,s_rate,meta_tick,metastate_id, state_width,color_states_matrix)
566
+
567
+ # def diagnostic(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args):
568
+ #
569
+ # # fig, axs = plt.subplots(3,1, figsize=(4,8), gridspec_kw={'height_ratios':[6,2,1]},dpi=120)
570
+ #
571
+ # # Set up the axes with gridspec
572
+ # fig = plt.figure(figsize=(6, 6),dpi=120)
573
+ # grid = plt.GridSpec(6, 6, hspace=1.0, wspace=1.0)
574
+ # spectrogram = fig.add_subplot(grid[0:3, 0:3])
575
+ # rp_plot = fig.add_subplot(grid[0:3, 3:])
576
+ # fft_vector = fig.add_subplot(grid[3:,:])
577
+ #
578
+ # max_array = np.max(stft, axis=1)
579
+ # max_value_stft = np.max(max_array, axis=0)
580
+ # max_index = list(max_array).index(max_value_stft)
581
+ #
582
+ # min_array = np.min(stft, axis=1)
583
+ # min_value_stft = np.min(min_array, axis=0)
584
+ # min_index = list(min_array).index(min_value_stft)
585
+ #
586
+ # # top = np.triu(matrix)
587
+ # # bottom = np.tril(matrix_binary)
588
+ #
589
+ # rp_plot.imshow(matrix_binary, cmap='cividis', origin='lower') #interpolation='none'
590
+ # # axs[0].imshow(bottom, cmap='jet', origin='lower') #interpolation='none'
591
+ # rp_plot.plot(max_index,max_index,'orange',marker="o", markersize=7)
592
+ # rp_plot.plot(min_index,min_index,'red',marker="o", markersize=7)
593
+ # # axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
594
+ # # axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
595
+ # rp_plot.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
596
+ # rp_plot.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
597
+ # rp_plot.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_xticks().shape[0])])
598
+ # rp_plot.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_yticks().shape[0])])
599
+ # rp_plot.set_xlabel('Time (s)')
600
+ # rp_plot.set_ylabel('Time (s)')
601
+ # rp_plot.set_title('Recurrence Plot', fontsize=10)
602
+ #
603
+ #
604
+ # # np.linspace(0, stft.shape[1], stft.shape[1]), np.linspace(0, stft.shape[0], cut_freq),
605
+ #
606
+ #
607
+ # spectrogram.pcolormesh(stft.T, shading='gouraud') #,vmax=max_value_stft
608
+ # spectrogram.plot(max_index,0,'orange', marker="o", markersize=7)
609
+ # spectrogram.plot(min_index,0,'red', marker="o", markersize=7)
610
+ # spectrogram.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
611
+ # spectrogram.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, spectrogram.get_xticks().shape[0])])
612
+ # spectrogram.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 5)))
613
+ # spectrogram.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
614
+ # spectrogram.set_ylabel('Freq (Hz)')
615
+ # spectrogram.set_xlabel('Time (s)')
616
+ # spectrogram.set_title('Spectrogram', fontsize=10)
617
+ #
618
+ # max_index_ = stft[max_index]/stft.shape[1]
619
+ # min_index_ = stft[min_index]/stft.shape[1]
620
+ # fft_vector.plot(max_index_**2,'orange')#,marker="o", markersize=2
621
+ # fft_vector.plot(min_index_**2,'red')#,marker="o", markersize=2
622
+ # fft_vector.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
623
+ # fft_vector.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
624
+ # fft_vector.set_xlim([0,100])
625
+ # fft_vector.set_ylabel('Power (µV^2)')
626
+ # fft_vector.set_xlabel('Freq (Hz)')
627
+ # fft_vector.set_title('Frequency Domain', size=10)
628
+ #
629
+ # plt.suptitle( 'Condition: '+ task + '\n' + 'epsilon {}, FFT window size {} '.format(
630
+ # str(info_args['eps']), str(info_args['win_len'])) + '\n'
631
+ # + 'Subject {}, electrode {}, n_fft {}'.format(str(info_args['selected_subject']),str(info_args['electrode_name']),str(info_args['n_fft'])),
632
+ # fontsize=8,ha='left',va='top')
633
+ # plt.tight_layout()
634
+
BrainPulse/recurrence_quantification_analysis.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numba import jit
2
+ import numpy as np
3
+ from pyrqa.computation import RPComputation, RQAComputation
4
+ from pyrqa.time_series import TimeSeries, EmbeddedSeries
5
+ from pyrqa.settings import Settings
6
+ from pyrqa.analysis_type import Classic
7
+ from pyrqa.metric import EuclideanMetric
8
+ # from pyrqa.metric import Sigmoid
9
+ # from pyrqa.metric import Cosine
10
+ from pyrqa.neighbourhood import Unthresholded,FixedRadius
11
+
12
+ def get_results(recurrence_matrix,
13
+ minimum_diagonal_line_length,
14
+ minimum_vertical_line_length,
15
+ minimum_white_vertical_line_length):
16
+
17
+ number_of_vectors = recurrence_matrix.shape[0]
18
+ diagonal = diagonal_frequency_distribution(recurrence_matrix)
19
+ vertical = vertical_frequency_distribution(recurrence_matrix)
20
+ white = white_vertical_frequency_distribution(recurrence_matrix)
21
+
22
+ number_of_vert_lines = number_of_vertical_lines(vertical, minimum_vertical_line_length)
23
+ number_of_vert_lines_points = number_of_vertical_lines_points(vertical, minimum_vertical_line_length)
24
+
25
+ RR = recurrence_rate(recurrence_matrix)
26
+ DET = determinism(number_of_vectors, diagonal, minimum_diagonal_line_length)
27
+ L = average_diagonal_line_length(number_of_vectors, diagonal, minimum_diagonal_line_length)
28
+ Lmax = longest_diagonal_line_length(number_of_vectors, diagonal)
29
+ DIV = divergence(Lmax)
30
+ Lentr = entropy_diagonal_lines(number_of_vectors, diagonal, minimum_diagonal_line_length)
31
+ DET_RR = ratio_determinism_recurrence_rate(DET, RR)
32
+ LAM = laminarity(number_of_vectors, vertical, minimum_vertical_line_length)
33
+ V = average_vertical_line_length(number_of_vectors, vertical, minimum_vertical_line_length)
34
+ Vmax = longest_vertical_line_length(number_of_vectors, vertical)
35
+ Ventr = entropy_vertical_lines(number_of_vectors, vertical, minimum_vertical_line_length)
36
+ LAM_DET = laminarity_determinism(LAM, DET)
37
+ W = average_white_vertical_line_length(number_of_vectors, white, minimum_white_vertical_line_length)
38
+ Wmax = longest_white_vertical_line_length(number_of_vectors, white)
39
+ Wentr = entropy_white_vertical_lines(number_of_vectors, white, minimum_white_vertical_line_length)
40
+ TT = trapping_time(number_of_vert_lines_points, number_of_vert_lines)
41
+
42
+ return [RR, DET, L, Lmax, DIV, Lentr, DET_RR, LAM, V, Vmax, Ventr, LAM_DET, W, Wmax, Wentr, TT]
43
+
44
+
45
+ @jit(nopython=True)
46
+ def diagonal_frequency_distribution(recurrence_matrix):
47
+ # Calculating the number of states - N
48
+ number_of_vectors = recurrence_matrix.shape[0]
49
+ diagonal_frequency_distribution = np.zeros(number_of_vectors + 1)
50
+
51
+ # Calculating the diagonal frequency distribution - P(l)
52
+ for i in range(number_of_vectors - 1, -1, -1):
53
+ diagonal_line_length = 0
54
+ for j in range(0, number_of_vectors - i):
55
+ if recurrence_matrix[i + j, j] == 1:
56
+ diagonal_line_length += 1
57
+ if j == (number_of_vectors - i - 1):
58
+ diagonal_frequency_distribution[diagonal_line_length] += 1.0
59
+ else:
60
+ if diagonal_line_length != 0:
61
+ diagonal_frequency_distribution[diagonal_line_length] += 1.0
62
+ diagonal_line_length = 0
63
+ for k in range(1, number_of_vectors):
64
+ diagonal_line_length = 0
65
+ for i in range(number_of_vectors - k):
66
+ j = i + k
67
+ if recurrence_matrix[i, j] == 1:
68
+ diagonal_line_length += 1
69
+ if j == (number_of_vectors - 1):
70
+ diagonal_frequency_distribution[diagonal_line_length] += 1.0
71
+ else:
72
+ if diagonal_line_length != 0:
73
+ diagonal_frequency_distribution[diagonal_line_length] += 1.0
74
+ diagonal_line_length = 0
75
+
76
+ return diagonal_frequency_distribution
77
+
78
+
79
+ @jit(nopython=True)
80
+ def vertical_frequency_distribution(recurrence_matrix):
81
+ number_of_vectors = recurrence_matrix.shape[0]
82
+
83
+ # Calculating the vertical frequency distribution - P(v)
84
+ vertical_frequency_distribution = np.zeros(number_of_vectors + 1)
85
+ for i in range(number_of_vectors):
86
+ vertical_line_length = 0
87
+ for j in range(number_of_vectors):
88
+ if recurrence_matrix[i, j] == 1:
89
+ vertical_line_length += 1
90
+ if j == (number_of_vectors - 1):
91
+ vertical_frequency_distribution[vertical_line_length] += 1.0
92
+ else:
93
+ if vertical_line_length != 0:
94
+ vertical_frequency_distribution[vertical_line_length] += 1.0
95
+ vertical_line_length = 0
96
+
97
+ return vertical_frequency_distribution
98
+
99
+
100
+ @jit(nopython=True)
101
+ def white_vertical_frequency_distribution(recurrence_matrix):
102
+ number_of_vectors = recurrence_matrix.shape[0]
103
+
104
+ # Calculating the white vertical frequency distribution - P(w)
105
+ white_vertical_frequency_distribution = np.zeros(number_of_vectors + 1)
106
+ for i in range(number_of_vectors):
107
+ white_vertical_line_length = 0
108
+ for j in range(number_of_vectors):
109
+ if recurrence_matrix[i, j] == 0:
110
+ white_vertical_line_length += 1
111
+ if j == (number_of_vectors - 1):
112
+ white_vertical_frequency_distribution[white_vertical_line_length] += 1.0
113
+ else:
114
+ if white_vertical_line_length != 0:
115
+ white_vertical_frequency_distribution[white_vertical_line_length] += 1.0
116
+ white_vertical_line_length = 0
117
+
118
+ return white_vertical_frequency_distribution
119
+
120
+
121
+ @jit(nopython=True)
122
+ def recurrence_rate(recurrence_matrix):
123
+ # Calculating the recurrence rate - RR
124
+ number_of_vectors = recurrence_matrix.shape[0]
125
+ return np.float(np.sum(recurrence_matrix)) / np.power(number_of_vectors, 2)
126
+
127
+
128
+ def determinism(number_of_vectors, diagonal_frequency_distribution_, minimum_diagonal_line_length):
129
+ # Calculating the determinism - DET
130
+ numerator = np.sum(
131
+ [l * diagonal_frequency_distribution_[l] for l in range(minimum_diagonal_line_length, number_of_vectors)])
132
+ denominator = np.sum([l * diagonal_frequency_distribution_[l] for l in range(1, number_of_vectors)])
133
+ return numerator / denominator
134
+
135
+
136
+ def average_diagonal_line_length(number_of_vectors, diagonal_frequency_distribution_, minimum_diagonal_line_length):
137
+ # Calculating the average diagonal line length - L
138
+ numerator = np.sum(
139
+ [l * diagonal_frequency_distribution_[l] for l in range(minimum_diagonal_line_length, number_of_vectors)])
140
+ denominator = np.sum(
141
+ [diagonal_frequency_distribution_[l] for l in range(minimum_diagonal_line_length, number_of_vectors)])
142
+ return numerator / denominator
143
+
144
+
145
+ @jit(nopython=True)
146
+ def longest_diagonal_line_length(number_of_vectors, diagonal_frequency_distribution_):
147
+ # Calculating the longest diagonal line length - Lmax
148
+ for l in range(number_of_vectors - 1, 0, -1):
149
+ if diagonal_frequency_distribution_[l] != 0:
150
+ longest_diagonal_line_length = l
151
+ break
152
+ return longest_diagonal_line_length
153
+
154
+
155
+ @jit(nopython=True)
156
+ def divergence(longest_diagonal_line_length_):
157
+ # Calculating the divergence - DIV
158
+ return 1. / longest_diagonal_line_length_
159
+
160
+
161
+ @jit(nopython=True)
162
+ def entropy_diagonal_lines(number_of_vectors, diagonal_frequency_distribution_, minimum_diagonal_line_length):
163
+ # Calculating the entropy diagonal lines - Lentr
164
+ sum_diagonal_frequency_distribution = np.float(
165
+ np.sum(diagonal_frequency_distribution_[minimum_diagonal_line_length:-1]))
166
+ entropy_diagonal_lines = 0
167
+ for l in range(minimum_diagonal_line_length, number_of_vectors):
168
+ if diagonal_frequency_distribution_[l] != 0:
169
+ entropy_diagonal_lines += (diagonal_frequency_distribution_[
170
+ l] / sum_diagonal_frequency_distribution) * np.log(
171
+ diagonal_frequency_distribution_[l] / sum_diagonal_frequency_distribution)
172
+ entropy_diagonal_lines *= -1
173
+ return entropy_diagonal_lines
174
+
175
+
176
+ @jit(nopython=True)
177
+ def ratio_determinism_recurrence_rate(determinism_, recurrence_rate_):
178
+ # Calculating the divergence - DIV
179
+ return determinism_ / recurrence_rate_
180
+
181
+
182
+ def laminarity(number_of_vectors, vertical_frequency_distribution_, minimum_vertical_line_length):
183
+ # Calculating the laminarity - LAM
184
+ numerator = np.sum(
185
+ [v * vertical_frequency_distribution_[v] for v in range(minimum_vertical_line_length, number_of_vectors + 1)])
186
+ denominator = np.sum([v * vertical_frequency_distribution_[v] for v in range(1, number_of_vectors + 1)])
187
+ return numerator / denominator
188
+
189
+
190
+ def average_vertical_line_length(number_of_vectors, vertical_frequency_distribution_, minimum_vertical_line_length):
191
+ # Calculating the average vertical line length - V
192
+ numerator = np.sum(
193
+ [v * vertical_frequency_distribution_[v] for v in range(minimum_vertical_line_length, number_of_vectors + 1)])
194
+ denominator = np.sum(
195
+ [vertical_frequency_distribution_[v] for v in range(minimum_vertical_line_length, number_of_vectors + 1)])
196
+ return numerator / denominator
197
+
198
+
199
+ @jit(nopython=True)
200
+ def longest_vertical_line_length(number_of_vectors, vertical_frequency_distribution_):
201
+ # Calculating the longest vertical line length - Vmax
202
+ longest_vertical_line_length_ = 0
203
+ for v in range(number_of_vectors, 0, -1):
204
+ if vertical_frequency_distribution_[v] != 0:
205
+ longest_vertical_line_length_ = v
206
+ break
207
+ return longest_vertical_line_length_
208
+
209
+
210
+ @jit(nopython=True)
211
+ def entropy_vertical_lines(number_of_vectors, vertical_frequency_distribution_, minimum_vertical_line_length):
212
+ # Calculating the entropy vertical lines - Ventr
213
+ sum_vertical_frequency_distribution = np.float(
214
+ np.sum(vertical_frequency_distribution_[minimum_vertical_line_length:]))
215
+ entropy_vertical_lines_ = 0
216
+ for v in range(minimum_vertical_line_length, number_of_vectors + 1):
217
+ if vertical_frequency_distribution_[v] != 0:
218
+ entropy_vertical_lines_ += (vertical_frequency_distribution_[
219
+ v] / sum_vertical_frequency_distribution) * np.log(
220
+ vertical_frequency_distribution_[v] / sum_vertical_frequency_distribution)
221
+ entropy_vertical_lines_ *= -1
222
+ return entropy_vertical_lines_
223
+
224
+
225
+ @jit(nopython=True)
226
+ def laminarity_determinism(laminarity_, determinism_):
227
+ # Calculating the ratio laminarity_determinism - LAM/DET
228
+ return laminarity_ / determinism_
229
+
230
+
231
+ def average_white_vertical_line_length(number_of_vectors, white_vertical_frequency_distribution_,
232
+ minimum_white_vertical_line_length):
233
+ # Calculating the average white vertical line length - W
234
+ numerator = np.sum([w * white_vertical_frequency_distribution_[w] for w in
235
+ range(minimum_white_vertical_line_length, number_of_vectors + 1)])
236
+ denominator = np.sum([white_vertical_frequency_distribution_[w] for w in
237
+ range(minimum_white_vertical_line_length, number_of_vectors + 1)])
238
+ return numerator / denominator
239
+
240
+
241
+ @jit(nopython=True)
242
+ def longest_white_vertical_line_length(number_of_vectors, white_vertical_frequency_distribution_):
243
+ # Calculating the longest white vertical line length - Wmax
244
+ longest_white_vertical_line_length_ = 0
245
+ for w in range(number_of_vectors, 0, -1):
246
+ if white_vertical_frequency_distribution_[w] != 0:
247
+ longest_white_vertical_line_length_ = w
248
+ break
249
+ return longest_white_vertical_line_length_
250
+
251
+
252
+ @jit(nopython=True)
253
+ def entropy_white_vertical_lines(number_of_vectors, white_vertical_frequency_distribution_,
254
+ minimum_white_vertical_line_length):
255
+ # Calculating the entropy white vertical lines - Wentr
256
+ sum_white_vertical_frequency_distribution = np.float(
257
+ np.sum(white_vertical_frequency_distribution_[minimum_white_vertical_line_length:]))
258
+ entropy_white_vertical_lines_ = 0
259
+ for w in range(minimum_white_vertical_line_length, number_of_vectors + 1):
260
+ if white_vertical_frequency_distribution_[w] != 0:
261
+ entropy_white_vertical_lines_ += (white_vertical_frequency_distribution_[
262
+ w] / sum_white_vertical_frequency_distribution) * np.log(
263
+ white_vertical_frequency_distribution_[w] / sum_white_vertical_frequency_distribution)
264
+ entropy_white_vertical_lines_ *= -1
265
+ return entropy_white_vertical_lines_
266
+
267
+ def number_of_vertical_lines(vertical_frequency_distribution_, minimum_vertical_line_length):
268
+ if minimum_vertical_line_length > 0:
269
+ return np.sum(vertical_frequency_distribution_[minimum_vertical_line_length - 1:])
270
+
271
+ return np.uint(0)
272
+
273
+
274
+ def number_of_vertical_lines_points(vertical_frequency_distribution_, minimum_vertical_line_length):
275
+ if minimum_vertical_line_length > 0:
276
+ return np.sum(
277
+ ((np.arange(vertical_frequency_distribution_.size) + 1) * vertical_frequency_distribution_)[minimum_vertical_line_length - 1:])
278
+
279
+ return np.uint(0)
280
+
281
+ @jit(nopython=True)
282
+ def trapping_time(number_of_vertical_lines_points_, number_of_vertical_lines_):
283
+ """
284
+ Trapping time (TT).
285
+ """
286
+ try:
287
+ return np.float32(number_of_vertical_lines_points_ / number_of_vertical_lines_)
288
+ except:
289
+ return 0
290
+
291
+
292
+
293
+
294
+
295
+ def return_pyRQA_results(signal, nbr):
296
+ time_series = EmbeddedSeries(signal)
297
+ settings = Settings(time_series,
298
+ analysis_type=Classic,
299
+ neighbourhood=FixedRadius(nbr),
300
+ similarity_measure=EuclideanMetric,
301
+ theiler_corrector=1)
302
+ computation = RQAComputation.create(settings,
303
+ verbose=True)
304
+ result = computation.run()
305
+ return result
BrainPulse/requirements.txt ADDED
File without changes
BrainPulse/vector_space.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+
4
+ def compute_stft(epoch_, n_fft, win_len, s_rate, cut_freq):
5
+ # stft_time = time.time()
6
+
7
+ signal_tensor = torch.tensor(epoch_, dtype=torch.float)
8
+ stft_tensor = torch.stft(signal_tensor,n_fft=n_fft, win_length=win_len, hop_length=1,return_complex=True,window=torch.hann_window(win_len))
9
+
10
+ sft = torch.abs(stft_tensor).numpy()
11
+
12
+ freq_to_take = (((n_fft/2)+1)*cut_freq) / ((s_rate/2)+1)
13
+
14
+ sft = sft[:int(freq_to_take),::]
15
+
16
+ return sft.T
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # syntax=docker/dockerfile:1
2
+
3
+ FROM python:3.8-slim-buster
4
+ RUN apt-get update && apt-get install -y git
5
+
6
+ COPY ./requirements.txt /app/requirements.txt
7
+
8
+ WORKDIR /app
9
+
10
+ RUN pip3 install --no-cache-dir -r requirements.txt
11
+
12
+ EXPOSE 5000
13
+
14
+ COPY . /app
15
+
16
+ # ENTRYPOINT [ "python" ]
17
+ # ENTRYPOINT [ "python" ]
18
+ # ENTRYPOINT [ "streamlit run BrainPulseAPP.py" ]
19
+
20
+ # CMD [ "BrainPulseAPP.py" ]
21
+ CMD ["streamlit", "run", "BrainPulseAPP.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Łukasz Furman
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,2 @@
1
- ---
2
- title: NCU
3
- emoji: 📚
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.10.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # BrainPulse_webapp
2
+
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cProfile import run
2
+ import streamlit as st
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+ import numpy as np
6
+ import pandas as pd
7
+ import plotly.graph_objects as go
8
+ from complexRadar import ComplexRadar
9
+ import math
10
+ from zipfile import ZipFile
11
+ from glob import glob
12
+ import os
13
+ from BrainPulse import (dataset,
14
+ vector_space,
15
+ distance_matrix,
16
+ recurrence_quantification_analysis,
17
+ features_space,
18
+ plot)
19
+
20
+ # path
21
+ path = "./mne_data"
22
+ path2 = "./RPs"
23
+
24
+ # Remove the specified
25
+ # file path
26
+ try:
27
+ os.remove(path)
28
+ print("% s removed successfully" % path)
29
+ except:
30
+ pass
31
+
32
+ path = "./mne_data"
33
+ os.makedirs(path, exist_ok = True)
34
+ path1 = "./RPs"
35
+ os.makedirs(path1, exist_ok = True)
36
+
37
+ def run_computation(t_start, t_end, selected_subject, fir_filter, electrode_name, cut_freq, win_len, n_fft, percentile, run_list, options):
38
+
39
+ epochs, raw = dataset.eegbci_data(tmin=t_start, tmax=t_end,
40
+ subject=selected_subject,
41
+ filter_range=fir_filter,run_list=run_list)
42
+
43
+ s_rate = epochs.info['sfreq']
44
+
45
+ electrode_index = epochs.ch_names.index(electrode_name)
46
+
47
+ electrode_open = epochs.get_data()[0][electrode_index]
48
+ electrode_close = epochs.get_data()[1][electrode_index]
49
+
50
+ stft_open = vector_space.compute_stft(electrode_open,
51
+ n_fft=n_fft, win_len=win_len,
52
+ s_rate=epochs.info['sfreq'],
53
+ cut_freq=cut_freq)
54
+
55
+ stft_close = vector_space.compute_stft(electrode_close,
56
+ n_fft=n_fft, win_len=win_len,
57
+ s_rate=epochs.info['sfreq'],
58
+ cut_freq=cut_freq)
59
+ del raw
60
+ del electrode_open, electrode_close
61
+ # matrix_open = distance_matrix.EuclideanPyRQA_RP_stft(stft_open)
62
+ # matrix_close = distance_matrix.EuclideanPyRQA_RP_stft(stft_close)
63
+ matrix_open = distance_matrix.EuclideanPyRQA_RP_stft_cpu(stft_open)
64
+ matrix_close = distance_matrix.EuclideanPyRQA_RP_stft_cpu(stft_close)
65
+
66
+ nbr_open = np.percentile(matrix_open, percentile)
67
+ nbr_close = np.percentile(matrix_close, percentile)
68
+
69
+ matrix_open_binary = distance_matrix.set_epsilon(matrix_open,nbr_open)
70
+ matrix_close_binary = distance_matrix.set_epsilon(matrix_close,nbr_close)
71
+
72
+ del matrix_open, matrix_close
73
+ # matrix_open_to_plot = matrix_open_binary
74
+ # matrix_closed_to_plot = matrix_close_binary
75
+
76
+ fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(16,8),dpi=200)
77
+ ax1.imshow(matrix_open_binary, cmap='Greys', origin='lower') #cividis
78
+ ax1.set_xticks(np.linspace(0, matrix_open_binary.shape[0] , ax1.get_xticks().shape[0]))
79
+ ax1.set_yticks(np.linspace(0, matrix_open_binary.shape[0] , ax1.get_xticks().shape[0]))
80
+ ax1.set_xticklabels([str(np.around(x,decimals=0)) for x in np.linspace(0, matrix_open_binary.shape[0] / s_rate, ax1.get_xticks().shape[0])])
81
+ ax1.set_yticklabels([str(np.around(x, decimals=0)) for x in np.linspace(0, matrix_open_binary.shape[0] / s_rate, ax1.get_yticks().shape[0])])
82
+ ax1.set_title(options[0]+' window size = 240 samples, ε = '+str(np.round(nbr_open,4)))
83
+ ax1.set_xlabel('time (s)')
84
+ ax1.set_ylabel('time (s)')
85
+
86
+ ax2.imshow(matrix_close_binary, cmap='Greys', origin='lower')
87
+ ax2.set_xticks(np.linspace(0, matrix_close_binary.shape[0] , ax1.get_xticks().shape[0]))
88
+ ax2.set_yticks(np.linspace(0, matrix_close_binary.shape[0] , ax1.get_xticks().shape[0]))
89
+ ax2.set_xticklabels([str(np.around(x,decimals=0)) for x in np.linspace(0, matrix_close_binary.shape[0] / s_rate, ax1.get_xticks().shape[0])])
90
+ ax2.set_yticklabels([str(np.around(x, decimals=0)) for x in np.linspace(0, matrix_close_binary.shape[0] / s_rate, ax2.get_yticks().shape[0])])
91
+ ax2.set_title(options[1]+' window size = 240 samples, ε = '+str(np.round(nbr_close,4)))
92
+ ax2.set_xlabel('time (s)')
93
+ ax2.set_ylabel('time (s)')
94
+
95
+ return fig, matrix_open_binary, matrix_close_binary, epochs, stft_open, stft_close
96
+
97
+
98
+ def plot_rqa(matrix_open_binary, matrix_close_binary, min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len,options):
99
+
100
+ categories = ['RR', 'DET', 'L', 'Lmax', 'DIV', 'Lentr', 'DET_RR', 'LAM', 'V', 'Vmax', 'Ventr', 'LAM_DET', 'W', 'Wmax', 'Wentr', 'TT']
101
+
102
+ result_rqa_open = recurrence_quantification_analysis.get_results(matrix_open_binary,min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len)
103
+ result_rqa_closed = recurrence_quantification_analysis.get_results(matrix_close_binary,min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len)
104
+
105
+ data = pd.DataFrame([result_rqa_open,result_rqa_closed], columns=categories)
106
+
107
+ data = data.drop(['RR', 'DIV', 'Lmax'],axis=1)
108
+ # print(data)
109
+ min_max_per_variable = data.describe().T[['min', 'max']]
110
+ min_max_per_variable['min'] = min_max_per_variable['min'].apply(lambda x: int(x))
111
+ min_max_per_variable['max'] = min_max_per_variable['max'].apply(lambda x: math.ceil(x))
112
+ # print(min_max_per_variable)
113
+
114
+
115
+
116
+ variables = data.columns
117
+ ranges = list(min_max_per_variable.itertuples(index=False, name=None))
118
+
119
+ format_cfg = {
120
+ #'axes_args':{'facecolor':'#84A8CD'},
121
+ 'rad_ln_args': {'visible':True, 'linestyle':'dotted'},
122
+ 'angle_ln_args':{'linestyle':'dotted'},
123
+ 'outer_ring': {'visible':True, 'linestyle':'dotted'},
124
+ 'rgrid_tick_lbls_args': {'fontsize':6},
125
+ 'theta_tick_lbls': {'fontsize':9, 'backgroundcolor':'#355C7D', 'color':'#FFFFFF'},
126
+ 'theta_tick_lbls_pad':3
127
+ }
128
+
129
+
130
+ fig = plt.figure(figsize=(5,3),dpi=100)
131
+ radar = ComplexRadar(fig, variables, ranges,n_ring_levels=3 ,show_scales=True, format_cfg=format_cfg)
132
+
133
+
134
+ custom_colors = ['#F67280', '#6C5B7B', '#355C7D']
135
+ k=0
136
+ for g,c in zip(data.index, custom_colors):
137
+ # radar.plot(data.loc[g].values, label=f"condition {g}", color=c, marker='o')
138
+ radar.plot(data.loc[g].values, label=options[k], color=c, marker='o')
139
+ radar.fill(data.loc[g].values, alpha=0.5, color=c)
140
+ k+=1
141
+
142
+ radar.use_legend(loc='upper left', bbox_to_anchor=(-0.4, 1.1), fontsize = 'xx-small') #, bbox_to_anchor=(0.15, -0.25),ncol=radar.plot_counter
143
+
144
+ return fig
145
+
146
+ def waterfall_spectrum(stft1, stft2, s_rate, cut_freq, options):
147
+
148
+ fig = plt.figure(figsize=(14, 12), dpi=150)
149
+ grid = plt.GridSpec(8, 8, hspace=0.0, wspace=3.5)
150
+ spectrogram1 = fig.add_subplot(grid[0:3, 0:4])
151
+ spectrogram2 = fig.add_subplot(grid[0:3, 4:])
152
+
153
+ spectrogram1.pcolormesh(stft1.T,cmap='viridis')
154
+ spectrogram1.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft1.shape[0], 5)))
155
+ spectrogram1.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, stft1.shape[0] / s_rate, spectrogram1.get_xticks().shape[0])])
156
+ spectrogram1.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft1.shape[1], 5)))
157
+ spectrogram1.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
158
+ spectrogram1.set_ylabel('Freq (Hz)', )
159
+ spectrogram1.set_xlabel('Time (s)', )
160
+ spectrogram1.set_title(options[0] + ' Spectrogram', )
161
+
162
+ spectrogram2.pcolormesh(stft2.T,cmap='viridis')
163
+ spectrogram2.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft2.shape[0], 5)))
164
+ spectrogram2.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, stft2.shape[0] / s_rate, spectrogram2.get_xticks().shape[0])])
165
+ spectrogram2.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft2.shape[1], 5)))
166
+ spectrogram2.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
167
+ spectrogram2.set_ylabel('Freq (Hz)', )
168
+ spectrogram2.set_xlabel('Time (s)', )
169
+ spectrogram2.set_title(options[1] +' Spectrogram', )
170
+ return fig
171
+
172
+ def save(matrix_open_binary, matrix_close_binary):
173
+
174
+ file_name_open = './RPs/subject-'+str(selected_subject)+'_electrode-'+electrode_name+'_percentile-'+str(percentile)+'_run-open_binary.npy'
175
+ np.save(file_name_open, np.asarray(matrix_close_binary, dtype=np.ubyte))
176
+ file_name_close = './RPs/subject-'+str(selected_subject)+'_electrode-'+electrode_name+'_percentile-'+str(percentile)+'_run-close_binary.npy'
177
+ np.save(file_name_close, np.asarray(matrix_close_binary, dtype=np.ubyte))
178
+
179
+ def download():
180
+
181
+ file_paths = glob('./RPs/*')
182
+
183
+ with ZipFile('download.zip','w') as zip:
184
+ for file in file_paths:
185
+ # writing each file one by one
186
+ zip.write(file)
187
+
188
+ return open('download.zip', 'rb')
189
+
190
+ # ---------------Settings--------------------
191
+
192
+ st.set_page_config(layout="wide")
193
+ st.title('BrainPulse Playground')
194
+ sidebar = st.sidebar
195
+
196
+ selected_subject = sidebar.slider('Select Subject', 0, 100, 25)
197
+
198
+ electrode_name = sidebar.selectbox(
199
+ 'Select Electrode',
200
+ ('FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FT8', 'T7', 'T8', 'T9', 'T10', 'TP7', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2', 'Iz'))
201
+
202
+ t_start, t_end = sidebar.slider(
203
+ 'Select a time range in seconds',
204
+ 0, 60, (0, 30), step=1)
205
+
206
+ f1, f2 = sidebar.slider(
207
+ 'Select a FIR filter range',
208
+ 0, 60, (2, 50), step=1)
209
+ fir_filter = [f1, f2]
210
+
211
+ cut_freq = f2
212
+
213
+ win_len = sidebar.slider('FFT window size', 0, 512, 170, step=1)
214
+
215
+ n_fft = sidebar.slider('numer of FFT bins', 0, 1024, 512, step=1)
216
+
217
+ min_vert_line_len = sidebar.slider('Minimum vertical line length', 0, 250, 2, step=1)
218
+
219
+ min_diagonal_line_len = sidebar.slider('Minimum diagonal line length', 0, 250, 2, step=1)
220
+
221
+ min_white_vert_line_len = sidebar.slider('Minimum white vertical line length', 0, 250, 2, step=1)
222
+
223
+ percentile = sidebar.slider('Precentile', 0, 100, 24, step=1)
224
+
225
+ sidebar.download_button('Download file', download(),file_name='archive.zip')
226
+
227
+ # ---------------Plot RPs--------------------
228
+ # runs_ = ['Baseline open eyes', 'Baseline closed eyes', 'Motor execution: left vs right hand', 'Motor imagery: left vs right hand',
229
+ # 'Motor execution: hands vs feet', 'Motor imagery: hands vs feet']
230
+ #
231
+ # options = st.multiselect('Select two runs to compare', runs_, ['Baseline open eyes', 'Baseline closed eyes'])
232
+
233
+
234
+ # run_list = []
235
+ #
236
+ # for v in options:
237
+ # run_list.append(runs_.index(v)+1)
238
+ # if len(run_list) <= 1:
239
+ # run_list = [1,2]
240
+ st.markdown('Baseline open eyes vs Baseline closed eyes')
241
+ options = ['Baseline open eyes', 'Baseline closed eyes']
242
+ run_list = [1,2]
243
+
244
+ rp_plot, matrix_open_binary, matrix_close_binary, epochs, stft1, stft2 = run_computation(t_start, t_end, selected_subject, fir_filter, electrode_name, cut_freq, win_len, n_fft, percentile, run_list,options)
245
+ st.write(rp_plot)
246
+
247
+ # ---------------Plot Spectrum--------------------
248
+ st.write(waterfall_spectrum(stft1, stft2, 160, cut_freq, options))
249
+
250
+ # ---------------Save RPs--------------------
251
+ if st.button('Save RPs as *.npy'):
252
+ save(matrix_open_binary, matrix_close_binary)
253
+
254
+ # ---------------Plot Radar--------------------
255
+ rqa_radar = plot_rqa(matrix_open_binary, matrix_close_binary, min_vert_line_len, min_diagonal_line_len, min_white_vert_line_len, options)
256
+ st.write(rqa_radar)
257
+
258
+
259
+
260
+
complexRadar.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import textwrap
3
+
4
+ class ComplexRadar():
5
+ """
6
+ Create a complex radar chart with different scales for each variable
7
+ Parameters
8
+ ----------
9
+ fig : figure object
10
+ A matplotlib figure object to add the axes on
11
+ variables : list
12
+ A list of variables
13
+ ranges : list
14
+ A list of tuples (min, max) for each variable
15
+ n_ring_levels: int, defaults to 5
16
+ Number of ordinate or ring levels to draw
17
+ show_scales: bool, defaults to True
18
+ Indicates if we the ranges for each variable are plotted
19
+ format_cfg: dict, defaults to None
20
+ A dictionary with formatting configurations
21
+ """
22
+ def __init__(self, fig, variables, ranges, n_ring_levels=5, show_scales=True, format_cfg=None):
23
+
24
+ # Default formatting
25
+ self.format_cfg = {
26
+ # Axes
27
+ # https://matplotlib.org/stable/api/figure_api.html
28
+ 'axes_args': {},
29
+ # Tick labels on the scales
30
+ # https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.rgrids.html
31
+ 'rgrid_tick_lbls_args': {'fontsize':8},
32
+ # Radial (circle) lines
33
+ # https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.grid.html
34
+ 'rad_ln_args': {},
35
+ # Angle lines
36
+ # https://matplotlib.org/3.2.2/api/_as_gen/matplotlib.lines.Line2D.html#matplotlib.lines.Line2D
37
+ 'angle_ln_args': {},
38
+ # Include last value (endpoint) on scale
39
+ 'incl_endpoint':False,
40
+ # Variable labels (ThetaTickLabel)
41
+ 'theta_tick_lbls':{'va':'top', 'ha':'center'},
42
+ 'theta_tick_lbls_txt_wrap':15,
43
+ 'theta_tick_lbls_brk_lng_wrds':False,
44
+ 'theta_tick_lbls_pad':25,
45
+ # Outer ring
46
+ # https://matplotlib.org/stable/api/spines_api.html
47
+ 'outer_ring':{'visible':True, 'color':'#d6d6d6'}
48
+ }
49
+
50
+ if format_cfg is not None:
51
+ self.format_cfg = { k:(format_cfg[k]) if k in format_cfg.keys() else (self.format_cfg[k])
52
+ for k in self.format_cfg.keys()}
53
+
54
+
55
+ # Calculate angles and create for each variable an axes
56
+ # Consider here the trick with having the first axes element twice (len+1)
57
+ angles = np.arange(0, 360, 360./len(variables))
58
+ axes = [fig.add_axes([0.1,0.1,0.9,0.9],
59
+ polar=True,
60
+ label = "axes{}".format(i),
61
+ **self.format_cfg['axes_args']) for i in range(len(variables)+1)]
62
+
63
+ # Ensure clockwise rotation (first variable at the top N)
64
+ for ax in axes:
65
+ ax.set_theta_zero_location('N')
66
+ ax.set_theta_direction(-1)
67
+ ax.set_axisbelow(True)
68
+
69
+ # Writing the ranges on each axes
70
+ for i, ax in enumerate(axes):
71
+
72
+ # Here we do the trick by repeating the first iteration
73
+ j = 0 if (i==0 or i==1) else i-1
74
+ ax.set_ylim(*ranges[j])
75
+ # Set endpoint to True if you like to have values right before the last circle
76
+ grid = np.linspace(*ranges[j], num=n_ring_levels,
77
+ endpoint=self.format_cfg['incl_endpoint'])
78
+ gridlabel = ["{}".format(round(x,2)) for x in grid]
79
+ gridlabel[0] = "" # remove values from the center
80
+ lines, labels = ax.set_rgrids(grid,
81
+ labels=gridlabel,
82
+ angle=angles[j],
83
+ **self.format_cfg['rgrid_tick_lbls_args']
84
+ )
85
+
86
+ ax.set_ylim(*ranges[j])
87
+ ax.spines["polar"].set_visible(False)
88
+ ax.grid(visible=False)
89
+
90
+ if show_scales == False:
91
+ ax.set_yticklabels([])
92
+
93
+ # Set all axes except the first one unvisible
94
+ for ax in axes[1:]:
95
+ ax.patch.set_visible(False)
96
+ ax.xaxis.set_visible(False)
97
+
98
+ # Setting the attributes
99
+ self.angle = np.deg2rad(np.r_[angles, angles[0]])
100
+ self.ranges = ranges
101
+ self.ax = axes[0]
102
+ self.ax1 = axes[1]
103
+ self.plot_counter = 0
104
+
105
+
106
+ # Draw (inner) circles and lines
107
+ self.ax.yaxis.grid(**self.format_cfg['rad_ln_args'])
108
+ # Draw outer circle
109
+ self.ax.spines['polar'].set(**self.format_cfg['outer_ring'])
110
+ # Draw angle lines
111
+ self.ax.xaxis.grid(**self.format_cfg['angle_ln_args'])
112
+
113
+ # ax1 is the duplicate of axes[0] (self.ax)
114
+ # Remove everything from ax1 except the plot itself
115
+ self.ax1.axis('off')
116
+ self.ax1.set_zorder(9)
117
+
118
+ # Create the outer labels for each variable
119
+ l, text = self.ax.set_thetagrids(angles, labels=variables)
120
+
121
+ # Beautify them
122
+ labels = [t.get_text() for t in self.ax.get_xticklabels()]
123
+ labels = ['\n'.join(textwrap.wrap(l, self.format_cfg['theta_tick_lbls_txt_wrap'],
124
+ break_long_words=self.format_cfg['theta_tick_lbls_brk_lng_wrds'])) for l in labels]
125
+ self.ax.set_xticklabels(labels, **self.format_cfg['theta_tick_lbls'])
126
+
127
+ for t,a in zip(self.ax.get_xticklabels(),angles):
128
+ if a == 0:
129
+ t.set_ha('center')
130
+ elif a > 0 and a < 180:
131
+ t.set_ha('left')
132
+ elif a == 180:
133
+ t.set_ha('center')
134
+ else:
135
+ t.set_ha('right')
136
+
137
+ self.ax.tick_params(axis='both', pad=self.format_cfg['theta_tick_lbls_pad'])
138
+
139
+
140
+ def _scale_data(self, data, ranges):
141
+ """Scales data[1:] to ranges[0]"""
142
+ for d, (y1, y2) in zip(data[1:], ranges[1:]):
143
+ assert (y1 <= d <= y2) or (y2 <= d <= y1)
144
+ x1, x2 = ranges[0]
145
+ d = data[0]
146
+ sdata = [d]
147
+ for d, (y1, y2) in zip(data[1:], ranges[1:]):
148
+ sdata.append((d-y1) / (y2-y1) * (x2 - x1) + x1)
149
+ return sdata
150
+
151
+ def plot(self, data, *args, **kwargs):
152
+ """Plots a line"""
153
+ sdata = self._scale_data(data, self.ranges)
154
+ self.ax1.plot(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs)
155
+ self.plot_counter = self.plot_counter+1
156
+
157
+ def fill(self, data, *args, **kwargs):
158
+ """Plots an area"""
159
+ sdata = self._scale_data(data, self.ranges)
160
+ self.ax1.fill(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs)
161
+
162
+ def use_legend(self, *args, **kwargs):
163
+ """Shows a legend"""
164
+ self.ax1.legend(*args, **kwargs)
165
+
166
+ def set_title(self, title, pad=25, **kwargs):
167
+ """Set a title"""
168
+ self.ax.set_title(title,pad=pad, **kwargs)
papaerspace_config.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ image: datalab108/cookiewriter64:test
2
+ port: 5005
3
+ resources:
4
+ replicas: 1
5
+ instanceType: P4000
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ umap-learn
5
+ umap-learn[plot]
6
+ PyRQA
7
+ pandas
8
+ mne
9
+ plotly
10
+ streamlit
test.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import time
2
+
3
+ while True:
4
+ print("Hello World")
5
+ time.sleep(10)