gestsync / sync_models /modules.py
sindhuhegde's picture
Add sync-offset-prediction app
aa5ee46
raw
history blame
No virus
6.73 kB
import torch
import torch.nn as nn
from torch.autograd import Variable
import math
class PositionalEncoding_RGB(nn.Module):
"Implement the PE function."
def __init__(self, d_model, dropout=0.1, max_len=50):
super(PositionalEncoding_RGB, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)],
requires_grad=False)
return self.dropout(x)
def calc_receptive_field(layers, imsize, layer_names=None):
if layer_names is not None:
print("-------Net summary------")
currentLayer = [imsize, 1, 1, 0.5]
for l_id, layer in enumerate(layers):
conv = [
layer[key][-1] if type(layer[key]) in [list, tuple] else layer[key]
for key in ['kernel_size', 'stride', 'padding']
]
currentLayer = outFromIn(conv, currentLayer)
if 'maxpool' in layer:
conv = [
(layer['maxpool'][key][-1] if type(layer['maxpool'][key])
in [list, tuple] else layer['maxpool'][key]) if
(not key == 'padding' or 'padding' in layer['maxpool']) else 0
for key in ['kernel_size', 'stride', 'padding']
]
currentLayer = outFromIn(conv, currentLayer, ceil_mode=False)
return currentLayer
def outFromIn(conv, layerIn, ceil_mode=True):
n_in = layerIn[0]
j_in = layerIn[1]
r_in = layerIn[2]
start_in = layerIn[3]
k = conv[0]
s = conv[1]
p = conv[2]
n_out = math.floor((n_in - k + 2 * p) / s) + 1
actualP = (n_out - 1) * s - n_in + k
pR = math.ceil(actualP / 2)
pL = math.floor(actualP / 2)
j_out = j_in * s
r_out = r_in + (k - 1) * j_in
start_out = start_in + ((k - 1) / 2 - pL) * j_in
return n_out, j_out, r_out, start_out
class DebugModule(nn.Module):
"""
Wrapper class for printing the activation dimensions
"""
def __init__(self, name=None):
super().__init__()
self.name = name
self.debug_log = True
def debug_line(self, layer_str, output, memuse=1, final_call=False):
if self.debug_log:
namestr = '{}: '.format(self.name) if self.name is not None else ''
# print('{}{:80s}: dims {}'.format(namestr, repr(layer_str),
# output.shape))
if final_call:
self.debug_log = False
# print()
class VGGNet(DebugModule):
conv_dict = {
'conv1d': nn.Conv1d,
'conv2d': nn.Conv2d,
'conv3d': nn.Conv3d,
'fc1d': nn.Conv1d,
'fc2d': nn.Conv2d,
'fc3d': nn.Conv3d,
}
pool_dict = {
'conv1d': nn.MaxPool1d,
'conv2d': nn.MaxPool2d,
'conv3d': nn.MaxPool3d,
}
norm_dict = {
'conv1d': nn.BatchNorm1d,
'conv2d': nn.BatchNorm2d,
'conv3d': nn.BatchNorm3d,
'fc1d': nn.BatchNorm1d,
'fc2d': nn.BatchNorm2d,
'fc3d': nn.BatchNorm3d,
}
def __init__(self, n_channels_in, layers):
super(VGGNet, self).__init__()
self.layers = layers
n_channels_prev = n_channels_in
for l_id, lr in enumerate(self.layers):
l_id += 1
name = 'fc' if 'fc' in lr['type'] else 'conv'
conv_type = self.conv_dict[lr['type']]
norm_type = self.norm_dict[lr['type']]
self.__setattr__(
'{:s}{:d}'.format(name, l_id),
conv_type(n_channels_prev,
lr['n_channels'],
kernel_size=lr['kernel_size'],
stride=lr['stride'],
padding=lr['padding']))
n_channels_prev = lr['n_channels']
self.__setattr__('bn{:d}'.format(l_id), norm_type(lr['n_channels']))
if 'maxpool' in lr:
pool_type = self.pool_dict[lr['type']]
padding = lr['maxpool']['padding'] if 'padding' in lr[
'maxpool'] else 0
self.__setattr__(
'mp{:d}'.format(l_id),
pool_type(kernel_size=lr['maxpool']['kernel_size'],
stride=lr['maxpool']['stride'],
padding=padding),
)
def forward(self, inp):
self.debug_line('Input', inp)
out = inp
for l_id, lr in enumerate(self.layers):
l_id += 1
name = 'fc' if 'fc' in lr['type'] else 'conv'
out = self.__getattr__('{:s}{:d}'.format(name, l_id))(out)
out = self.__getattr__('bn{:d}'.format(l_id))(out)
out = nn.ReLU(inplace=True)(out)
self.debug_line(self.__getattr__('{:s}{:d}'.format(name, l_id)),
out)
if 'maxpool' in lr:
out = self.__getattr__('mp{:d}'.format(l_id))(out)
self.debug_line(self.__getattr__('mp{:d}'.format(l_id)), out)
self.debug_line('Output', out, final_call=True)
return out
class NetFC(DebugModule):
def __init__(self, input_dim, hidden_dim, embed_dim):
super(NetFC, self).__init__()
self.fc7 = nn.Conv3d(input_dim, hidden_dim, kernel_size=(1, 1, 1))
self.bn7 = nn.BatchNorm3d(hidden_dim)
self.fc8 = nn.Conv3d(hidden_dim, embed_dim, kernel_size=(1, 1, 1))
def forward(self, inp):
out = self.fc7(inp)
self.debug_line(self.fc7, out)
out = self.bn7(out)
out = nn.ReLU(inplace=True)(out)
out = self.fc8(out)
self.debug_line(self.fc8, out, final_call=True)
return out
class NetFC_2D(DebugModule):
def __init__(self, input_dim, hidden_dim, embed_dim):
super(NetFC_2D, self).__init__()
self.fc7 = nn.Conv2d(input_dim, hidden_dim, kernel_size=(1, 1))
self.bn7 = nn.BatchNorm2d(hidden_dim)
self.fc8 = nn.Conv2d(hidden_dim, embed_dim, kernel_size=(1, 1))
def forward(self, inp):
out = self.fc7(inp)
self.debug_line(self.fc7, out)
out = self.bn7(out)
out = nn.ReLU(inplace=True)(out)
out = self.fc8(out)
self.debug_line(self.fc8, out, final_call=True)
return out