import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from argparse import ZERO_OR_MORE import math import random from torch.nn.modules.module import T from transformers import PreTrainedModel from .configuration_fsae import FSAEConfig dt = 5 a = 0.25 aa = 0.5 Vth = 0.2 tau = 0.25 class SpikeAct(torch.autograd.Function): """ Implementation of the spiking activation function with an approximation of gradient. """ @staticmethod def forward(ctx, input): ctx.save_for_backward(input) # if input = u > Vth then output = 1 output = torch.gt(input, Vth) return output.float() @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_input = grad_output.clone() # hu is an approximate func of df/du hu = abs(input) < aa hu = hu.float() / (2 * aa) return grad_input * hu class LIFSpike(nn.Module): """ Generates spikes based on LIF module. It can be considered as an activation function and is used similar to ReLU. The input tensor needs to have an additional time dimension, which in this case is on the last dimension of the data. """ def __init__(self): super(LIFSpike, self).__init__() def forward(self, x): nsteps = x.shape[-1] u = torch.zeros(x.shape[:-1] , device=x.device) out = torch.zeros(x.shape, device=x.device) for step in range(nsteps): u, out[..., step] = self.state_update(u, out[..., max(step-1, 0)], x[..., step]) return out def state_update(self, u_t_n1, o_t_n1, W_mul_o_t1_n, tau=tau): u_t1_n1 = tau * u_t_n1 * (1 - o_t_n1) + W_mul_o_t1_n o_t1_n1 = SpikeAct.apply(u_t1_n1) return u_t1_n1, o_t1_n1 class tdLinear(nn.Linear): def __init__(self, in_features, out_features, bias=True, bn=None, spike=None): assert type(in_features) == int, 'inFeatures should not be more than 1 dimesnion. It was: {}'.format(in_features.shape) assert type(out_features) == int, 'outFeatures should not be more than 1 dimesnion. It was: {}'.format(out_features.shape) super(tdLinear, self).__init__(in_features, out_features, bias=bias) self.bn = bn self.spike = spike def forward(self, x): """ x : (N,C,T) """ x = x.transpose(1, 2) # (N, T, C) y = F.linear(x, self.weight, self.bias) y = y.transpose(1, 2)# (N, C, T) if self.bn is not None: y = y[:,:,None,None,:] y = self.bn(y) y = y[:,:,0,0,:] if self.spike is not None: y = self.spike(y) return y class tdConv(nn.Conv3d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, bn=None, spike=None, is_first_conv=False): # kernel if type(kernel_size) == int: kernel = (kernel_size, kernel_size, 1) elif len(kernel_size) == 2: kernel = (kernel_size[0], kernel_size[1], 1) else: raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape)) # stride if type(stride) == int: stride = (stride, stride, 1) elif len(stride) == 2: stride = (stride[0], stride[1], 1) else: raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) # padding if type(padding) == int: padding = (padding, padding, 0) elif len(padding) == 2: padding = (padding[0], padding[1], 0) else: raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) # dilation if type(dilation) == int: dilation = (dilation, dilation, 1) elif len(dilation) == 2: dilation = (dilation[0], dilation[1], 1) else: raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) super(tdConv, self).__init__(in_channels, out_channels, kernel, stride, padding, dilation, groups, bias=bias) self.bn = bn self.spike = spike self.is_first_conv = is_first_conv def forward(self, x): x = F.conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) if self.bn is not None: x = self.bn(x) if self.spike is not None: x = self.spike(x) return x class tdConvTranspose(nn.ConvTranspose3d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True, bn=None, spike=None): # kernel if type(kernel_size) == int: kernel = (kernel_size, kernel_size, 1) elif len(kernel_size) == 2: kernel = (kernel_size[0], kernel_size[1], 1) else: raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape)) # stride if type(stride) == int: stride = (stride, stride, 1) elif len(stride) == 2: stride = (stride[0], stride[1], 1) else: raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape)) # padding if type(padding) == int: padding = (padding, padding, 0) elif len(padding) == 2: padding = (padding[0], padding[1], 0) else: raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) # dilation if type(dilation) == int: dilation = (dilation, dilation, 1) elif len(dilation) == 2: dilation = (dilation[0], dilation[1], 1) else: raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape)) # output padding if type(output_padding) == int: output_padding = (output_padding, output_padding, 0) elif len(output_padding) == 2: output_padding = (output_padding[0], output_padding[1], 0) else: raise Exception('output_padding can be either int or tuple of size 2. It was: {}'.format(padding.shape)) super().__init__(in_channels, out_channels, kernel, stride, padding, output_padding, groups, bias=bias, dilation=dilation) self.bn = bn self.spike = spike def forward(self, x): x = F.conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation) if self.bn is not None: x = self.bn(x) if self.spike is not None: x = self.spike(x) return x class tdBatchNorm(nn.BatchNorm2d): """ Implementation of tdBN. Link to related paper: https://arxiv.org/pdf/2011.05280. In short it is averaged over the time domain as well when doing BN. Args: num_features (int): same with nn.BatchNorm2d eps (float): same with nn.BatchNorm2d momentum (float): same with nn.BatchNorm2d alpha (float): an addtional parameter which may change in resblock. affine (bool): same with nn.BatchNorm2d track_running_stats (bool): same with nn.BatchNorm2d """ def __init__(self, num_features, eps=1e-05, momentum=0.1, alpha=1, affine=True, track_running_stats=True): super(tdBatchNorm, self).__init__( num_features, eps, momentum, affine, track_running_stats) self.alpha = alpha def forward(self, input): exponential_average_factor = 0.0 if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum # calculate running estimates if self.training: mean = input.mean([0, 2, 3, 4]) # use biased var in train var = input.var([0, 2, 3, 4], unbiased=False) n = input.numel() / input.size(1) with torch.no_grad(): self.running_mean = exponential_average_factor * mean\ + (1 - exponential_average_factor) * self.running_mean # update running_var with unbiased var self.running_var = exponential_average_factor * var * n / (n - 1)\ + (1 - exponential_average_factor) * self.running_var else: mean = self.running_mean var = self.running_var input = self.alpha * Vth * (input - mean[None, :, None, None, None]) / (torch.sqrt(var[None, :, None, None, None] + self.eps)) if self.affine: input = input * self.weight[None, :, None, None, None] + self.bias[None, :, None, None, None] return input class PSP(torch.nn.Module): def __init__(self): super().__init__() self.tau_s = 2 def forward(self, inputs): """ inputs: (N, C, T) """ syns = None syn = 0 n_steps = inputs.shape[-1] for t in range(n_steps): syn = syn + (inputs[...,t] - syn) / self.tau_s if syns is None: syns = syn.unsqueeze(-1) else: syns = torch.cat([syns, syn.unsqueeze(-1)], dim=-1) return syns class MembraneOutputLayer(nn.Module): """ outputs the last time membrane potential of the LIF neuron with V_th=infty """ def __init__(self) -> None: super().__init__() # n_steps = glv.n_steps n_steps = 16 arr = torch.arange(n_steps-1,-1,-1) self.register_buffer("coef", torch.pow(0.8, arr)[None,None,None,None,:]) # (1,1,1,1,T) def forward(self, x): """ x : (N,C,H,W,T) """ out = torch.sum(x*self.coef, dim=-1) return out class PriorBernoulliSTBP(nn.Module): def __init__(self, k=20) -> None: """ modeling of p(z_t|z_=5 and random.random() < p: # scheduled sampling outputs = self.layers(z_t_minus.detach()) #binary (B, C*k, t+1) z_<=t p_z_t = outputs[...,-1] # (B, C*k, 1) # sampling from p(z_t | z_0.5).float() # (B,C) z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1) z_t_minus = torch.cat([z_t_minus, z_t], dim=-1) # (B,C,t+2) else: z_t_minus = torch.cat([z_t_minus, z[...,t].unsqueeze(-1)], dim=-1) # (B,C,t+2) else: # for test time z_t_minus = torch.cat([z_t_minus, z[:,:,:-1]], dim=-1) # (B,C,T) z_t_minus = z_t_minus.detach() # (B,C,T) z_{<=T-1} p_z = self.layers(z_t_minus) # (B,C*k,T) p_z = p_z.view(batch_size, self.channels, self.k, self.n_steps)# (B,C,k,T) return p_z def sample(self, batch_size=64): z_minus_t = self.initial_input.repeat(batch_size, 1, 1) # (B, C, 1) for t in range(self.n_steps): outputs = self.layers(z_minus_t) # (B, C*k, t+1) p_z_t = outputs[...,-1] # (B, C*k, 1) random_index = torch.randint(0, self.k, (batch_size*self.channels,)) \ + torch.arange(start=0, end=batch_size*self.channels*self.k, step=self.k) #(B*C,) pick one from k random_index = random_index.to(z_minus_t.device) z_t = p_z_t.view(batch_size*self.channels*self.k)[random_index] # (B*C,) z_t = z_t.view(batch_size, self.channels, 1) #(B,C,1) z_minus_t = torch.cat([z_minus_t, z_t], dim=-1) # (B,C,t+2) sampled_z = z_minus_t[...,1:] # (B,C,T) return sampled_z class PosteriorBernoulliSTBP(nn.Module): def __init__(self, k=20) -> None: """ modeling of q(z_t | x_<=t, z_