|
"""Event segmentation using a Hidden Markov Model |
|
|
|
Adapted from the brainiak package for this workshop. |
|
See https://brainiak.org/ for full documentation.""" |
|
|
|
import numpy as np |
|
from scipy import stats |
|
import logging |
|
import copy |
|
from sklearn.base import BaseEstimator |
|
from sklearn.utils.validation import check_is_fitted, check_array |
|
from sklearn.exceptions import NotFittedError |
|
import itertools |
|
|
|
|
|
def masked_log(x): |
|
y = np.empty(x.shape, dtype=x.dtype) |
|
lim = x.shape[0] |
|
for i in range(lim): |
|
if x[i] <= 0: |
|
y[i] = float('-inf') |
|
else: |
|
y[i] = np.log(x[i]) |
|
return y |
|
|
|
|
|
class EventSegment(BaseEstimator): |
|
|
|
def _default_var_schedule(step): |
|
return 4 * (0.98 ** (step - 1)) |
|
|
|
def __init__(self, n_events=2, |
|
step_var=_default_var_schedule, |
|
n_iter=500, event_chains=None, |
|
split_merge=False, split_merge_proposals=1): |
|
self.n_events = n_events |
|
self.step_var = step_var |
|
self.n_iter = n_iter |
|
self.split_merge = split_merge |
|
self.split_merge_proposals = split_merge_proposals |
|
if event_chains is None: |
|
self.event_chains = np.zeros(n_events) |
|
else: |
|
self.event_chains = event_chains |
|
|
|
def _fit_validate(self, X): |
|
if len(np.unique(self.event_chains)) > 1: |
|
raise RuntimeError("Cannot fit chains, use set_event_patterns") |
|
|
|
|
|
X = copy.deepcopy(X) |
|
if type(X) is not list: |
|
X = [X] |
|
for i in range(len(X)): |
|
X[i] = check_array(X[i]) |
|
X[i] = X[i].T |
|
|
|
|
|
n_dim = X[0].shape[0] |
|
for i in range(len(X)): |
|
assert (X[i].shape[0] == n_dim) |
|
|
|
|
|
for i in range(len(X)): |
|
X[i] = stats.zscore(X[i], axis=1, ddof=1) |
|
|
|
return X |
|
|
|
def fit(self, X, y=None): |
|
seg = [] |
|
X = self._fit_validate(X) |
|
n_train = len(X) |
|
n_dim = X[0].shape[0] |
|
self.classes_ = np.arange(self.n_events) |
|
|
|
|
|
log_gamma = [] |
|
for i in range(n_train): |
|
log_gamma.append(np.zeros((X[i].shape[1], self.n_events))) |
|
step = 1 |
|
best_ll = float("-inf") |
|
self.ll_ = np.empty((0, n_train)) |
|
while step <= self.n_iter: |
|
iteration_var = self.step_var(step) |
|
|
|
|
|
|
|
seg_prob = [np.exp(lg) / np.sum(np.exp(lg), axis=0) |
|
for lg in log_gamma] |
|
mean_pat = np.empty((n_train, n_dim, self.n_events)) |
|
for i in range(n_train): |
|
mean_pat[i, :, :] = X[i].dot(seg_prob[i]) |
|
mean_pat = np.mean(mean_pat, axis=0) |
|
|
|
|
|
|
|
self.ll_ = np.append(self.ll_, np.empty((1, n_train)), axis=0) |
|
for i in range(n_train): |
|
logprob = self._logprob_obs(X[i], mean_pat, iteration_var) |
|
log_gamma[i], self.ll_[-1, i] = self._forward_backward(logprob) |
|
|
|
if step > 1 and self.split_merge: |
|
curr_ll = np.mean(self.ll_[-1, :]) |
|
self.ll_[-1, :], log_gamma, mean_pat = \ |
|
self._split_merge(X, log_gamma, iteration_var, curr_ll) |
|
|
|
|
|
if np.mean(self.ll_[-1, :]) < best_ll: |
|
self.ll_ = self.ll_[:-1, :] |
|
break |
|
|
|
self.segments_ = [np.exp(lg) for lg in log_gamma] |
|
self.event_var_ = iteration_var |
|
self.event_pat_ = mean_pat |
|
best_ll = np.mean(self.ll_[-1, :]) |
|
|
|
seg.append(self.segments_[0].copy()) |
|
|
|
step += 1 |
|
|
|
return seg |
|
|
|
def _logprob_obs(self, data, mean_pat, var): |
|
n_vox = data.shape[0] |
|
t = data.shape[1] |
|
|
|
|
|
|
|
|
|
data_z = stats.zscore(data, axis=0, ddof=1) |
|
mean_pat_z = stats.zscore(mean_pat, axis=0, ddof=1) |
|
|
|
logprob = np.empty((t, self.n_events)) |
|
|
|
if type(var) is not np.ndarray: |
|
var = var * np.ones(self.n_events) |
|
|
|
for k in range(self.n_events): |
|
logprob[:, k] = -0.5 * n_vox * np.log( |
|
2 * np.pi * var[k]) - 0.5 * np.sum( |
|
(data_z.T - mean_pat_z[:, k]).T ** 2, axis=0) / var[k] |
|
|
|
logprob /= n_vox |
|
return logprob |
|
|
|
def _forward_backward(self, logprob): |
|
logprob = copy.copy(logprob) |
|
t = logprob.shape[0] |
|
logprob = np.hstack((logprob, float("-inf") * np.ones((t, 1)))) |
|
|
|
|
|
log_scale = np.zeros(t) |
|
log_alpha = np.zeros((t, self.n_events + 1)) |
|
log_beta = np.zeros((t, self.n_events + 1)) |
|
|
|
|
|
self.p_start = np.zeros(self.n_events + 1) |
|
self.p_end = np.zeros(self.n_events + 1) |
|
self.P = np.zeros((self.n_events + 1, self.n_events + 1)) |
|
label_ind = np.unique(self.event_chains, return_inverse=True)[1] |
|
n_chains = np.max(label_ind) + 1 |
|
|
|
|
|
for c in range(n_chains): |
|
chain_ind = np.nonzero(label_ind == c)[0] |
|
self.p_start[chain_ind[0]] = 1 / n_chains |
|
self.p_end[chain_ind[-1]] = 1 / n_chains |
|
|
|
p_trans = (len(chain_ind) - 1) / t |
|
if p_trans >= 1: |
|
raise ValueError('Too few timepoints') |
|
for i in range(len(chain_ind)): |
|
self.P[chain_ind[i], chain_ind[i]] = 1 - p_trans |
|
if i < len(chain_ind) - 1: |
|
self.P[chain_ind[i], chain_ind[i+1]] = p_trans |
|
else: |
|
self.P[chain_ind[i], -1] = p_trans |
|
self.P[-1, -1] = 1 |
|
|
|
|
|
for i in range(t): |
|
if i == 0: |
|
log_alpha[0, :] = self._log(self.p_start) + logprob[0, :] |
|
else: |
|
log_alpha[i, :] = self._log(np.exp(log_alpha[i - 1, :]) |
|
.dot(self.P)) + logprob[i, :] |
|
|
|
log_scale[i] = np.logaddexp.reduce(log_alpha[i, :]) |
|
log_alpha[i] -= log_scale[i] |
|
|
|
|
|
log_beta[-1, :] = self._log(self.p_end) - log_scale[-1] |
|
for i in reversed(range(t - 1)): |
|
obs_weighted = log_beta[i + 1, :] + logprob[i + 1, :] |
|
offset = np.max(obs_weighted) |
|
log_beta[i, :] = offset + self._log( |
|
np.exp(obs_weighted - offset).dot(self.P.T)) - log_scale[i] |
|
|
|
|
|
log_gamma = log_alpha + log_beta |
|
log_gamma -= np.logaddexp.reduce(log_gamma, axis=1, keepdims=True) |
|
|
|
ll = np.sum(log_scale[:(t - 1)]) + np.logaddexp.reduce( |
|
log_alpha[-1, :] + log_scale[-1] + self._log(self.p_end)) |
|
|
|
log_gamma = log_gamma[:, :-1] |
|
|
|
return log_gamma, ll |
|
|
|
def _log(self, x): |
|
xshape = x.shape |
|
_x = x.flatten() |
|
y = masked_log(_x) |
|
return y.reshape(xshape) |
|
|
|
def set_event_patterns(self, event_pat): |
|
if event_pat.shape[1] != self.n_events: |
|
raise ValueError(("Number of columns of event_pat must match " |
|
"number of events")) |
|
self.event_pat_ = event_pat.copy() |
|
|
|
def find_events(self, testing_data, var=None, scramble=False): |
|
if var is None: |
|
if not hasattr(self, 'event_var_'): |
|
raise NotFittedError(("Event variance must be provided, if " |
|
"not previously set by fit()")) |
|
else: |
|
var = self.event_var_ |
|
|
|
if not hasattr(self, 'event_pat_'): |
|
raise NotFittedError(("The event patterns must first be set " |
|
"by fit() or set_event_patterns()")) |
|
if scramble: |
|
mean_pat = self.event_pat_[:, np.random.permutation(self.n_events)] |
|
else: |
|
mean_pat = self.event_pat_ |
|
|
|
logprob = self._logprob_obs(testing_data.T, mean_pat, var) |
|
lg, test_ll = self._forward_backward(logprob) |
|
segments = np.exp(lg) |
|
|
|
return segments, test_ll |
|
|
|
def predict(self, X): |
|
check_is_fitted(self, ["event_pat_", "event_var_"]) |
|
X = check_array(X) |
|
segments, test_ll = self.find_events(X) |
|
return np.argmax(segments, axis=1) |
|
|
|
def calc_weighted_event_var(self, D, weights, event_pat): |
|
Dz = stats.zscore(D, axis=1, ddof=1) |
|
ev_var = np.empty(event_pat.shape[1]) |
|
for e in range(event_pat.shape[1]): |
|
|
|
nz = weights[:, e] > np.max(weights[:, e])/1000 |
|
sumsq = np.dot(weights[nz, e], |
|
np.sum(np.square(Dz[nz, :] - |
|
event_pat[:, e]), axis=1)) |
|
ev_var[e] = sumsq/(np.sum(weights[nz, e]) - |
|
np.sum(np.square(weights[nz, e])) / |
|
np.sum(weights[nz, e])) |
|
ev_var = ev_var / D.shape[1] |
|
return ev_var |
|
|
|
def model_prior(self, t): |
|
lg, test_ll = self._forward_backward(np.zeros((t, self.n_events))) |
|
segments = np.exp(lg) |
|
|
|
return segments, test_ll |
|
|
|
def _split_merge(self, X, log_gamma, iteration_var, curr_ll): |
|
|
|
n_train = len(X) |
|
n_dim = X[0].shape[0] |
|
|
|
seg_prob = [np.exp(lg) / np.sum(np.exp(lg), axis=0) |
|
for lg in log_gamma] |
|
mean_pat = np.empty((n_train, n_dim, self.n_events)) |
|
for i in range(n_train): |
|
mean_pat[i, :, :] = X[i].dot(seg_prob[i]) |
|
mean_pat = np.mean(mean_pat, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
merge_pat = np.empty((n_train, n_dim, self.n_events)) |
|
split_pat = np.empty((n_train, n_dim, 2 * self.n_events)) |
|
for i, sp in enumerate(seg_prob): |
|
m_evprob = np.zeros((sp.shape[0], sp.shape[1])) |
|
s_evprob = np.zeros((sp.shape[0], 2 * sp.shape[1])) |
|
cs = np.cumsum(sp, axis=0) |
|
for e in range(sp.shape[1]): |
|
|
|
mid = np.where(cs[:, e] >= 0.5)[0][0] |
|
cs_first = cs[mid, e] - sp[mid, e] |
|
cs_second = 1 - cs_first |
|
s_evprob[:mid, 2 * e] = sp[:mid, e] / cs_first |
|
s_evprob[mid:, 2 * e + 1] = sp[mid:, e] / cs_second |
|
|
|
|
|
m_evprob[:, e] = sp[:, e:(e + 2)].mean(1) |
|
|
|
|
|
merge_pat[i, :, :] = X[i].dot(m_evprob) |
|
split_pat[i, :, :] = X[i].dot(s_evprob) |
|
|
|
|
|
merge_pat = np.mean(merge_pat, axis=0) |
|
split_pat = np.mean(split_pat, axis=0) |
|
|
|
|
|
|
|
merge_corr = np.zeros(self.n_events) |
|
split_corr = np.zeros(self.n_events) |
|
for e in range(self.n_events): |
|
split_corr[e] = np.corrcoef(mean_pat[:, e], |
|
split_pat[:, (2 * e):(2 * e + 2)], |
|
rowvar=False)[0, 1:3].max() |
|
merge_corr[e] = np.corrcoef(merge_pat[:, e], |
|
mean_pat[:, e:(e + 2)], |
|
rowvar=False)[0, 1:3].min() |
|
merge_corr = merge_corr[:-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
best_merge = np.flipud(np.argsort(merge_corr)) |
|
best_merge = best_merge[:self.split_merge_proposals] |
|
best_split = np.argsort(split_corr) |
|
best_split = best_split[:self.split_merge_proposals] |
|
|
|
|
|
|
|
|
|
mean_pat_last = mean_pat.copy() |
|
return_ll = curr_ll |
|
return_lg = copy.deepcopy(log_gamma) |
|
return_mp = mean_pat.copy() |
|
for m_e, s_e in itertools.product(best_merge, best_split): |
|
if m_e == s_e or m_e+1 == s_e: |
|
|
|
continue |
|
|
|
|
|
mean_pat_ms = np.delete(mean_pat_last, s_e, axis=1) |
|
mean_pat_ms = np.insert(mean_pat_ms, [s_e, s_e], |
|
split_pat[:, (2 * s_e):(2 * s_e + 2)], |
|
axis=1) |
|
mean_pat_ms = np.delete(mean_pat_ms, |
|
[m_e + (s_e < m_e), m_e + (s_e < m_e) + 1], |
|
axis=1) |
|
mean_pat_ms = np.insert(mean_pat_ms, m_e + (s_e < m_e), |
|
merge_pat[:, m_e], axis=1) |
|
|
|
|
|
ll_ms = np.zeros(n_train) |
|
log_gamma_ms = list() |
|
for i in range(n_train): |
|
logprob = self._logprob_obs(X[i], |
|
mean_pat_ms, iteration_var) |
|
lg, ll_ms[i] = self._forward_backward(logprob) |
|
log_gamma_ms.append(lg) |
|
|
|
|
|
if ll_ms.mean() > return_ll: |
|
return_mp = mean_pat_ms.copy() |
|
return_ll = ll_ms |
|
for i in range(n_train): |
|
return_lg[i] = log_gamma_ms[i].copy() |
|
|
|
return return_ll, return_lg, return_mp |
|
|