gestsync / sync_models /gestsync_models.py
sindhuhegde's picture
Add sync-offset-prediction app
aa5ee46
raw
history blame
5.09 kB
import torch
import torch.nn as nn
from sync_models.modules import *
class Transformer_RGB(nn.Module):
def __init__(self):
super().__init__()
self.net_vid = self.build_net_vid()
self.ff_vid = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 1024)
)
self.pos_encoder = PositionalEncoding_RGB(d_model=512)
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
self.net_aud = self.build_net_aud()
self.lstm = nn.LSTM(512, 256, num_layers=1, bidirectional=True, batch_first=True)
self.ff_aud = NetFC_2D(input_dim=512, hidden_dim=512, embed_dim=1024)
self.logits_scale = nn.Linear(1, 1, bias=False)
torch.nn.init.ones_(self.logits_scale.weight)
self.fc = nn.Linear(1,1)
def build_net_vid(self):
layers = [
{
'type': 'conv3d',
'n_channels': 64,
'kernel_size': (5, 7, 7),
'stride': (1, 3, 3),
'padding': (0),
'maxpool': {
'kernel_size': (1, 3, 3),
'stride': (1, 2, 2)
}
},
{
'type': 'conv3d',
'n_channels': 128,
'kernel_size': (1, 5, 5),
'stride': (1, 2, 2),
'padding': (0, 0, 0),
},
{
'type': 'conv3d',
'n_channels': 256,
'kernel_size': (1, 3, 3),
'stride': (1, 2, 2),
'padding': (0, 1, 1),
},
{
'type': 'conv3d',
'n_channels': 256,
'kernel_size': (1, 3, 3),
'stride': (1, 1, 2),
'padding': (0, 1, 1),
},
{
'type': 'conv3d',
'n_channels': 256,
'kernel_size': (1, 3, 3),
'stride': (1, 1, 1),
'padding': (0, 1, 1),
'maxpool': {
'kernel_size': (1, 3, 3),
'stride': (1, 2, 2)
}
},
{
'type': 'fc3d',
'n_channels': 512,
'kernel_size': (1, 4, 4),
'stride': (1, 1, 1),
'padding': (0),
},
]
return VGGNet(n_channels_in=3, layers=layers)
def build_net_aud(self):
layers = [
{
'type': 'conv2d',
'n_channels': 64,
'kernel_size': (3, 3),
'stride': (2, 2),
'padding': (1, 1),
'maxpool': {
'kernel_size': (3, 3),
'stride': (2, 2)
}
},
{
'type': 'conv2d',
'n_channels': 192,
'kernel_size': (3, 3),
'stride': (1, 2),
'padding': (1, 1),
'maxpool': {
'kernel_size': (3, 3),
'stride': (2, 2)
}
},
{
'type': 'conv2d',
'n_channels': 384,
'kernel_size': (3, 3),
'stride': (1, 1),
'padding': (1, 1),
},
{
'type': 'conv2d',
'n_channels': 256,
'kernel_size': (3, 3),
'stride': (1, 1),
'padding': (1, 1),
},
{
'type': 'conv2d',
'n_channels': 256,
'kernel_size': (3, 3),
'stride': (1, 1),
'padding': (1, 1),
'maxpool': {
'kernel_size': (2, 3),
'stride': (2, 2)
}
},
{
'type': 'fc2d',
'n_channels': 512,
'kernel_size': (4, 2),
'stride': (1, 1),
'padding': (0, 0),
},
]
return VGGNet(n_channels_in=1, layers=layers)
def forward_vid(self, x, return_feats=False):
out_conv = self.net_vid(x).squeeze(-1).squeeze(-1)
# print("Conv: ", out_conv.shape) # Bx1024x21x1x1
out = self.pos_encoder(out_conv.transpose(1,2))
out_trans = self.transformer_encoder(out)
# print("Transformer: ", out_trans.shape) # Bx21x1024
out = self.ff_vid(out_trans).transpose(1,2)
# print("MLP output: ", out.shape) # Bx1024
if return_feats:
return out, out_conv
else:
return out
def forward_aud(self, x):
out = self.net_aud(x)
out = self.ff_aud(out)
out = out.squeeze(-1)
return out