import torch import torch.nn as nn from sync_models.modules import * class Transformer_RGB(nn.Module): def __init__(self): super().__init__() self.net_vid = self.build_net_vid() self.ff_vid = nn.Sequential( nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 1024) ) self.pos_encoder = PositionalEncoding_RGB(d_model=512) encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) self.net_aud = self.build_net_aud() self.lstm = nn.LSTM(512, 256, num_layers=1, bidirectional=True, batch_first=True) self.ff_aud = NetFC_2D(input_dim=512, hidden_dim=512, embed_dim=1024) self.logits_scale = nn.Linear(1, 1, bias=False) torch.nn.init.ones_(self.logits_scale.weight) self.fc = nn.Linear(1,1) def build_net_vid(self): layers = [ { 'type': 'conv3d', 'n_channels': 64, 'kernel_size': (5, 7, 7), 'stride': (1, 3, 3), 'padding': (0), 'maxpool': { 'kernel_size': (1, 3, 3), 'stride': (1, 2, 2) } }, { 'type': 'conv3d', 'n_channels': 128, 'kernel_size': (1, 5, 5), 'stride': (1, 2, 2), 'padding': (0, 0, 0), }, { 'type': 'conv3d', 'n_channels': 256, 'kernel_size': (1, 3, 3), 'stride': (1, 2, 2), 'padding': (0, 1, 1), }, { 'type': 'conv3d', 'n_channels': 256, 'kernel_size': (1, 3, 3), 'stride': (1, 1, 2), 'padding': (0, 1, 1), }, { 'type': 'conv3d', 'n_channels': 256, 'kernel_size': (1, 3, 3), 'stride': (1, 1, 1), 'padding': (0, 1, 1), 'maxpool': { 'kernel_size': (1, 3, 3), 'stride': (1, 2, 2) } }, { 'type': 'fc3d', 'n_channels': 512, 'kernel_size': (1, 4, 4), 'stride': (1, 1, 1), 'padding': (0), }, ] return VGGNet(n_channels_in=3, layers=layers) def build_net_aud(self): layers = [ { 'type': 'conv2d', 'n_channels': 64, 'kernel_size': (3, 3), 'stride': (2, 2), 'padding': (1, 1), 'maxpool': { 'kernel_size': (3, 3), 'stride': (2, 2) } }, { 'type': 'conv2d', 'n_channels': 192, 'kernel_size': (3, 3), 'stride': (1, 2), 'padding': (1, 1), 'maxpool': { 'kernel_size': (3, 3), 'stride': (2, 2) } }, { 'type': 'conv2d', 'n_channels': 384, 'kernel_size': (3, 3), 'stride': (1, 1), 'padding': (1, 1), }, { 'type': 'conv2d', 'n_channels': 256, 'kernel_size': (3, 3), 'stride': (1, 1), 'padding': (1, 1), }, { 'type': 'conv2d', 'n_channels': 256, 'kernel_size': (3, 3), 'stride': (1, 1), 'padding': (1, 1), 'maxpool': { 'kernel_size': (2, 3), 'stride': (2, 2) } }, { 'type': 'fc2d', 'n_channels': 512, 'kernel_size': (4, 2), 'stride': (1, 1), 'padding': (0, 0), }, ] return VGGNet(n_channels_in=1, layers=layers) def forward_vid(self, x, return_feats=False): out_conv = self.net_vid(x).squeeze(-1).squeeze(-1) # print("Conv: ", out_conv.shape) # Bx1024x21x1x1 out = self.pos_encoder(out_conv.transpose(1,2)) out_trans = self.transformer_encoder(out) # print("Transformer: ", out_trans.shape) # Bx21x1024 out = self.ff_vid(out_trans).transpose(1,2) # print("MLP output: ", out.shape) # Bx1024 if return_feats: return out, out_conv else: return out def forward_aud(self, x): out = self.net_aud(x) out = self.ff_aud(out) out = out.squeeze(-1) return out