LoCoNet_ASD / loconet_encoder.py
Superxixixi's picture
Upload 5 files
b98cec2
raw
history blame
2.92 kB
import torch
import torch.nn as nn
from attentionLayer import attentionLayer
from convLayer import ConvLayer
from torchvggish import vggish
from visualEncoder import visualFrontend, visualConv1D, visualTCN
class locoencoder(nn.Module):
def __init__(self, cfg):
super(locoencoder, self).__init__()
self.cfg = cfg
# Visual Temporal Encoder
self.visualFrontend = visualFrontend(cfg) # Visual Frontend
self.visualTCN = visualTCN() # Visual Temporal Network TCN
self.visualConv1D = visualConv1D() # Visual Temporal Network Conv1d
urls = {
'vggish':
"https://github.com/harritaylor/torchvggish/releases/download/v0.1/vggish-10086976.pth"
}
self.audioEncoder = vggish.VGGish(urls, preprocess=False, postprocess=False)
self.audio_pool = nn.AdaptiveAvgPool1d(1)
# Audio-visual Cross Attention
self.crossA2V = attentionLayer(d_model=128, nhead=8)
self.crossV2A = attentionLayer(d_model=128, nhead=8)
# Audio-visual Self Attention
num_layers = self.cfg.av_layers
layers = nn.ModuleList()
for i in range(num_layers):
layers.append(ConvLayer(cfg))
layers.append(attentionLayer(d_model=256, nhead=8))
self.convAV = layers
def forward_visual_frontend(self, x):
B, T, W, H = x.shape
x = x.view(B * T, 1, 1, W, H)
x = (x / 255 - 0.4161) / 0.1688
x = self.visualFrontend(x)
x = x.view(B, T, 512)
x = x.transpose(1, 2)
x = self.visualTCN(x)
x = self.visualConv1D(x)
x = x.transpose(1, 2)
return x
def forward_audio_frontend(self, x):
t = x.shape[-2]
numFrames = t // 4
pad = 8 - (t % 8)
x = torch.nn.functional.pad(x, (0, 0, 0, pad), "constant")
# x = x.unsqueeze(1).transpose(2, 3)
x = self.audioEncoder(x)
b, c, t2, freq = x.shape
x = x.view(b * c, t2, freq)
x = self.audio_pool(x)
x = x.view(b, c, t2)[:, :, :numFrames]
x = x.permute(0, 2, 1)
return x
def forward_cross_attention(self, x1, x2):
x1_c = self.crossA2V(src=x1, tar=x2, adjust=self.cfg.adjust_attention)
x2_c = self.crossV2A(src=x2, tar=x1, adjust=self.cfg.adjust_attention)
return x1_c, x2_c
def forward_audio_visual_backend(self, x1, x2, b=1, s=1):
x = torch.cat((x1, x2), 2) # B*S, T, 2C
for i, layer in enumerate(self.convAV):
if i % 2 == 0:
x, b, s = layer(x, b, s)
else:
x = layer(src=x, tar=x)
x = torch.reshape(x, (-1, 256))
return x
def forward_audio_backend(self, x):
x = torch.reshape(x, (-1, 128))
return x
def forward_visual_backend(self, x):
x = torch.reshape(x, (-1, 128))
return x