NCU / BrainPulse /event.py
Łukasz Furman
update app.py
a59bdc5
raw
history blame
14.8 kB
"""Event segmentation using a Hidden Markov Model
Adapted from the brainiak package for this workshop.
See https://brainiak.org/ for full documentation."""
import numpy as np
from scipy import stats
import logging
import copy
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_is_fitted, check_array
from sklearn.exceptions import NotFittedError
import itertools
def masked_log(x):
y = np.empty(x.shape, dtype=x.dtype)
lim = x.shape[0]
for i in range(lim):
if x[i] <= 0:
y[i] = float('-inf')
else:
y[i] = np.log(x[i])
return y
class EventSegment(BaseEstimator):
def _default_var_schedule(step):
return 4 * (0.98 ** (step - 1))
def __init__(self, n_events=2,
step_var=_default_var_schedule,
n_iter=500, event_chains=None,
split_merge=False, split_merge_proposals=1):
self.n_events = n_events
self.step_var = step_var
self.n_iter = n_iter
self.split_merge = split_merge
self.split_merge_proposals = split_merge_proposals
if event_chains is None:
self.event_chains = np.zeros(n_events)
else:
self.event_chains = event_chains
def _fit_validate(self, X):
if len(np.unique(self.event_chains)) > 1:
raise RuntimeError("Cannot fit chains, use set_event_patterns")
# Copy X into a list and transpose
X = copy.deepcopy(X)
if type(X) is not list:
X = [X]
for i in range(len(X)):
X[i] = check_array(X[i])
X[i] = X[i].T
# Check that number of voxels is consistent across datasets
n_dim = X[0].shape[0]
for i in range(len(X)):
assert (X[i].shape[0] == n_dim)
# Double-check that data is z-scored in time
for i in range(len(X)):
X[i] = stats.zscore(X[i], axis=1, ddof=1)
return X
def fit(self, X, y=None):
seg = []
X = self._fit_validate(X)
n_train = len(X)
n_dim = X[0].shape[0]
self.classes_ = np.arange(self.n_events)
# Initialize variables for fitting
log_gamma = []
for i in range(n_train):
log_gamma.append(np.zeros((X[i].shape[1], self.n_events)))
step = 1
best_ll = float("-inf")
self.ll_ = np.empty((0, n_train))
while step <= self.n_iter:
iteration_var = self.step_var(step)
# Based on the current segmentation, compute the mean pattern
# for each event
seg_prob = [np.exp(lg) / np.sum(np.exp(lg), axis=0)
for lg in log_gamma]
mean_pat = np.empty((n_train, n_dim, self.n_events))
for i in range(n_train):
mean_pat[i, :, :] = X[i].dot(seg_prob[i])
mean_pat = np.mean(mean_pat, axis=0)
# Based on the current mean patterns, compute the event
# segmentation
self.ll_ = np.append(self.ll_, np.empty((1, n_train)), axis=0)
for i in range(n_train):
logprob = self._logprob_obs(X[i], mean_pat, iteration_var)
log_gamma[i], self.ll_[-1, i] = self._forward_backward(logprob)
if step > 1 and self.split_merge:
curr_ll = np.mean(self.ll_[-1, :])
self.ll_[-1, :], log_gamma, mean_pat = \
self._split_merge(X, log_gamma, iteration_var, curr_ll)
# If log-likelihood has started decreasing, undo last step and stop
if np.mean(self.ll_[-1, :]) < best_ll:
self.ll_ = self.ll_[:-1, :]
break
self.segments_ = [np.exp(lg) for lg in log_gamma]
self.event_var_ = iteration_var
self.event_pat_ = mean_pat
best_ll = np.mean(self.ll_[-1, :])
seg.append(self.segments_[0].copy())
step += 1
return seg
def _logprob_obs(self, data, mean_pat, var):
n_vox = data.shape[0]
t = data.shape[1]
# z-score both data and mean patterns in space, so that Gaussians
# are measuring Pearson correlations and are insensitive to overall
# activity changes
data_z = stats.zscore(data, axis=0, ddof=1)
mean_pat_z = stats.zscore(mean_pat, axis=0, ddof=1)
logprob = np.empty((t, self.n_events))
if type(var) is not np.ndarray:
var = var * np.ones(self.n_events)
for k in range(self.n_events):
logprob[:, k] = -0.5 * n_vox * np.log(
2 * np.pi * var[k]) - 0.5 * np.sum(
(data_z.T - mean_pat_z[:, k]).T ** 2, axis=0) / var[k]
logprob /= n_vox
return logprob
def _forward_backward(self, logprob):
logprob = copy.copy(logprob)
t = logprob.shape[0]
logprob = np.hstack((logprob, float("-inf") * np.ones((t, 1))))
# Initialize variables
log_scale = np.zeros(t)
log_alpha = np.zeros((t, self.n_events + 1))
log_beta = np.zeros((t, self.n_events + 1))
# Set up transition matrix, with final sink state
self.p_start = np.zeros(self.n_events + 1)
self.p_end = np.zeros(self.n_events + 1)
self.P = np.zeros((self.n_events + 1, self.n_events + 1))
label_ind = np.unique(self.event_chains, return_inverse=True)[1]
n_chains = np.max(label_ind) + 1
# For each chain of events, link them together and then to sink state
for c in range(n_chains):
chain_ind = np.nonzero(label_ind == c)[0]
self.p_start[chain_ind[0]] = 1 / n_chains
self.p_end[chain_ind[-1]] = 1 / n_chains
p_trans = (len(chain_ind) - 1) / t
if p_trans >= 1:
raise ValueError('Too few timepoints')
for i in range(len(chain_ind)):
self.P[chain_ind[i], chain_ind[i]] = 1 - p_trans
if i < len(chain_ind) - 1:
self.P[chain_ind[i], chain_ind[i+1]] = p_trans
else:
self.P[chain_ind[i], -1] = p_trans
self.P[-1, -1] = 1
# Forward pass
for i in range(t):
if i == 0:
log_alpha[0, :] = self._log(self.p_start) + logprob[0, :]
else:
log_alpha[i, :] = self._log(np.exp(log_alpha[i - 1, :])
.dot(self.P)) + logprob[i, :]
log_scale[i] = np.logaddexp.reduce(log_alpha[i, :])
log_alpha[i] -= log_scale[i]
# Backward pass
log_beta[-1, :] = self._log(self.p_end) - log_scale[-1]
for i in reversed(range(t - 1)):
obs_weighted = log_beta[i + 1, :] + logprob[i + 1, :]
offset = np.max(obs_weighted)
log_beta[i, :] = offset + self._log(
np.exp(obs_weighted - offset).dot(self.P.T)) - log_scale[i]
# Combine and normalize
log_gamma = log_alpha + log_beta
log_gamma -= np.logaddexp.reduce(log_gamma, axis=1, keepdims=True)
ll = np.sum(log_scale[:(t - 1)]) + np.logaddexp.reduce(
log_alpha[-1, :] + log_scale[-1] + self._log(self.p_end))
log_gamma = log_gamma[:, :-1]
return log_gamma, ll
def _log(self, x):
xshape = x.shape
_x = x.flatten()
y = masked_log(_x)
return y.reshape(xshape)
def set_event_patterns(self, event_pat):
if event_pat.shape[1] != self.n_events:
raise ValueError(("Number of columns of event_pat must match "
"number of events"))
self.event_pat_ = event_pat.copy()
def find_events(self, testing_data, var=None, scramble=False):
if var is None:
if not hasattr(self, 'event_var_'):
raise NotFittedError(("Event variance must be provided, if "
"not previously set by fit()"))
else:
var = self.event_var_
if not hasattr(self, 'event_pat_'):
raise NotFittedError(("The event patterns must first be set "
"by fit() or set_event_patterns()"))
if scramble:
mean_pat = self.event_pat_[:, np.random.permutation(self.n_events)]
else:
mean_pat = self.event_pat_
logprob = self._logprob_obs(testing_data.T, mean_pat, var)
lg, test_ll = self._forward_backward(logprob)
segments = np.exp(lg)
return segments, test_ll
def predict(self, X):
check_is_fitted(self, ["event_pat_", "event_var_"])
X = check_array(X)
segments, test_ll = self.find_events(X)
return np.argmax(segments, axis=1)
def calc_weighted_event_var(self, D, weights, event_pat):
Dz = stats.zscore(D, axis=1, ddof=1)
ev_var = np.empty(event_pat.shape[1])
for e in range(event_pat.shape[1]):
# Only compute variances for weights > 0.1% of max weight
nz = weights[:, e] > np.max(weights[:, e])/1000
sumsq = np.dot(weights[nz, e],
np.sum(np.square(Dz[nz, :] -
event_pat[:, e]), axis=1))
ev_var[e] = sumsq/(np.sum(weights[nz, e]) -
np.sum(np.square(weights[nz, e])) /
np.sum(weights[nz, e]))
ev_var = ev_var / D.shape[1]
return ev_var
def model_prior(self, t):
lg, test_ll = self._forward_backward(np.zeros((t, self.n_events)))
segments = np.exp(lg)
return segments, test_ll
def _split_merge(self, X, log_gamma, iteration_var, curr_ll):
# Compute current probabilities and mean patterns
n_train = len(X)
n_dim = X[0].shape[0]
seg_prob = [np.exp(lg) / np.sum(np.exp(lg), axis=0)
for lg in log_gamma]
mean_pat = np.empty((n_train, n_dim, self.n_events))
for i in range(n_train):
mean_pat[i, :, :] = X[i].dot(seg_prob[i])
mean_pat = np.mean(mean_pat, axis=0)
# For each event, merge its probability distribution
# with the next event, and also split its probability
# distribution at its median into two separate events.
# Use these new event probability distributions to compute
# merged and split event patterns.
merge_pat = np.empty((n_train, n_dim, self.n_events))
split_pat = np.empty((n_train, n_dim, 2 * self.n_events))
for i, sp in enumerate(seg_prob): # Iterate over datasets
m_evprob = np.zeros((sp.shape[0], sp.shape[1]))
s_evprob = np.zeros((sp.shape[0], 2 * sp.shape[1]))
cs = np.cumsum(sp, axis=0)
for e in range(sp.shape[1]):
# Split distribution at midpoint and normalize each half
mid = np.where(cs[:, e] >= 0.5)[0][0]
cs_first = cs[mid, e] - sp[mid, e]
cs_second = 1 - cs_first
s_evprob[:mid, 2 * e] = sp[:mid, e] / cs_first
s_evprob[mid:, 2 * e + 1] = sp[mid:, e] / cs_second
# Merge distribution with next event distribution
m_evprob[:, e] = sp[:, e:(e + 2)].mean(1)
# Weight data by distribution to get event patterns
merge_pat[i, :, :] = X[i].dot(m_evprob)
split_pat[i, :, :] = X[i].dot(s_evprob)
# Average across datasets
merge_pat = np.mean(merge_pat, axis=0)
split_pat = np.mean(split_pat, axis=0)
# Correlate the current event patterns with the split and
# merged patterns
merge_corr = np.zeros(self.n_events)
split_corr = np.zeros(self.n_events)
for e in range(self.n_events):
split_corr[e] = np.corrcoef(mean_pat[:, e],
split_pat[:, (2 * e):(2 * e + 2)],
rowvar=False)[0, 1:3].max()
merge_corr[e] = np.corrcoef(merge_pat[:, e],
mean_pat[:, e:(e + 2)],
rowvar=False)[0, 1:3].min()
merge_corr = merge_corr[:-1]
# Find best merge/split candidates
# A high value of merge_corr indicates that a pair of events are
# very similar to their merged pattern, and are good candidates for
# being merged.
# A low value of split_corr indicates that an event's pattern is
# very dissimilar from the patterns in its first and second half,
# and is a good candidate for being split.
best_merge = np.flipud(np.argsort(merge_corr))
best_merge = best_merge[:self.split_merge_proposals]
best_split = np.argsort(split_corr)
best_split = best_split[:self.split_merge_proposals]
# For every pair of merge/split candidates, attempt the merge/split
# and measure the log-likelihood. If any are better than curr_ll,
# accept this best merge/split
mean_pat_last = mean_pat.copy()
return_ll = curr_ll
return_lg = copy.deepcopy(log_gamma)
return_mp = mean_pat.copy()
for m_e, s_e in itertools.product(best_merge, best_split):
if m_e == s_e or m_e+1 == s_e:
# Don't attempt to merge/split same event
continue
# Construct new set of patterns with merge/split
mean_pat_ms = np.delete(mean_pat_last, s_e, axis=1)
mean_pat_ms = np.insert(mean_pat_ms, [s_e, s_e],
split_pat[:, (2 * s_e):(2 * s_e + 2)],
axis=1)
mean_pat_ms = np.delete(mean_pat_ms,
[m_e + (s_e < m_e), m_e + (s_e < m_e) + 1],
axis=1)
mean_pat_ms = np.insert(mean_pat_ms, m_e + (s_e < m_e),
merge_pat[:, m_e], axis=1)
# Measure log-likelihood with these new patterns
ll_ms = np.zeros(n_train)
log_gamma_ms = list()
for i in range(n_train):
logprob = self._logprob_obs(X[i],
mean_pat_ms, iteration_var)
lg, ll_ms[i] = self._forward_backward(logprob)
log_gamma_ms.append(lg)
# If better than best ll so far, save to return to fit()
if ll_ms.mean() > return_ll:
return_mp = mean_pat_ms.copy()
return_ll = ll_ms
for i in range(n_train):
return_lg[i] = log_gamma_ms[i].copy()
return return_ll, return_lg, return_mp