Spaces:
Build error
Build error
''' | |
not exactly the same as the official repo but the results are good | |
''' | |
import sys | |
import os | |
sys.path.append(os.getcwd()) | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from nets.layers import SeqEncoder1D, SeqTranslator1D | |
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ | |
class Conv2d_tf(nn.Conv2d): | |
""" | |
Conv2d with the padding behavior from TF | |
from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py | |
""" | |
def __init__(self, *args, **kwargs): | |
super(Conv2d_tf, self).__init__(*args, **kwargs) | |
self.padding = kwargs.get("padding", "SAME") | |
def _compute_padding(self, input, dim): | |
input_size = input.size(dim + 2) | |
filter_size = self.weight.size(dim + 2) | |
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 | |
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] | |
total_padding = max( | |
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size | |
) | |
additional_padding = int(total_padding % 2 != 0) | |
return additional_padding, total_padding | |
def forward(self, input): | |
if self.padding == "VALID": | |
return F.conv2d( | |
input, | |
self.weight, | |
self.bias, | |
self.stride, | |
padding=0, | |
dilation=self.dilation, | |
groups=self.groups, | |
) | |
rows_odd, padding_rows = self._compute_padding(input, dim=0) | |
cols_odd, padding_cols = self._compute_padding(input, dim=1) | |
if rows_odd or cols_odd: | |
input = F.pad(input, [0, cols_odd, 0, rows_odd]) | |
return F.conv2d( | |
input, | |
self.weight, | |
self.bias, | |
self.stride, | |
padding=(padding_rows // 2, padding_cols // 2), | |
dilation=self.dilation, | |
groups=self.groups, | |
) | |
class Conv1d_tf(nn.Conv1d): | |
""" | |
Conv1d with the padding behavior from TF | |
modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py | |
""" | |
def __init__(self, *args, **kwargs): | |
super(Conv1d_tf, self).__init__(*args, **kwargs) | |
self.padding = kwargs.get("padding") | |
def _compute_padding(self, input, dim): | |
input_size = input.size(dim + 2) | |
filter_size = self.weight.size(dim + 2) | |
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 | |
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] | |
total_padding = max( | |
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size | |
) | |
additional_padding = int(total_padding % 2 != 0) | |
return additional_padding, total_padding | |
def forward(self, input): | |
# if self.padding == "valid": | |
# return F.conv1d( | |
# input, | |
# self.weight, | |
# self.bias, | |
# self.stride, | |
# padding=0, | |
# dilation=self.dilation, | |
# groups=self.groups, | |
# ) | |
rows_odd, padding_rows = self._compute_padding(input, dim=0) | |
if rows_odd: | |
input = F.pad(input, [0, rows_odd]) | |
return F.conv1d( | |
input, | |
self.weight, | |
self.bias, | |
self.stride, | |
padding=(padding_rows // 2), | |
dilation=self.dilation, | |
groups=self.groups, | |
) | |
def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, padding='valid', groups=1, | |
nonlinear='lrelu', bn='bn'): | |
if k is None and s is None: | |
if not downsample: | |
k = 3 | |
s = 1 | |
padding = 'same' | |
else: | |
k = 4 | |
s = 2 | |
padding = 'valid' | |
if type == '1d': | |
conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups) | |
norm_block = nn.BatchNorm1d(out_channels) | |
elif type == '2d': | |
conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups) | |
norm_block = nn.BatchNorm2d(out_channels) | |
else: | |
assert False | |
if bn != 'bn': | |
if bn == 'gn': | |
norm_block = nn.GroupNorm(1, out_channels) | |
elif bn == 'ln': | |
norm_block = nn.LayerNorm(out_channels) | |
else: | |
norm_block = nn.Identity() | |
if nonlinear == 'lrelu': | |
nlinear = nn.LeakyReLU(0.2, True) | |
elif nonlinear == 'tanh': | |
nlinear = nn.Tanh() | |
elif nonlinear == 'none': | |
nlinear = nn.Identity() | |
return nn.Sequential( | |
conv_block, | |
norm_block, | |
nlinear | |
) | |
class UnetUp(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(UnetUp, self).__init__() | |
self.conv = ConvNormRelu(in_ch, out_ch) | |
def forward(self, x1, x2): | |
# x1 = torch.repeat_interleave(x1, 2, dim=2) | |
# x1 = x1[:, :, :x2.shape[2]] | |
x1 = torch.nn.functional.interpolate(x1, size=x2.shape[2], mode='linear') | |
x = x1 + x2 | |
x = self.conv(x) | |
return x | |
class UNet(nn.Module): | |
def __init__(self, input_dim, dim): | |
super(UNet, self).__init__() | |
# dim = 512 | |
self.down1 = nn.Sequential( | |
ConvNormRelu(input_dim, input_dim, '1d', False), | |
ConvNormRelu(input_dim, dim, '1d', False), | |
ConvNormRelu(dim, dim, '1d', False) | |
) | |
self.gru = nn.GRU(dim, dim, 1, batch_first=True) | |
self.down2 = ConvNormRelu(dim, dim, '1d', True) | |
self.down3 = ConvNormRelu(dim, dim, '1d', True) | |
self.down4 = ConvNormRelu(dim, dim, '1d', True) | |
self.down5 = ConvNormRelu(dim, dim, '1d', True) | |
self.down6 = ConvNormRelu(dim, dim, '1d', True) | |
self.up1 = UnetUp(dim, dim) | |
self.up2 = UnetUp(dim, dim) | |
self.up3 = UnetUp(dim, dim) | |
self.up4 = UnetUp(dim, dim) | |
self.up5 = UnetUp(dim, dim) | |
def forward(self, x1, pre_pose=None, w_pre=False): | |
x2_0 = self.down1(x1) | |
if w_pre: | |
i = 1 | |
x2_pre = self.gru(x2_0[:,:,0:i].permute(0,2,1), pre_pose[:,:,-1:].permute(2,0,1).contiguous())[0].permute(0,2,1) | |
x2 = torch.cat([x2_pre, x2_0[:,:,i:]], dim=-1) | |
# x2 = torch.cat([pre_pose, x2_0], dim=2) # [B, 512, 15] | |
else: | |
# x2 = self.gru(x2_0.transpose(1, 2))[0].transpose(1,2) | |
x2 = x2_0 | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x6 = self.down5(x5) | |
x7 = self.down6(x6) | |
x = self.up1(x7, x6) | |
x = self.up2(x, x5) | |
x = self.up3(x, x4) | |
x = self.up4(x, x3) | |
x = self.up5(x, x2) # [B, 512, 15] | |
return x, x2_0 | |
class AudioEncoder(nn.Module): | |
def __init__(self, n_frames, template_length, pose=False, common_dim=512): | |
super().__init__() | |
self.n_frames = n_frames | |
self.pose = pose | |
self.step = 0 | |
self.weight = 0 | |
if self.pose: | |
# self.first_net = nn.Sequential( | |
# ConvNormRelu(1, 64, '2d', False), | |
# ConvNormRelu(64, 64, '2d', True), | |
# ConvNormRelu(64, 128, '2d', False), | |
# ConvNormRelu(128, 128, '2d', True), | |
# ConvNormRelu(128, 256, '2d', False), | |
# ConvNormRelu(256, 256, '2d', True), | |
# ConvNormRelu(256, 256, '2d', False), | |
# ConvNormRelu(256, 256, '2d', False, padding='VALID') | |
# ) | |
# decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4, | |
# dim_feedforward=2 * args.feature_dim, batch_first=True) | |
# a = nn.TransformerDecoder | |
self.first_net = SeqTranslator1D(256, 256, | |
min_layers_num=4, | |
residual=True | |
) | |
self.dropout_0 = nn.Dropout(0.1) | |
self.mu_fc = nn.Conv1d(256, 128, 1, 1) | |
self.var_fc = nn.Conv1d(256, 128, 1, 1) | |
self.trans_motion = SeqTranslator1D(common_dim, common_dim, | |
kernel_size=1, | |
stride=1, | |
min_layers_num=3, | |
residual=True | |
) | |
# self.att = nn.MultiheadAttention(64 + template_length, 4, dropout=0.1) | |
self.unet = UNet(128 + template_length, common_dim) | |
else: | |
self.first_net = SeqTranslator1D(256, 256, | |
min_layers_num=4, | |
residual=True | |
) | |
self.dropout_0 = nn.Dropout(0.1) | |
# self.att = nn.MultiheadAttention(256, 4, dropout=0.1) | |
self.unet = UNet(256, 256) | |
self.dropout_1 = nn.Dropout(0.0) | |
def forward(self, spectrogram, time_steps=None, template=None, pre_pose=None, w_pre=False): | |
self.step = self.step + 1 | |
if self.pose: | |
spect = spectrogram.transpose(1, 2) | |
if w_pre: | |
spect = spect[:, :, :] | |
out = self.first_net(spect) | |
out = self.dropout_0(out) | |
mu = self.mu_fc(out) | |
var = self.var_fc(out) | |
audio = self.__reparam(mu, var) | |
# audio = out | |
# template = self.trans_motion(template) | |
x1 = torch.cat([audio, template], dim=1)#.permute(2,0,1) | |
# x1 = out | |
#x1, _ = self.att(x1, x1, x1) | |
#x1 = x1.permute(1,2,0) | |
x1, x2_0 = self.unet(x1, pre_pose=pre_pose, w_pre=w_pre) | |
else: | |
spectrogram = spectrogram.transpose(1, 2) | |
x1 = self.first_net(spectrogram)#.permute(2,0,1) | |
#out, _ = self.att(out, out, out) | |
#out = out.permute(1, 2, 0) | |
x1 = self.dropout_0(x1) | |
x1, x2_0 = self.unet(x1) | |
x1 = self.dropout_1(x1) | |
mu = None | |
var = None | |
return x1, (mu, var), x2_0 | |
def __reparam(self, mu, log_var): | |
std = torch.exp(0.5 * log_var) | |
eps = torch.randn_like(std, device='cuda') | |
z = eps * std + mu | |
return z | |
class Generator(nn.Module): | |
def __init__(self, | |
n_poses, | |
pose_dim, | |
pose, | |
n_pre_poses, | |
each_dim: list, | |
dim_list: list, | |
use_template=False, | |
template_length=0, | |
training=False, | |
device=None, | |
separate=False, | |
expression=False | |
): | |
super().__init__() | |
self.use_template = use_template | |
self.template_length = template_length | |
self.training = training | |
self.device = device | |
self.separate = separate | |
self.pose = pose | |
self.decoderf = True | |
self.expression = expression | |
common_dim = 256 | |
if self.use_template: | |
assert template_length > 0 | |
# self.KLLoss = KLLoss(kl_tolerance=self.config.Train.weights.kl_tolerance).to(self.device) | |
# self.pose_encoder = SeqEncoder1D( | |
# C_in=pose_dim, | |
# C_out=512, | |
# T_in=n_poses, | |
# min_layer_nums=6 | |
# | |
# ) | |
self.pose_encoder = SeqTranslator1D(pose_dim - 50, common_dim, | |
# kernel_size=1, | |
# stride=1, | |
min_layers_num=3, | |
residual=True | |
) | |
self.mu_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1) | |
self.var_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1) | |
else: | |
self.template_length = 0 | |
self.gen_length = n_poses | |
self.audio_encoder = AudioEncoder(n_poses, template_length, True, common_dim) | |
self.speech_encoder = AudioEncoder(n_poses, template_length, False) | |
# self.pre_pose_encoder = SeqEncoder1D( | |
# C_in=pose_dim, | |
# C_out=128, | |
# T_in=15, | |
# min_layer_nums=3 | |
# | |
# ) | |
# self.pmu_fc = nn.Linear(128, 64) | |
# self.pvar_fc = nn.Linear(128, 64) | |
self.pre_pose_encoder = SeqTranslator1D(pose_dim-50, common_dim, | |
min_layers_num=5, | |
residual=True | |
) | |
self.decoder_in = 256 + 64 | |
self.dim_list = dim_list | |
if self.separate: | |
self.decoder = nn.ModuleList() | |
self.final_out = nn.ModuleList() | |
self.decoder.append(nn.Sequential( | |
ConvNormRelu(256, 64), | |
ConvNormRelu(64, 64), | |
ConvNormRelu(64, 64), | |
)) | |
self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1)) | |
self.decoder.append(nn.Sequential( | |
ConvNormRelu(common_dim, common_dim), | |
ConvNormRelu(common_dim, common_dim), | |
ConvNormRelu(common_dim, common_dim), | |
)) | |
self.final_out.append(nn.Conv1d(common_dim, each_dim[1], 1, 1)) | |
self.decoder.append(nn.Sequential( | |
ConvNormRelu(common_dim, common_dim), | |
ConvNormRelu(common_dim, common_dim), | |
ConvNormRelu(common_dim, common_dim), | |
)) | |
self.final_out.append(nn.Conv1d(common_dim, each_dim[2], 1, 1)) | |
if self.expression: | |
self.decoder.append(nn.Sequential( | |
ConvNormRelu(256, 256), | |
ConvNormRelu(256, 256), | |
ConvNormRelu(256, 256), | |
)) | |
self.final_out.append(nn.Conv1d(256, each_dim[3], 1, 1)) | |
else: | |
self.decoder = nn.Sequential( | |
ConvNormRelu(self.decoder_in, 512), | |
ConvNormRelu(512, 512), | |
ConvNormRelu(512, 512), | |
ConvNormRelu(512, 512), | |
ConvNormRelu(512, 512), | |
ConvNormRelu(512, 512), | |
) | |
self.final_out = nn.Conv1d(512, pose_dim, 1, 1) | |
def __reparam(self, mu, log_var): | |
std = torch.exp(0.5 * log_var) | |
eps = torch.randn_like(std, device=self.device) | |
z = eps * std + mu | |
return z | |
def forward(self, in_spec, pre_poses, gt_poses, template=None, time_steps=None, w_pre=False, norm=True): | |
if time_steps is not None: | |
self.gen_length = time_steps | |
if self.use_template: | |
if self.training: | |
if w_pre: | |
in_spec = in_spec[:, 15:, :] | |
pre_pose = self.pre_pose_encoder(gt_poses[:, 14:15, :-50].permute(0, 2, 1)) | |
pose_enc = self.pose_encoder(gt_poses[:, 15:, :-50].permute(0, 2, 1)) | |
mu = self.mu_fc(pose_enc) | |
var = self.var_fc(pose_enc) | |
template = self.__reparam(mu, var) | |
else: | |
pre_pose = None | |
pose_enc = self.pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1)) | |
mu = self.mu_fc(pose_enc) | |
var = self.var_fc(pose_enc) | |
template = self.__reparam(mu, var) | |
elif pre_poses is not None: | |
if w_pre: | |
pre_pose = pre_poses[:, -1:, :-50] | |
if norm: | |
pre_pose = pre_pose.reshape(1, -1, 55, 5) | |
pre_pose = torch.cat([F.normalize(pre_pose[..., :3], dim=-1), | |
F.normalize(pre_pose[..., 3:5], dim=-1)], | |
dim=-1).reshape(1, -1, 275) | |
pre_pose = self.pre_pose_encoder(pre_pose.permute(0, 2, 1)) | |
template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length ]).to( | |
in_spec.device) | |
else: | |
pre_pose = None | |
template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device) | |
elif gt_poses is not None: | |
template = self.pre_pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1)) | |
elif template is None: | |
pre_pose = None | |
template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device) | |
else: | |
template = None | |
mu = None | |
var = None | |
a_t_f, (mu2, var2), x2_0 = self.audio_encoder(in_spec, time_steps=time_steps, template=template, pre_pose=pre_pose, w_pre=w_pre) | |
s_f, _, _ = self.speech_encoder(in_spec, time_steps=time_steps) | |
out = [] | |
if self.separate: | |
for i in range(self.decoder.__len__()): | |
if i == 0 or i == 3: | |
mid = self.decoder[i](s_f) | |
else: | |
mid = self.decoder[i](a_t_f) | |
mid = self.final_out[i](mid) | |
out.append(mid) | |
out = torch.cat(out, dim=1) | |
else: | |
out = self.decoder(a_t_f) | |
out = self.final_out(out) | |
out = out.transpose(1, 2) | |
if self.training: | |
if w_pre: | |
return out, template, mu, var, (mu2, var2, x2_0, pre_pose) | |
else: | |
return out, template, mu, var, (mu2, var2, None, None) | |
else: | |
return out | |
class Discriminator(nn.Module): | |
def __init__(self, pose_dim, pose): | |
super().__init__() | |
self.net = nn.Sequential( | |
Conv1d_tf(pose_dim, 64, kernel_size=4, stride=2, padding='SAME'), | |
nn.LeakyReLU(0.2, True), | |
ConvNormRelu(64, 128, '1d', True), | |
ConvNormRelu(128, 256, '1d', k=4, s=1), | |
Conv1d_tf(256, 1, kernel_size=4, stride=1, padding='SAME'), | |
) | |
def forward(self, x): | |
x = x.transpose(1, 2) | |
out = self.net(x) | |
return out | |
def main(): | |
d = Discriminator(275, 55) | |
x = torch.randn([8, 60, 275]) | |
result = d(x) | |
if __name__ == "__main__": | |
main() | |