|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from attentionLayer import attentionLayer |
|
|
|
|
|
class ResNetLayer(nn.Module): |
|
""" |
|
A ResNet layer used to build the ResNet network. |
|
Architecture: |
|
--> conv-bn-relu -> conv -> + -> bn-relu -> conv-bn-relu -> conv -> + -> bn-relu --> |
|
| | | | |
|
-----> downsample ------> -------------------------------------> |
|
""" |
|
|
|
def __init__(self, inplanes, outplanes, stride): |
|
super(ResNetLayer, self).__init__() |
|
self.conv1a = nn.Conv2d(inplanes, |
|
outplanes, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=1, |
|
bias=False) |
|
self.bn1a = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) |
|
self.conv2a = nn.Conv2d(outplanes, |
|
outplanes, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
bias=False) |
|
self.stride = stride |
|
if self.stride != 1: |
|
self.downsample = nn.Conv2d(inplanes, |
|
outplanes, |
|
kernel_size=(1, 1), |
|
stride=stride, |
|
bias=False) |
|
self.outbna = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) |
|
|
|
self.conv1b = nn.Conv2d(outplanes, |
|
outplanes, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
bias=False) |
|
self.bn1b = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) |
|
self.conv2b = nn.Conv2d(outplanes, |
|
outplanes, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
bias=False) |
|
self.outbnb = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) |
|
return |
|
|
|
def forward(self, inputBatch): |
|
batch = F.relu(self.bn1a(self.conv1a(inputBatch))) |
|
batch = self.conv2a(batch) |
|
if self.stride == 1: |
|
residualBatch = inputBatch |
|
else: |
|
residualBatch = self.downsample(inputBatch) |
|
batch = batch + residualBatch |
|
intermediateBatch = batch |
|
batch = F.relu(self.outbna(batch)) |
|
|
|
batch = F.relu(self.bn1b(self.conv1b(batch))) |
|
batch = self.conv2b(batch) |
|
residualBatch = intermediateBatch |
|
batch = batch + residualBatch |
|
outputBatch = F.relu(self.outbnb(batch)) |
|
return outputBatch |
|
|
|
|
|
class ResNet(nn.Module): |
|
""" |
|
An 18-layer ResNet architecture. |
|
""" |
|
|
|
def __init__(self): |
|
super(ResNet, self).__init__() |
|
self.layer1 = ResNetLayer(64, 64, stride=1) |
|
self.layer2 = ResNetLayer(64, 128, stride=2) |
|
self.layer3 = ResNetLayer(128, 256, stride=2) |
|
self.layer4 = ResNetLayer(256, 512, stride=2) |
|
self.avgpool = nn.AvgPool2d(kernel_size=(4, 4), stride=(1, 1)) |
|
|
|
return |
|
|
|
def forward(self, inputBatch): |
|
batch = self.layer1(inputBatch) |
|
batch = self.layer2(batch) |
|
batch = self.layer3(batch) |
|
batch = self.layer4(batch) |
|
outputBatch = self.avgpool(batch) |
|
return outputBatch |
|
|
|
|
|
class GlobalLayerNorm(nn.Module): |
|
|
|
def __init__(self, channel_size): |
|
super(GlobalLayerNorm, self).__init__() |
|
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) |
|
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.gamma.data.fill_(1) |
|
self.beta.data.zero_() |
|
|
|
def forward(self, y): |
|
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) |
|
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) |
|
gLN_y = self.gamma * (y - mean) / torch.pow(var + 1e-8, 0.5) + self.beta |
|
return gLN_y |
|
|
|
|
|
class visualFrontend(nn.Module): |
|
""" |
|
A visual feature extraction module. Generates a 512-dim feature vector per video frame. |
|
Architecture: A 3D convolution block followed by an 18-layer ResNet. |
|
""" |
|
|
|
def __init__(self, cfg): |
|
self.cfg = cfg |
|
super(visualFrontend, self).__init__() |
|
self.frontend3D = nn.Sequential( |
|
nn.Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), |
|
bias=False), nn.BatchNorm3d(64, momentum=0.01, eps=0.001), nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))) |
|
self.resnet = ResNet() |
|
return |
|
|
|
def forward(self, inputBatch): |
|
inputBatch = inputBatch.transpose(0, 1).transpose(1, 2) |
|
batchsize = inputBatch.shape[0] |
|
batch = self.frontend3D(inputBatch) |
|
|
|
batch = batch.transpose(1, 2) |
|
batch = batch.reshape(batch.shape[0] * batch.shape[1], batch.shape[2], batch.shape[3], |
|
batch.shape[4]) |
|
outputBatch = self.resnet(batch) |
|
outputBatch = outputBatch.reshape(batchsize, -1, 512) |
|
outputBatch = outputBatch.transpose(1, 2) |
|
outputBatch = outputBatch.transpose(1, 2).transpose(0, 1) |
|
return outputBatch |
|
|
|
|
|
class DSConv1d(nn.Module): |
|
|
|
def __init__(self): |
|
super(DSConv1d, self).__init__() |
|
self.net = nn.Sequential( |
|
nn.ReLU(), |
|
nn.BatchNorm1d(512), |
|
nn.Conv1d(512, 512, 3, stride=1, padding=1, dilation=1, groups=512, bias=False), |
|
nn.PReLU(), |
|
GlobalLayerNorm(512), |
|
nn.Conv1d(512, 512, 1, bias=False), |
|
) |
|
|
|
def forward(self, x): |
|
out = self.net(x) |
|
return out + x |
|
|
|
|
|
class visualTCN(nn.Module): |
|
|
|
def __init__(self): |
|
super(visualTCN, self).__init__() |
|
stacks = [] |
|
for x in range(5): |
|
stacks += [DSConv1d()] |
|
self.net = nn.Sequential(*stacks) |
|
|
|
def forward(self, x): |
|
out = self.net(x) |
|
return out |
|
|
|
|
|
class visualConv1D(nn.Module): |
|
|
|
def __init__(self): |
|
super(visualConv1D, self).__init__() |
|
self.net = nn.Sequential( |
|
nn.Conv1d(512, 256, 5, stride=1, padding=2), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(), |
|
nn.Conv1d(256, 128, 1), |
|
) |
|
|
|
def forward(self, x): |
|
out = self.net(x) |
|
return out |
|
|