TalkSHOW / nets /spg /s2glayers.py
feifeifeiliu's picture
first version
865fd8a
'''
not exactly the same as the official repo but the results are good
'''
import sys
import os
sys.path.append(os.getcwd())
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from nets.layers import SeqEncoder1D, SeqTranslator1D
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """
class Conv2d_tf(nn.Conv2d):
"""
Conv2d with the padding behavior from TF
from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
"""
def __init__(self, *args, **kwargs):
super(Conv2d_tf, self).__init__(*args, **kwargs)
self.padding = kwargs.get("padding", "SAME")
def _compute_padding(self, input, dim):
input_size = input.size(dim + 2)
filter_size = self.weight.size(dim + 2)
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
total_padding = max(
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
)
additional_padding = int(total_padding % 2 != 0)
return additional_padding, total_padding
def forward(self, input):
if self.padding == "VALID":
return F.conv2d(
input,
self.weight,
self.bias,
self.stride,
padding=0,
dilation=self.dilation,
groups=self.groups,
)
rows_odd, padding_rows = self._compute_padding(input, dim=0)
cols_odd, padding_cols = self._compute_padding(input, dim=1)
if rows_odd or cols_odd:
input = F.pad(input, [0, cols_odd, 0, rows_odd])
return F.conv2d(
input,
self.weight,
self.bias,
self.stride,
padding=(padding_rows // 2, padding_cols // 2),
dilation=self.dilation,
groups=self.groups,
)
class Conv1d_tf(nn.Conv1d):
"""
Conv1d with the padding behavior from TF
modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
"""
def __init__(self, *args, **kwargs):
super(Conv1d_tf, self).__init__(*args, **kwargs)
self.padding = kwargs.get("padding")
def _compute_padding(self, input, dim):
input_size = input.size(dim + 2)
filter_size = self.weight.size(dim + 2)
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
total_padding = max(
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
)
additional_padding = int(total_padding % 2 != 0)
return additional_padding, total_padding
def forward(self, input):
# if self.padding == "valid":
# return F.conv1d(
# input,
# self.weight,
# self.bias,
# self.stride,
# padding=0,
# dilation=self.dilation,
# groups=self.groups,
# )
rows_odd, padding_rows = self._compute_padding(input, dim=0)
if rows_odd:
input = F.pad(input, [0, rows_odd])
return F.conv1d(
input,
self.weight,
self.bias,
self.stride,
padding=(padding_rows // 2),
dilation=self.dilation,
groups=self.groups,
)
def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, padding='valid', groups=1,
nonlinear='lrelu', bn='bn'):
if k is None and s is None:
if not downsample:
k = 3
s = 1
padding = 'same'
else:
k = 4
s = 2
padding = 'valid'
if type == '1d':
conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
norm_block = nn.BatchNorm1d(out_channels)
elif type == '2d':
conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups)
norm_block = nn.BatchNorm2d(out_channels)
else:
assert False
if bn != 'bn':
if bn == 'gn':
norm_block = nn.GroupNorm(1, out_channels)
elif bn == 'ln':
norm_block = nn.LayerNorm(out_channels)
else:
norm_block = nn.Identity()
if nonlinear == 'lrelu':
nlinear = nn.LeakyReLU(0.2, True)
elif nonlinear == 'tanh':
nlinear = nn.Tanh()
elif nonlinear == 'none':
nlinear = nn.Identity()
return nn.Sequential(
conv_block,
norm_block,
nlinear
)
class UnetUp(nn.Module):
def __init__(self, in_ch, out_ch):
super(UnetUp, self).__init__()
self.conv = ConvNormRelu(in_ch, out_ch)
def forward(self, x1, x2):
# x1 = torch.repeat_interleave(x1, 2, dim=2)
# x1 = x1[:, :, :x2.shape[2]]
x1 = torch.nn.functional.interpolate(x1, size=x2.shape[2], mode='linear')
x = x1 + x2
x = self.conv(x)
return x
class UNet(nn.Module):
def __init__(self, input_dim, dim):
super(UNet, self).__init__()
# dim = 512
self.down1 = nn.Sequential(
ConvNormRelu(input_dim, input_dim, '1d', False),
ConvNormRelu(input_dim, dim, '1d', False),
ConvNormRelu(dim, dim, '1d', False)
)
self.gru = nn.GRU(dim, dim, 1, batch_first=True)
self.down2 = ConvNormRelu(dim, dim, '1d', True)
self.down3 = ConvNormRelu(dim, dim, '1d', True)
self.down4 = ConvNormRelu(dim, dim, '1d', True)
self.down5 = ConvNormRelu(dim, dim, '1d', True)
self.down6 = ConvNormRelu(dim, dim, '1d', True)
self.up1 = UnetUp(dim, dim)
self.up2 = UnetUp(dim, dim)
self.up3 = UnetUp(dim, dim)
self.up4 = UnetUp(dim, dim)
self.up5 = UnetUp(dim, dim)
def forward(self, x1, pre_pose=None, w_pre=False):
x2_0 = self.down1(x1)
if w_pre:
i = 1
x2_pre = self.gru(x2_0[:,:,0:i].permute(0,2,1), pre_pose[:,:,-1:].permute(2,0,1).contiguous())[0].permute(0,2,1)
x2 = torch.cat([x2_pre, x2_0[:,:,i:]], dim=-1)
# x2 = torch.cat([pre_pose, x2_0], dim=2) # [B, 512, 15]
else:
# x2 = self.gru(x2_0.transpose(1, 2))[0].transpose(1,2)
x2 = x2_0
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x6 = self.down5(x5)
x7 = self.down6(x6)
x = self.up1(x7, x6)
x = self.up2(x, x5)
x = self.up3(x, x4)
x = self.up4(x, x3)
x = self.up5(x, x2) # [B, 512, 15]
return x, x2_0
class AudioEncoder(nn.Module):
def __init__(self, n_frames, template_length, pose=False, common_dim=512):
super().__init__()
self.n_frames = n_frames
self.pose = pose
self.step = 0
self.weight = 0
if self.pose:
# self.first_net = nn.Sequential(
# ConvNormRelu(1, 64, '2d', False),
# ConvNormRelu(64, 64, '2d', True),
# ConvNormRelu(64, 128, '2d', False),
# ConvNormRelu(128, 128, '2d', True),
# ConvNormRelu(128, 256, '2d', False),
# ConvNormRelu(256, 256, '2d', True),
# ConvNormRelu(256, 256, '2d', False),
# ConvNormRelu(256, 256, '2d', False, padding='VALID')
# )
# decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4,
# dim_feedforward=2 * args.feature_dim, batch_first=True)
# a = nn.TransformerDecoder
self.first_net = SeqTranslator1D(256, 256,
min_layers_num=4,
residual=True
)
self.dropout_0 = nn.Dropout(0.1)
self.mu_fc = nn.Conv1d(256, 128, 1, 1)
self.var_fc = nn.Conv1d(256, 128, 1, 1)
self.trans_motion = SeqTranslator1D(common_dim, common_dim,
kernel_size=1,
stride=1,
min_layers_num=3,
residual=True
)
# self.att = nn.MultiheadAttention(64 + template_length, 4, dropout=0.1)
self.unet = UNet(128 + template_length, common_dim)
else:
self.first_net = SeqTranslator1D(256, 256,
min_layers_num=4,
residual=True
)
self.dropout_0 = nn.Dropout(0.1)
# self.att = nn.MultiheadAttention(256, 4, dropout=0.1)
self.unet = UNet(256, 256)
self.dropout_1 = nn.Dropout(0.0)
def forward(self, spectrogram, time_steps=None, template=None, pre_pose=None, w_pre=False):
self.step = self.step + 1
if self.pose:
spect = spectrogram.transpose(1, 2)
if w_pre:
spect = spect[:, :, :]
out = self.first_net(spect)
out = self.dropout_0(out)
mu = self.mu_fc(out)
var = self.var_fc(out)
audio = self.__reparam(mu, var)
# audio = out
# template = self.trans_motion(template)
x1 = torch.cat([audio, template], dim=1)#.permute(2,0,1)
# x1 = out
#x1, _ = self.att(x1, x1, x1)
#x1 = x1.permute(1,2,0)
x1, x2_0 = self.unet(x1, pre_pose=pre_pose, w_pre=w_pre)
else:
spectrogram = spectrogram.transpose(1, 2)
x1 = self.first_net(spectrogram)#.permute(2,0,1)
#out, _ = self.att(out, out, out)
#out = out.permute(1, 2, 0)
x1 = self.dropout_0(x1)
x1, x2_0 = self.unet(x1)
x1 = self.dropout_1(x1)
mu = None
var = None
return x1, (mu, var), x2_0
def __reparam(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std, device='cuda')
z = eps * std + mu
return z
class Generator(nn.Module):
def __init__(self,
n_poses,
pose_dim,
pose,
n_pre_poses,
each_dim: list,
dim_list: list,
use_template=False,
template_length=0,
training=False,
device=None,
separate=False,
expression=False
):
super().__init__()
self.use_template = use_template
self.template_length = template_length
self.training = training
self.device = device
self.separate = separate
self.pose = pose
self.decoderf = True
self.expression = expression
common_dim = 256
if self.use_template:
assert template_length > 0
# self.KLLoss = KLLoss(kl_tolerance=self.config.Train.weights.kl_tolerance).to(self.device)
# self.pose_encoder = SeqEncoder1D(
# C_in=pose_dim,
# C_out=512,
# T_in=n_poses,
# min_layer_nums=6
#
# )
self.pose_encoder = SeqTranslator1D(pose_dim - 50, common_dim,
# kernel_size=1,
# stride=1,
min_layers_num=3,
residual=True
)
self.mu_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
self.var_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1)
else:
self.template_length = 0
self.gen_length = n_poses
self.audio_encoder = AudioEncoder(n_poses, template_length, True, common_dim)
self.speech_encoder = AudioEncoder(n_poses, template_length, False)
# self.pre_pose_encoder = SeqEncoder1D(
# C_in=pose_dim,
# C_out=128,
# T_in=15,
# min_layer_nums=3
#
# )
# self.pmu_fc = nn.Linear(128, 64)
# self.pvar_fc = nn.Linear(128, 64)
self.pre_pose_encoder = SeqTranslator1D(pose_dim-50, common_dim,
min_layers_num=5,
residual=True
)
self.decoder_in = 256 + 64
self.dim_list = dim_list
if self.separate:
self.decoder = nn.ModuleList()
self.final_out = nn.ModuleList()
self.decoder.append(nn.Sequential(
ConvNormRelu(256, 64),
ConvNormRelu(64, 64),
ConvNormRelu(64, 64),
))
self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1))
self.decoder.append(nn.Sequential(
ConvNormRelu(common_dim, common_dim),
ConvNormRelu(common_dim, common_dim),
ConvNormRelu(common_dim, common_dim),
))
self.final_out.append(nn.Conv1d(common_dim, each_dim[1], 1, 1))
self.decoder.append(nn.Sequential(
ConvNormRelu(common_dim, common_dim),
ConvNormRelu(common_dim, common_dim),
ConvNormRelu(common_dim, common_dim),
))
self.final_out.append(nn.Conv1d(common_dim, each_dim[2], 1, 1))
if self.expression:
self.decoder.append(nn.Sequential(
ConvNormRelu(256, 256),
ConvNormRelu(256, 256),
ConvNormRelu(256, 256),
))
self.final_out.append(nn.Conv1d(256, each_dim[3], 1, 1))
else:
self.decoder = nn.Sequential(
ConvNormRelu(self.decoder_in, 512),
ConvNormRelu(512, 512),
ConvNormRelu(512, 512),
ConvNormRelu(512, 512),
ConvNormRelu(512, 512),
ConvNormRelu(512, 512),
)
self.final_out = nn.Conv1d(512, pose_dim, 1, 1)
def __reparam(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std, device=self.device)
z = eps * std + mu
return z
def forward(self, in_spec, pre_poses, gt_poses, template=None, time_steps=None, w_pre=False, norm=True):
if time_steps is not None:
self.gen_length = time_steps
if self.use_template:
if self.training:
if w_pre:
in_spec = in_spec[:, 15:, :]
pre_pose = self.pre_pose_encoder(gt_poses[:, 14:15, :-50].permute(0, 2, 1))
pose_enc = self.pose_encoder(gt_poses[:, 15:, :-50].permute(0, 2, 1))
mu = self.mu_fc(pose_enc)
var = self.var_fc(pose_enc)
template = self.__reparam(mu, var)
else:
pre_pose = None
pose_enc = self.pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
mu = self.mu_fc(pose_enc)
var = self.var_fc(pose_enc)
template = self.__reparam(mu, var)
elif pre_poses is not None:
if w_pre:
pre_pose = pre_poses[:, -1:, :-50]
if norm:
pre_pose = pre_pose.reshape(1, -1, 55, 5)
pre_pose = torch.cat([F.normalize(pre_pose[..., :3], dim=-1),
F.normalize(pre_pose[..., 3:5], dim=-1)],
dim=-1).reshape(1, -1, 275)
pre_pose = self.pre_pose_encoder(pre_pose.permute(0, 2, 1))
template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length ]).to(
in_spec.device)
else:
pre_pose = None
template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
elif gt_poses is not None:
template = self.pre_pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1))
elif template is None:
pre_pose = None
template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device)
else:
template = None
mu = None
var = None
a_t_f, (mu2, var2), x2_0 = self.audio_encoder(in_spec, time_steps=time_steps, template=template, pre_pose=pre_pose, w_pre=w_pre)
s_f, _, _ = self.speech_encoder(in_spec, time_steps=time_steps)
out = []
if self.separate:
for i in range(self.decoder.__len__()):
if i == 0 or i == 3:
mid = self.decoder[i](s_f)
else:
mid = self.decoder[i](a_t_f)
mid = self.final_out[i](mid)
out.append(mid)
out = torch.cat(out, dim=1)
else:
out = self.decoder(a_t_f)
out = self.final_out(out)
out = out.transpose(1, 2)
if self.training:
if w_pre:
return out, template, mu, var, (mu2, var2, x2_0, pre_pose)
else:
return out, template, mu, var, (mu2, var2, None, None)
else:
return out
class Discriminator(nn.Module):
def __init__(self, pose_dim, pose):
super().__init__()
self.net = nn.Sequential(
Conv1d_tf(pose_dim, 64, kernel_size=4, stride=2, padding='SAME'),
nn.LeakyReLU(0.2, True),
ConvNormRelu(64, 128, '1d', True),
ConvNormRelu(128, 256, '1d', k=4, s=1),
Conv1d_tf(256, 1, kernel_size=4, stride=1, padding='SAME'),
)
def forward(self, x):
x = x.transpose(1, 2)
out = self.net(x)
return out
def main():
d = Discriminator(275, 55)
x = torch.randn([8, 60, 275])
result = d(x)
if __name__ == "__main__":
main()