|
import time
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib import animation, cm
|
|
import matplotlib
|
|
import seaborn as sns
|
|
import seaborn as sns
|
|
import umap
|
|
import umap.plot
|
|
import pandas as pd
|
|
from .event import EventSegment
|
|
from sklearn.impute import SimpleImputer
|
|
from sklearn.pipeline import make_pipeline
|
|
from sklearn.preprocessing import QuantileTransformer
|
|
|
|
sns.set_style("whitegrid")
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.font_manager as font_manager
|
|
font = font_manager.FontProperties(family='Times')
|
|
|
|
def explainer_(chan, stft, cut_freq, s_rate):
|
|
|
|
fig, axs = plt.subplots(4, figsize=(10, 14), dpi=150)
|
|
time_crop = np.linspace(0, int(chan[:400].shape[0]), chan[:400].shape[0])
|
|
|
|
axs[0].plot(chan[:400],'k')
|
|
axs[0].fill_betweenx(y=[-210, 125], x1=time_crop[0],
|
|
x2=time_crop[240], color='white', alpha=0.9, edgecolor='red' )
|
|
|
|
axs[0].fill_betweenx(y=[-210, 130], x1=time_crop[2]+20,
|
|
x2=time_crop[260], color='white', alpha=0.9, edgecolor='green')
|
|
|
|
axs[0].fill_betweenx(y=[-210, 135], x1=time_crop[2]+40,
|
|
x2=time_crop[280], color='white', alpha=0.9, edgecolor='blue')
|
|
|
|
axs[0].annotate('$fft_{1}$', xy=(.25, 72), xycoords='data',
|
|
xytext=(0.05, 1.45), textcoords='axes fraction',
|
|
arrowprops=dict(arrowstyle="->",facecolor='black',color='black'),
|
|
horizontalalignment='right', verticalalignment='top',
|
|
)
|
|
|
|
axs[0].annotate('$fft_{2}$', xy=(23.35, 85), xycoords='data',
|
|
xytext=(0.15, 1.45), textcoords='axes fraction',
|
|
arrowprops=dict(arrowstyle="->",facecolor='black',color='black'),
|
|
horizontalalignment='right', verticalalignment='top',
|
|
)
|
|
|
|
axs[0].annotate('$fft_{3}$', xy=(43.45, 95), xycoords='data',
|
|
xytext=(0.25, 1.45), textcoords='axes fraction',
|
|
arrowprops=dict(arrowstyle="->",facecolor='black ',color='black'),
|
|
horizontalalignment='right', verticalalignment='top',
|
|
)
|
|
axs[0].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, chan[:400].shape[0], 5)))
|
|
axs[0].set_xticklabels(
|
|
[str(np.round(x, 1)) for x in np.linspace(0, int(chan[:400].shape[0] / s_rate), axs[0].get_xticks().shape[0])])
|
|
axs[0].set_ylabel('Amplitude (µV)', )
|
|
axs[0].set_xlabel('Time (s)', )
|
|
axs[0].set_title('(a)', )
|
|
axs[0].xaxis.grid()
|
|
axs[0].yaxis.grid()
|
|
|
|
|
|
axs[1].plot((stft[100]/stft.shape[1])**2, 'red',label='$fft_{1}$',marker="o",markersize=3)
|
|
axs[1].legend(prop=font)
|
|
axs[1].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
|
axs[1].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
|
axs[1].set_xlim([0, 100])
|
|
|
|
axs[1].set_ylabel('Power ($\mu V^{2}$)', )
|
|
axs[1].set_xlabel('Freq (Hz)', )
|
|
|
|
axs[1].set_title('(b)', )
|
|
axs[1].xaxis.grid()
|
|
axs[1].yaxis.grid()
|
|
|
|
|
|
axs[2].plot((stft[115]/stft.shape[1])**2, 'green', label='$fft_{2}$', marker="o", markersize=3)
|
|
axs[2].legend(prop=font)
|
|
axs[2].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
|
axs[2].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
|
axs[2].set_xlim([0, 100])
|
|
|
|
axs[2].set_ylabel('Power ($\mu V^{2}$)', )
|
|
axs[2].set_xlabel('Freq (Hz)', )
|
|
axs[2].set_title('(c)', )
|
|
axs[2].xaxis.grid()
|
|
axs[2].yaxis.grid()
|
|
|
|
|
|
axs[3].plot((stft[140]/stft.shape[1])**2, 'blue', label='$fft_{3}$', marker="o", markersize=3)
|
|
axs[3].legend(prop=font)
|
|
axs[3].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
|
axs[3].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
|
axs[3].set_xlim([0, 100])
|
|
axs[3].set_ylabel('Power ($\mu V^{2}$)', )
|
|
axs[3].set_xlabel('Freq (Hz)', )
|
|
axs[3].set_title('(d)', )
|
|
axs[3].xaxis.grid()
|
|
axs[3].yaxis.grid()
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
plt.savefig('fig_4.png')
|
|
|
|
|
|
def stft_collections(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args, max_indx = None, min_indx = None):
|
|
fig = plt.figure(figsize=(14, 12), dpi=150)
|
|
grid = plt.GridSpec(6, 8, hspace=0.0, wspace=3.5)
|
|
spectrogram = fig.add_subplot(grid[0:3, 0:4])
|
|
rp_plot = fig.add_subplot(grid[0:3, 4:])
|
|
fft_vector = fig.add_subplot(grid[4:, :])
|
|
|
|
if max_indx != None and min_indx != None:
|
|
max_index = max_indx
|
|
min_index = min_indx
|
|
else:
|
|
max_array = np.max(stft, axis=1)
|
|
max_value_stft = np.max(max_array, axis=0)
|
|
max_index = list(max_array).index(max_value_stft)
|
|
|
|
min_array = np.min(stft, axis=1)
|
|
min_value_stft = np.min(min_array, axis=0)
|
|
min_index = list(min_array).index(min_value_stft)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rp_plot.imshow(matrix_binary, cmap='Greys', origin='lower')
|
|
|
|
rp_plot.plot(max_index, max_index, 'orange', marker="o", markersize=9)
|
|
rp_plot.plot(min_index, min_index, 'red', marker="o", markersize=9)
|
|
|
|
|
|
rp_plot.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
rp_plot.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
rp_plot.set_xticklabels(
|
|
[str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_xticks().shape[0])])
|
|
rp_plot.set_yticklabels(
|
|
[str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, rp_plot.get_yticks().shape[0])])
|
|
rp_plot.set_xlabel('Time (s)', )
|
|
rp_plot.set_ylabel('Time (s)', )
|
|
rp_plot.set_title('(b) Recurrence Plot', )
|
|
rp_plot.xaxis.grid()
|
|
rp_plot.yaxis.grid()
|
|
|
|
spectrogram.pcolormesh(stft.T,cmap='viridis')
|
|
spectrogram.plot(max_index,2,'orange', marker="|", markersize=40)
|
|
spectrogram.plot(min_index,2,'red', marker="|", markersize=40)
|
|
|
|
spectrogram.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[0], 5)))
|
|
spectrogram.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, stft.shape[0] / s_rate, spectrogram.get_xticks().shape[0])])
|
|
spectrogram.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 5)))
|
|
spectrogram.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
|
|
spectrogram.set_ylabel('Freq (Hz)', )
|
|
spectrogram.set_xlabel('Time (s)', )
|
|
spectrogram.set_title('(a) Spectrogram', )
|
|
|
|
|
|
|
|
|
|
|
|
max_index_ = stft[max_index]/stft.shape[1]
|
|
min_index_ = stft[min_index]/stft.shape[1]
|
|
fft_vector.plot(max_index_**2,'orange',label='$fft_{t_{1}}$')
|
|
fft_vector.plot(min_index_**2,'red',label='$fft_{t_{2}}}$')
|
|
fft_vector.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
|
fft_vector.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
|
fft_vector.set_xlim([0,100])
|
|
fft_vector.set_ylabel('Power ($\mu V^{2}$)', )
|
|
fft_vector.set_xlabel('Freq (Hz)', )
|
|
fft_vector.set_title('(c) Frequency Domain', )
|
|
fft_vector.legend(prop=font)
|
|
fft_vector.xaxis.grid()
|
|
fft_vector.yaxis.grid()
|
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
plt.savefig('fig_5.png')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def diagnostic(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args):
|
|
|
|
fig, axs = plt.subplots(3,1, figsize=(7,12), gridspec_kw={'height_ratios':[6,2,1]},dpi=150)
|
|
|
|
|
|
|
|
|
|
max_array = np.max(stft, axis=1)
|
|
max_value_stft = np.max(max_array, axis=0)
|
|
max_index = list(max_array).index(max_value_stft)
|
|
|
|
min_array = np.min(stft, axis=1)
|
|
min_value_stft = np.min(min_array, axis=0)
|
|
min_index = list(min_array).index(min_value_stft)
|
|
|
|
|
|
|
|
|
|
axs[0].imshow(matrix_binary, cmap='cividis', origin='lower')
|
|
|
|
|
|
axs[0].plot(max_index,max_index,'orange',marker="o", markersize=7)
|
|
axs[0].plot(min_index,min_index,'red',marker="o", markersize=7)
|
|
|
|
axs[0].set_yticks(axs[0].get_yticks()[1:len(axs[0].get_yticks())-1])
|
|
axs[0].set_xticks(axs[0].get_xticks()[1:len(axs[0].get_xticks())-1])
|
|
axs[0].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
axs[0].yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
axs[0].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_xticks().shape[0])])
|
|
axs[0].set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[0].get_yticks().shape[0])])
|
|
axs[0].set_xlabel('Time (s)')
|
|
axs[0].set_ylabel('Time (s)')
|
|
axs[0].set_title('Recurrence Plot', )
|
|
|
|
|
|
|
|
|
|
|
|
axs[1].pcolormesh(stft.T, shading='gouraud')
|
|
axs[1].plot(max_index,0,'orange', marker="o", markersize=7)
|
|
axs[1].plot(min_index,0,'red', marker="o", markersize=7)
|
|
axs[1].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
axs[1].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs[1].get_xticks().shape[0])])
|
|
axs[1].yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 5)))
|
|
axs[1].set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 5)])
|
|
axs[1].set_ylabel('Freq (Hz)')
|
|
axs[1].set_xlabel('Time (s)')
|
|
axs[1].set_title('Spectrogram', )
|
|
|
|
max_index_ = stft[max_index]/stft.shape[1]
|
|
min_index_ = stft[min_index]/stft.shape[1]
|
|
axs[2].plot(max_index_**2,'orange')
|
|
axs[2].plot(min_index_**2,'red')
|
|
axs[2].xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, stft.shape[1], 9)))
|
|
axs[2].set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, cut_freq, 9)])
|
|
axs[2].set_xlim([0,100])
|
|
axs[2].set_ylabel('Power (µV^2)')
|
|
axs[2].set_xlabel('Freq (Hz)')
|
|
axs[2].set_title('Frequency Domain',)
|
|
|
|
plt.suptitle( 'Condition: '+ task + '\n' + 'epsilon {}, FFT window size {} '.format(
|
|
str(info_args['eps']), str(info_args['win_len'])) + '\n'
|
|
+ 'Subject {}, electrode {}, n_fft {}'.format(str(info_args['selected_subject']),str(info_args['electrode_name']),str(info_args['n_fft'])),
|
|
ha='left',va='top')
|
|
plt.tight_layout()
|
|
|
|
def RecurrencePlot(matrix, matrix_binary, s_rate, stft, cut_freq, task, info_args):
|
|
|
|
fig, axs = plt.subplots( figsize=(12,12),dpi=200)
|
|
|
|
|
|
|
|
|
|
max_array = np.max(stft, axis=1)
|
|
max_value_stft = np.max(max_array, axis=0)
|
|
max_index = list(max_array).index(max_value_stft)
|
|
|
|
min_array = np.min(stft, axis=1)
|
|
min_value_stft = np.min(min_array, axis=0)
|
|
min_index = list(min_array).index(min_value_stft)
|
|
|
|
|
|
|
|
|
|
axs.imshow(matrix_binary, cmap='cividis', origin='lower')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
axs.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
axs.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
axs.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs.get_xticks().shape[0])])
|
|
axs.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, axs.get_yticks().shape[0])])
|
|
axs.set_xlabel('Time (s)')
|
|
axs.set_ylabel('Time (s)')
|
|
axs.set_title('Recurrence Plot')
|
|
|
|
|
|
|
|
|
|
def features_hists(df, features_list, condition, dpi = 200):
|
|
fig, axs = plt.subplots(len(features_list),figsize=(6, len(features_list)*3), dpi=dpi)
|
|
abc = ['(a)','(b)','(c)','(d)','(e)','(f)']
|
|
|
|
for i,ax in enumerate(axs):
|
|
sns.histplot(data=df, x=features_list[i], hue=condition, alpha=0.8, element="bars", fill=False, ax=ax, kde=True)
|
|
ax.containers[1].remove()
|
|
ax.containers[0].remove()
|
|
ax.xaxis.grid()
|
|
ax.yaxis.grid()
|
|
ax.set_title(abc[i])
|
|
|
|
|
|
plt.autoscale(enable=True, axis='both', tight=None)
|
|
fig.tight_layout()
|
|
|
|
def features_per_subjects_violin(df, features_list, condition, dpi = 200):
|
|
fig, axs = plt.subplots(len(features_list),figsize=(14, len(features_list)*2), dpi=dpi,sharex='col')
|
|
|
|
for i,ax in enumerate(axs):
|
|
sns.violinplot(data=df, x=df.Subject, y=features_list[i], hue=condition, ax=ax, split=True,linewidth=0.2)
|
|
ax.legend(loc='lower right')
|
|
|
|
axs[len(features_list)-1].set_xticklabels(axs[len(features_list)-1].get_xticklabels(), rotation=90)
|
|
|
|
|
|
|
|
plt.tick_params(axis='x', which='major', labelsize=16)
|
|
fig.tight_layout()
|
|
|
|
def umap_on_condition(df,y, title,labels_name,features_list=['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr'], random_state = 70, n_neighbors = 15, min_dist = 0.25, metric = "hamming", df_type=True):
|
|
|
|
|
|
fig, ax1 = plt.subplots(figsize=(8, 8), dpi=150)
|
|
|
|
if df_type:
|
|
stats_data = df
|
|
else:
|
|
stats_data = df[features_list].values
|
|
|
|
|
|
pipe = make_pipeline(SimpleImputer(strategy="mean"), QuantileTransformer())
|
|
X = pipe.fit_transform(stats_data.copy())
|
|
|
|
|
|
manifold = umap.UMAP(random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit(X, y)
|
|
|
|
umap.plot.points(manifold, labels=labels_name, ax=ax1, color_key=np.array(
|
|
[(0, 0.35, 0.73), (1, 0.83, 0)]))
|
|
ax1.set_title(title)
|
|
|
|
def umap_side_by_side_plot(df1, df2, features_list=['TT', 'RR', 'DET', 'LAM', 'L', 'Lentr'], random_state = 70, n_neighbors = 15, min_dist = 0.25, metric = "hamming"):
|
|
|
|
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2,figsize=(16,8),dpi=150)
|
|
|
|
stats_data = df1[features_list].values
|
|
y = df1.Task.values
|
|
|
|
|
|
pipe = make_pipeline(SimpleImputer(strategy="mean"), QuantileTransformer())
|
|
X = pipe.fit_transform(stats_data.copy())
|
|
|
|
|
|
manifold = umap.UMAP(random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit(X, y)
|
|
|
|
umap.plot.points(manifold, labels=y, ax=ax1, color_key=np.array(
|
|
[(0, 0.35, 0.73), (1, 0.83, 0)]))
|
|
ax1.set_xlabel('(a) STFT Condition 0 - open eyes, 1 - closed eyes')
|
|
|
|
stats_data = df2[features_list].values
|
|
y = df2.Task.values
|
|
|
|
|
|
pipe = make_pipeline(SimpleImputer(strategy="mean"), QuantileTransformer())
|
|
X = pipe.fit_transform(stats_data.copy())
|
|
|
|
|
|
manifold = umap.UMAP(random_state=random_state, n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit(X, y)
|
|
|
|
umap.plot.points(manifold, labels=y, ax=ax2, color_key=np.array(
|
|
[(0, 0.35, 0.73), (1, 0.83, 0)]))
|
|
ax2.set_xlabel('(b) TDEMB Condition 0 - open eyes, 1 - closed eyes')
|
|
|
|
return
|
|
|
|
def SVM_histogram(df, lin, lin_pred,title):
|
|
stats_data = df
|
|
plt.figure(dpi=150)
|
|
all_cechy=np.dot(stats_data, lin.coef_.T)
|
|
df_all=pd.DataFrame({'vectors':all_cechy.ravel(), 'Task':lin_pred})
|
|
|
|
|
|
a = sns.histplot(data=df_all, x='vectors', hue='Task', alpha=0.8, element="bars", fill=False,kde=True, kde_kws={'bw_adjust':0.4},palette=np.array([(0.3,0.85,0),(0.8,0.0,0.44)]))
|
|
a.containers[1].remove()
|
|
a.containers[0].remove()
|
|
|
|
plt.title(title)
|
|
plt.xlabel('All')
|
|
plt.grid(b=None)
|
|
plt.show()
|
|
|
|
def f_importances(coef, names):
|
|
imp = coef
|
|
imp,names = zip(*sorted(zip(imp,names)))
|
|
plt.figure()
|
|
plt.barh(range(len(names)), imp, align='center')
|
|
plt.yticks(range(len(names)), names)
|
|
plt.show()
|
|
|
|
def SVM_features_importance(lin):
|
|
|
|
|
|
lebel_ll = np.array([['TT']*int(64)+ ['RR']*int(64)+
|
|
['DET']*int(64)+ ['LAM']*int(64)+
|
|
['L']*int(64)+ ['L_entr']*int(64)])
|
|
|
|
e_long = "Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8 Af3 Af4 Af7 Af8 Afz C1 C2 C3 C4 C5 C6 CZ Cp1 Cp2 Cp3 Cp4 Cp5 Cp6 Cpz F1 F2 F3 F4 F5 F6 F7 F8 Fc1 Fc2 Fc3 Fc4 Fc5 Fc6 Fcz Fp1 Fp2 Fpz Ft7 Ft8 Fz Iz O1 O2 OZ P1 P2 P3 P4 P5 P6 P7 P8 Po3 Po4 Po7 Po8 Poz Pz T10 T7 T8 T9 Tp7 Tp8".replace('\t',',').split(",")
|
|
y_e_long = np.array(np.unique(e_long, return_inverse=True)[1].tolist())
|
|
|
|
df = pd.DataFrame({'feature':lebel_ll[0],
|
|
'electrode':e_long,
|
|
'coef':lin.coef_[0],
|
|
})
|
|
|
|
|
|
f_importances(df.coef, df.feature)
|
|
f_importances(df.coef, df.electrode)
|
|
|
|
|
|
sns.set_theme(style='darkgrid', rc={'figure.dpi': 120},
|
|
font_scale=1.7)
|
|
fig, ax = plt.subplots(figsize=(16, 10))
|
|
ax.set_title('Weight of features by electrodes')
|
|
sns.barplot(x='feature', y='coef', data=df, ax=ax,
|
|
ci=None,
|
|
hue='electrode')
|
|
ax.legend(bbox_to_anchor=(1, 1), title='electrode',prop={'size': 7})
|
|
|
|
|
|
|
|
def soft_bounds(T,seg):
|
|
|
|
|
|
bounds_anim = []
|
|
K = seg[0].shape[1]
|
|
for it in range(1,len(seg)):
|
|
sb = np.zeros((T,T))
|
|
for k in range(K-1):
|
|
p_change = np.diff(seg[it][:,(k+1):].sum(1))
|
|
sb[1:,1:] += np.outer(p_change, seg[it][1:,k:(k+2)].sum(1))
|
|
sb = np.maximum(sb,sb.T)
|
|
sb = sb/np.max(sb)
|
|
bounds_anim.append(sb)
|
|
return bounds_anim
|
|
|
|
|
|
def fitting_animation(seg,matrix,s_rate,meta_tick,metastate_id, state_width,color_states_matrix):
|
|
|
|
bounds_anim = soft_bounds(matrix.shape[0],seg)
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(18, 12), dpi=300)
|
|
grid = plt.GridSpec(4, 12, hspace=0.0, wspace=3.5)
|
|
ax1 = fig.add_subplot(grid[:, 0:4])
|
|
ax2 = fig.add_subplot(grid[:, 4:8])
|
|
ax3 = fig.add_subplot(grid[:, 8:])
|
|
|
|
|
|
|
|
datamat = matrix
|
|
bk = cm.viridis((datamat-np.min(datamat))/(np.max(datamat)-np.min(datamat)))
|
|
im = ax1.imshow(bk, interpolation='none',origin='lower')
|
|
fg = cm.gray(1-(sum(bounds_anim)/len(bounds_anim)))
|
|
|
|
im.set_array(bk * fg)
|
|
|
|
|
|
ax1.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
ax1.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
ax1.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax1.get_xticks().shape[0])])
|
|
ax1.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax1.get_yticks().shape[0])])
|
|
ax1.set_xlabel('Time (s)')
|
|
ax1.set_ylabel('Time (s)')
|
|
ax1.set_title('Metastates plot over recurrence plot', fontsize=10)
|
|
ax1.scatter(meta_tick,meta_tick,s=2)
|
|
|
|
|
|
|
|
ax2.imshow(fg, interpolation='none',origin='lower')
|
|
ax2.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
ax2.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
ax2.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax2.get_xticks().shape[0])])
|
|
ax2.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax2.get_yticks().shape[0])])
|
|
ax2.set_xlabel('Time (s)')
|
|
ax2.set_ylabel('Time (s)')
|
|
ax2.set_title('Metastates plot', fontsize=10)
|
|
ax2.scatter(meta_tick,meta_tick,s=2)
|
|
|
|
text_kwargs = dict(ha='center', va='center', fontsize=4, color='0')
|
|
for i,mstate in enumerate(metastate_id):
|
|
|
|
ax2.annotate('s '+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', xy=(meta_tick[i], meta_tick[i]+(state_width[i]/2)),
|
|
xytext =(meta_tick[i], meta_tick[i]+(state_width[i]/2)+70),
|
|
xycoords='data',
|
|
textcoords='data',
|
|
arrowprops=dict(arrowstyle="->",facecolor='blue'),
|
|
horizontalalignment='right', verticalalignment='top', fontsize=5
|
|
)
|
|
|
|
|
|
color_states = color_states_matrix
|
|
ax3.imshow(fg[:,:,:3]*color_states, interpolation='none',origin='lower')
|
|
ax3.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
ax3.yaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.linspace(0, matrix.shape[0], 5)))
|
|
ax3.set_xticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax3.get_xticks().shape[0])])
|
|
ax3.set_yticklabels([str(np.round(x, 1)) for x in np.linspace(0, matrix.shape[0] / s_rate, ax3.get_yticks().shape[0])])
|
|
ax3.set_xlabel('Time (s)')
|
|
ax3.set_ylabel('Time (s)')
|
|
ax3.set_title('Metastates plot', fontsize=10)
|
|
ax3.scatter(meta_tick,meta_tick,s=2)
|
|
|
|
text_kwargs = dict(ha='center', va='center', fontsize=4, color='0')
|
|
for i,mstate in enumerate(metastate_id):
|
|
|
|
ax3.annotate('s '+str(mstate)+'| '+ str(int(((1/160)*state_width[i])*1000)) + 'ms', xy=(meta_tick[i], meta_tick[i]+(state_width[i]/2)),
|
|
xytext =(meta_tick[i], meta_tick[i]+(state_width[i]/2)+70),
|
|
xycoords='data',
|
|
textcoords='data',
|
|
arrowprops=dict(arrowstyle="->",facecolor='blue'),
|
|
horizontalalignment='right', verticalalignment='top', fontsize=5
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.savefig('Metastate.png')
|
|
|
|
|
|
return fig
|
|
|
|
|
|
|
|
def fit_HMM(matrix,n_events):
|
|
return EventSegment(n_events=n_events).fit(matrix)
|
|
|
|
def metastates(seg,matrix,s_rate,meta_tick,metastate_id, state_width,color_states_matrix):
|
|
fitting_animation(seg,matrix,s_rate,meta_tick,metastate_id, state_width,color_states_matrix)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|