Superxixixi
commited on
Commit
•
b98cec2
1
Parent(s):
a25806a
Upload 5 files
Browse files- attentionLayer.py +39 -0
- audioEncoder.py +108 -0
- convLayer.py +42 -0
- loconet_encoder.py +90 -0
- visualEncoder.py +199 -0
attentionLayer.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch.nn import MultiheadAttention
|
5 |
+
|
6 |
+
|
7 |
+
class attentionLayer(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, d_model, nhead, dropout=0.1):
|
10 |
+
super(attentionLayer, self).__init__()
|
11 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
12 |
+
|
13 |
+
self.linear1 = nn.Linear(d_model, d_model * 4)
|
14 |
+
self.dropout = nn.Dropout(dropout)
|
15 |
+
self.linear2 = nn.Linear(d_model * 4, d_model)
|
16 |
+
|
17 |
+
self.norm1 = nn.LayerNorm(d_model)
|
18 |
+
self.norm2 = nn.LayerNorm(d_model)
|
19 |
+
self.dropout1 = nn.Dropout(dropout)
|
20 |
+
self.dropout2 = nn.Dropout(dropout)
|
21 |
+
|
22 |
+
self.activation = F.relu
|
23 |
+
|
24 |
+
def forward(self, src, tar, adjust=False, attn_mask=None):
|
25 |
+
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
|
26 |
+
src = src.transpose(0, 1) # B, T, C -> T, B, C
|
27 |
+
tar = tar.transpose(0, 1) # B, T, C -> T, B, C
|
28 |
+
if adjust:
|
29 |
+
src2 = self.self_attn(src, tar, tar, attn_mask=None, key_padding_mask=None)[0]
|
30 |
+
else:
|
31 |
+
src2 = self.self_attn(tar, src, src, attn_mask=None, key_padding_mask=None)[0]
|
32 |
+
src = src + self.dropout1(src2)
|
33 |
+
src = self.norm1(src)
|
34 |
+
|
35 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
36 |
+
src = src + self.dropout2(src2)
|
37 |
+
src = self.norm2(src)
|
38 |
+
src = src.transpose(0, 1) # T, B, C -> B, T, C
|
39 |
+
return src
|
audioEncoder.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class SEBasicBlock(nn.Module):
|
6 |
+
expansion = 1
|
7 |
+
|
8 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
9 |
+
super(SEBasicBlock, self).__init__()
|
10 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
11 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
12 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
|
13 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
14 |
+
self.relu = nn.ReLU(inplace=True)
|
15 |
+
self.se = SELayer(planes, reduction)
|
16 |
+
self.downsample = downsample
|
17 |
+
self.stride = stride
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
residual = x
|
21 |
+
|
22 |
+
out = self.conv1(x)
|
23 |
+
out = self.relu(out)
|
24 |
+
out = self.bn1(out)
|
25 |
+
|
26 |
+
out = self.conv2(out)
|
27 |
+
out = self.bn2(out)
|
28 |
+
out = self.se(out)
|
29 |
+
|
30 |
+
if self.downsample is not None:
|
31 |
+
residual = self.downsample(x)
|
32 |
+
|
33 |
+
out += residual
|
34 |
+
out = self.relu(out)
|
35 |
+
return out
|
36 |
+
|
37 |
+
class SELayer(nn.Module):
|
38 |
+
def __init__(self, channel, reduction=8):
|
39 |
+
super(SELayer, self).__init__()
|
40 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
41 |
+
self.fc = nn.Sequential(
|
42 |
+
nn.Linear(channel, channel // reduction),
|
43 |
+
nn.ReLU(inplace=True),
|
44 |
+
nn.Linear(channel // reduction, channel),
|
45 |
+
nn.Sigmoid()
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
b, c, _, _ = x.size()
|
50 |
+
y = self.avg_pool(x).view(b, c)
|
51 |
+
y = self.fc(y).view(b, c, 1, 1)
|
52 |
+
return x * y
|
53 |
+
|
54 |
+
class audioEncoder(nn.Module):
|
55 |
+
def __init__(self, layers, num_filters, **kwargs):
|
56 |
+
super(audioEncoder, self).__init__()
|
57 |
+
block = SEBasicBlock
|
58 |
+
self.inplanes = num_filters[0]
|
59 |
+
|
60 |
+
self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=7, stride=(2, 1), padding=3,
|
61 |
+
bias=False)
|
62 |
+
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
63 |
+
self.relu = nn.ReLU(inplace=True)
|
64 |
+
|
65 |
+
self.layer1 = self._make_layer(block, num_filters[0], layers[0])
|
66 |
+
self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2))
|
67 |
+
self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2))
|
68 |
+
self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(1, 1))
|
69 |
+
out_dim = num_filters[3] * block.expansion
|
70 |
+
|
71 |
+
for m in self.modules():
|
72 |
+
if isinstance(m, nn.Conv2d):
|
73 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
74 |
+
elif isinstance(m, nn.BatchNorm2d):
|
75 |
+
nn.init.constant_(m.weight, 1)
|
76 |
+
nn.init.constant_(m.bias, 0)
|
77 |
+
|
78 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
79 |
+
downsample = None
|
80 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
81 |
+
downsample = nn.Sequential(
|
82 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
83 |
+
kernel_size=1, stride=stride, bias=False),
|
84 |
+
nn.BatchNorm2d(planes * block.expansion),
|
85 |
+
)
|
86 |
+
|
87 |
+
layers = []
|
88 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
89 |
+
self.inplanes = planes * block.expansion
|
90 |
+
for i in range(1, blocks):
|
91 |
+
layers.append(block(self.inplanes, planes))
|
92 |
+
|
93 |
+
return nn.Sequential(*layers)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = self.conv1(x)
|
97 |
+
x = self.bn1(x)
|
98 |
+
x = self.relu(x)
|
99 |
+
|
100 |
+
x = self.layer1(x)
|
101 |
+
x = self.layer2(x)
|
102 |
+
x = self.layer3(x)
|
103 |
+
x = self.layer4(x)
|
104 |
+
x = torch.mean(x, dim=2, keepdim=True)
|
105 |
+
x = x.view((x.size()[0], x.size()[1], -1))
|
106 |
+
x = x.transpose(1, 2)
|
107 |
+
|
108 |
+
return x
|
convLayer.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class ConvLayer(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, cfg):
|
9 |
+
super(ConvLayer, self).__init__()
|
10 |
+
self.cfg = cfg
|
11 |
+
self.s = cfg.num_speakers
|
12 |
+
self.conv2d = torch.nn.Conv2d(256, 256 * self.s, (self.s, 7), padding=(0, 3))
|
13 |
+
# below line is speaker parallel 93.88 code
|
14 |
+
# self.conv2d = torch.nn.Conv2d(256, 256 * self.s, (3, 7), padding=(0, 3))
|
15 |
+
self.ln = torch.nn.LayerNorm(256)
|
16 |
+
self.conv2d_1x1 = torch.nn.Conv2d(256, 512, (1, 1), padding=(0, 0))
|
17 |
+
self.conv2d_1x1_2 = torch.nn.Conv2d(512, 256, (1, 1), padding=(0, 0))
|
18 |
+
self.gelu = nn.GELU()
|
19 |
+
|
20 |
+
def forward(self, x, b, s):
|
21 |
+
|
22 |
+
identity = x # b*s, t, c
|
23 |
+
t = x.shape[1]
|
24 |
+
c = x.shape[2]
|
25 |
+
out = x.view(b, s, t, c)
|
26 |
+
out = out.permute(0, 3, 1, 2) # b, c, s, t
|
27 |
+
|
28 |
+
out = self.conv2d(out) # b, s*c, 1, t
|
29 |
+
out = out.view(b, c, s, t)
|
30 |
+
out = out.permute(0, 2, 3, 1) # b, s, t, c
|
31 |
+
out = self.ln(out)
|
32 |
+
out = out.permute(0, 3, 1, 2)
|
33 |
+
out = self.conv2d_1x1(out)
|
34 |
+
out = self.gelu(out)
|
35 |
+
out = self.conv2d_1x1_2(out) # b, c, s, t
|
36 |
+
|
37 |
+
out = out.permute(0, 2, 3, 1) # b, s, t, c
|
38 |
+
out = out.view(b * s, t, c)
|
39 |
+
|
40 |
+
out += identity
|
41 |
+
|
42 |
+
return out, b, s
|
loconet_encoder.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from attentionLayer import attentionLayer
|
5 |
+
from convLayer import ConvLayer
|
6 |
+
from torchvggish import vggish
|
7 |
+
from visualEncoder import visualFrontend, visualConv1D, visualTCN
|
8 |
+
|
9 |
+
|
10 |
+
class locoencoder(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, cfg):
|
13 |
+
super(locoencoder, self).__init__()
|
14 |
+
self.cfg = cfg
|
15 |
+
# Visual Temporal Encoder
|
16 |
+
self.visualFrontend = visualFrontend(cfg) # Visual Frontend
|
17 |
+
self.visualTCN = visualTCN() # Visual Temporal Network TCN
|
18 |
+
self.visualConv1D = visualConv1D() # Visual Temporal Network Conv1d
|
19 |
+
|
20 |
+
urls = {
|
21 |
+
'vggish':
|
22 |
+
"https://github.com/harritaylor/torchvggish/releases/download/v0.1/vggish-10086976.pth"
|
23 |
+
}
|
24 |
+
self.audioEncoder = vggish.VGGish(urls, preprocess=False, postprocess=False)
|
25 |
+
self.audio_pool = nn.AdaptiveAvgPool1d(1)
|
26 |
+
|
27 |
+
# Audio-visual Cross Attention
|
28 |
+
self.crossA2V = attentionLayer(d_model=128, nhead=8)
|
29 |
+
self.crossV2A = attentionLayer(d_model=128, nhead=8)
|
30 |
+
|
31 |
+
# Audio-visual Self Attention
|
32 |
+
|
33 |
+
num_layers = self.cfg.av_layers
|
34 |
+
layers = nn.ModuleList()
|
35 |
+
for i in range(num_layers):
|
36 |
+
layers.append(ConvLayer(cfg))
|
37 |
+
layers.append(attentionLayer(d_model=256, nhead=8))
|
38 |
+
self.convAV = layers
|
39 |
+
|
40 |
+
def forward_visual_frontend(self, x):
|
41 |
+
|
42 |
+
B, T, W, H = x.shape
|
43 |
+
x = x.view(B * T, 1, 1, W, H)
|
44 |
+
x = (x / 255 - 0.4161) / 0.1688
|
45 |
+
x = self.visualFrontend(x)
|
46 |
+
x = x.view(B, T, 512)
|
47 |
+
x = x.transpose(1, 2)
|
48 |
+
x = self.visualTCN(x)
|
49 |
+
x = self.visualConv1D(x)
|
50 |
+
x = x.transpose(1, 2)
|
51 |
+
return x
|
52 |
+
|
53 |
+
def forward_audio_frontend(self, x):
|
54 |
+
t = x.shape[-2]
|
55 |
+
numFrames = t // 4
|
56 |
+
pad = 8 - (t % 8)
|
57 |
+
x = torch.nn.functional.pad(x, (0, 0, 0, pad), "constant")
|
58 |
+
# x = x.unsqueeze(1).transpose(2, 3)
|
59 |
+
x = self.audioEncoder(x)
|
60 |
+
|
61 |
+
b, c, t2, freq = x.shape
|
62 |
+
x = x.view(b * c, t2, freq)
|
63 |
+
x = self.audio_pool(x)
|
64 |
+
x = x.view(b, c, t2)[:, :, :numFrames]
|
65 |
+
x = x.permute(0, 2, 1)
|
66 |
+
return x
|
67 |
+
|
68 |
+
def forward_cross_attention(self, x1, x2):
|
69 |
+
x1_c = self.crossA2V(src=x1, tar=x2, adjust=self.cfg.adjust_attention)
|
70 |
+
x2_c = self.crossV2A(src=x2, tar=x1, adjust=self.cfg.adjust_attention)
|
71 |
+
return x1_c, x2_c
|
72 |
+
|
73 |
+
def forward_audio_visual_backend(self, x1, x2, b=1, s=1):
|
74 |
+
x = torch.cat((x1, x2), 2) # B*S, T, 2C
|
75 |
+
for i, layer in enumerate(self.convAV):
|
76 |
+
if i % 2 == 0:
|
77 |
+
x, b, s = layer(x, b, s)
|
78 |
+
else:
|
79 |
+
x = layer(src=x, tar=x)
|
80 |
+
|
81 |
+
x = torch.reshape(x, (-1, 256))
|
82 |
+
return x
|
83 |
+
|
84 |
+
def forward_audio_backend(self, x):
|
85 |
+
x = torch.reshape(x, (-1, 128))
|
86 |
+
return x
|
87 |
+
|
88 |
+
def forward_visual_backend(self, x):
|
89 |
+
x = torch.reshape(x, (-1, 128))
|
90 |
+
return x
|
visualEncoder.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##
|
2 |
+
# ResNet18 Pretrained network to extract lip embedding
|
3 |
+
# This code is modified based on https://github.com/lordmartian/deep_avsr
|
4 |
+
##
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from attentionLayer import attentionLayer
|
10 |
+
|
11 |
+
|
12 |
+
class ResNetLayer(nn.Module):
|
13 |
+
"""
|
14 |
+
A ResNet layer used to build the ResNet network.
|
15 |
+
Architecture:
|
16 |
+
--> conv-bn-relu -> conv -> + -> bn-relu -> conv-bn-relu -> conv -> + -> bn-relu -->
|
17 |
+
| | | |
|
18 |
+
-----> downsample ------> ------------------------------------->
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, inplanes, outplanes, stride):
|
22 |
+
super(ResNetLayer, self).__init__()
|
23 |
+
self.conv1a = nn.Conv2d(inplanes,
|
24 |
+
outplanes,
|
25 |
+
kernel_size=3,
|
26 |
+
stride=stride,
|
27 |
+
padding=1,
|
28 |
+
bias=False)
|
29 |
+
self.bn1a = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
|
30 |
+
self.conv2a = nn.Conv2d(outplanes,
|
31 |
+
outplanes,
|
32 |
+
kernel_size=3,
|
33 |
+
stride=1,
|
34 |
+
padding=1,
|
35 |
+
bias=False)
|
36 |
+
self.stride = stride
|
37 |
+
if self.stride != 1:
|
38 |
+
self.downsample = nn.Conv2d(inplanes,
|
39 |
+
outplanes,
|
40 |
+
kernel_size=(1, 1),
|
41 |
+
stride=stride,
|
42 |
+
bias=False)
|
43 |
+
self.outbna = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
|
44 |
+
|
45 |
+
self.conv1b = nn.Conv2d(outplanes,
|
46 |
+
outplanes,
|
47 |
+
kernel_size=3,
|
48 |
+
stride=1,
|
49 |
+
padding=1,
|
50 |
+
bias=False)
|
51 |
+
self.bn1b = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
|
52 |
+
self.conv2b = nn.Conv2d(outplanes,
|
53 |
+
outplanes,
|
54 |
+
kernel_size=3,
|
55 |
+
stride=1,
|
56 |
+
padding=1,
|
57 |
+
bias=False)
|
58 |
+
self.outbnb = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
|
59 |
+
return
|
60 |
+
|
61 |
+
def forward(self, inputBatch):
|
62 |
+
batch = F.relu(self.bn1a(self.conv1a(inputBatch)))
|
63 |
+
batch = self.conv2a(batch)
|
64 |
+
if self.stride == 1:
|
65 |
+
residualBatch = inputBatch
|
66 |
+
else:
|
67 |
+
residualBatch = self.downsample(inputBatch)
|
68 |
+
batch = batch + residualBatch
|
69 |
+
intermediateBatch = batch
|
70 |
+
batch = F.relu(self.outbna(batch))
|
71 |
+
|
72 |
+
batch = F.relu(self.bn1b(self.conv1b(batch)))
|
73 |
+
batch = self.conv2b(batch)
|
74 |
+
residualBatch = intermediateBatch
|
75 |
+
batch = batch + residualBatch
|
76 |
+
outputBatch = F.relu(self.outbnb(batch))
|
77 |
+
return outputBatch
|
78 |
+
|
79 |
+
|
80 |
+
class ResNet(nn.Module):
|
81 |
+
"""
|
82 |
+
An 18-layer ResNet architecture.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self):
|
86 |
+
super(ResNet, self).__init__()
|
87 |
+
self.layer1 = ResNetLayer(64, 64, stride=1)
|
88 |
+
self.layer2 = ResNetLayer(64, 128, stride=2)
|
89 |
+
self.layer3 = ResNetLayer(128, 256, stride=2)
|
90 |
+
self.layer4 = ResNetLayer(256, 512, stride=2)
|
91 |
+
self.avgpool = nn.AvgPool2d(kernel_size=(4, 4), stride=(1, 1))
|
92 |
+
|
93 |
+
return
|
94 |
+
|
95 |
+
def forward(self, inputBatch):
|
96 |
+
batch = self.layer1(inputBatch)
|
97 |
+
batch = self.layer2(batch)
|
98 |
+
batch = self.layer3(batch)
|
99 |
+
batch = self.layer4(batch)
|
100 |
+
outputBatch = self.avgpool(batch)
|
101 |
+
return outputBatch
|
102 |
+
|
103 |
+
|
104 |
+
class GlobalLayerNorm(nn.Module):
|
105 |
+
|
106 |
+
def __init__(self, channel_size):
|
107 |
+
super(GlobalLayerNorm, self).__init__()
|
108 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
109 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
110 |
+
self.reset_parameters()
|
111 |
+
|
112 |
+
def reset_parameters(self):
|
113 |
+
self.gamma.data.fill_(1)
|
114 |
+
self.beta.data.zero_()
|
115 |
+
|
116 |
+
def forward(self, y):
|
117 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1]
|
118 |
+
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
119 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + 1e-8, 0.5) + self.beta
|
120 |
+
return gLN_y
|
121 |
+
|
122 |
+
|
123 |
+
class visualFrontend(nn.Module):
|
124 |
+
"""
|
125 |
+
A visual feature extraction module. Generates a 512-dim feature vector per video frame.
|
126 |
+
Architecture: A 3D convolution block followed by an 18-layer ResNet.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, cfg):
|
130 |
+
self.cfg = cfg
|
131 |
+
super(visualFrontend, self).__init__()
|
132 |
+
self.frontend3D = nn.Sequential(
|
133 |
+
nn.Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3),
|
134 |
+
bias=False), nn.BatchNorm3d(64, momentum=0.01, eps=0.001), nn.ReLU(),
|
135 |
+
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))
|
136 |
+
self.resnet = ResNet()
|
137 |
+
return
|
138 |
+
|
139 |
+
def forward(self, inputBatch):
|
140 |
+
inputBatch = inputBatch.transpose(0, 1).transpose(1, 2)
|
141 |
+
batchsize = inputBatch.shape[0]
|
142 |
+
batch = self.frontend3D(inputBatch)
|
143 |
+
|
144 |
+
batch = batch.transpose(1, 2)
|
145 |
+
batch = batch.reshape(batch.shape[0] * batch.shape[1], batch.shape[2], batch.shape[3],
|
146 |
+
batch.shape[4])
|
147 |
+
outputBatch = self.resnet(batch)
|
148 |
+
outputBatch = outputBatch.reshape(batchsize, -1, 512)
|
149 |
+
outputBatch = outputBatch.transpose(1, 2)
|
150 |
+
outputBatch = outputBatch.transpose(1, 2).transpose(0, 1)
|
151 |
+
return outputBatch
|
152 |
+
|
153 |
+
|
154 |
+
class DSConv1d(nn.Module):
|
155 |
+
|
156 |
+
def __init__(self):
|
157 |
+
super(DSConv1d, self).__init__()
|
158 |
+
self.net = nn.Sequential(
|
159 |
+
nn.ReLU(),
|
160 |
+
nn.BatchNorm1d(512),
|
161 |
+
nn.Conv1d(512, 512, 3, stride=1, padding=1, dilation=1, groups=512, bias=False),
|
162 |
+
nn.PReLU(),
|
163 |
+
GlobalLayerNorm(512),
|
164 |
+
nn.Conv1d(512, 512, 1, bias=False),
|
165 |
+
)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
out = self.net(x)
|
169 |
+
return out + x
|
170 |
+
|
171 |
+
|
172 |
+
class visualTCN(nn.Module):
|
173 |
+
|
174 |
+
def __init__(self):
|
175 |
+
super(visualTCN, self).__init__()
|
176 |
+
stacks = []
|
177 |
+
for x in range(5):
|
178 |
+
stacks += [DSConv1d()]
|
179 |
+
self.net = nn.Sequential(*stacks) # Visual Temporal Network V-TCN
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
out = self.net(x)
|
183 |
+
return out
|
184 |
+
|
185 |
+
|
186 |
+
class visualConv1D(nn.Module):
|
187 |
+
|
188 |
+
def __init__(self):
|
189 |
+
super(visualConv1D, self).__init__()
|
190 |
+
self.net = nn.Sequential(
|
191 |
+
nn.Conv1d(512, 256, 5, stride=1, padding=2),
|
192 |
+
nn.BatchNorm1d(256),
|
193 |
+
nn.ReLU(),
|
194 |
+
nn.Conv1d(256, 128, 1),
|
195 |
+
)
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
out = self.net(x)
|
199 |
+
return out
|