TalkSHOW / nets /smplx_body_vq.py
feifeifeiliu's picture
first version
865fd8a
import os
import sys
from torch.optim.lr_scheduler import StepLR
sys.path.append(os.getcwd())
from nets.layers import *
from nets.base import TrainWrapperBaseClass
from nets.spg.s2glayers import Generator as G_S2G, Discriminator as D_S2G
from nets.spg.vqvae_1d import VQVAE as s2g_body
from nets.utils import parse_audio, denormalize
from data_utils import get_mfcc, get_melspec, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from sklearn.preprocessing import normalize
from data_utils.lower_body import c_index, c_index_3d, c_index_6d
class TrainWrapper(TrainWrapperBaseClass):
'''
a wrapper receving a batch from data_utils and calculate loss
'''
def __init__(self, args, config):
self.args = args
self.config = config
self.device = torch.device(self.args.gpu)
self.global_step = 0
self.convert_to_6d = self.config.Data.pose.convert_to_6d
self.expression = self.config.Data.pose.expression
self.epoch = 0
self.init_params()
self.num_classes = 4
self.composition = self.config.Model.composition
if self.composition:
self.g_body = s2g_body(self.each_dim[1], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
num_residual_layers=2, num_residual_hiddens=512).to(self.device)
self.g_hand = s2g_body(self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num, num_hiddens=1024,
num_residual_layers=2, num_residual_hiddens=512).to(self.device)
else:
self.g = s2g_body(self.each_dim[1] + self.each_dim[2], embedding_dim=64, num_embeddings=config.Model.code_num,
num_hiddens=1024, num_residual_layers=2, num_residual_hiddens=512).to(self.device)
self.discriminator = None
if self.convert_to_6d:
self.c_index = c_index_6d
else:
self.c_index = c_index_3d
super().__init__(args, config)
def init_optimizer(self):
print('using Adam')
if self.composition:
self.g_body_optimizer = optim.Adam(
self.g_body.parameters(),
lr=self.config.Train.learning_rate.generator_learning_rate,
betas=[0.9, 0.999]
)
self.g_hand_optimizer = optim.Adam(
self.g_hand.parameters(),
lr=self.config.Train.learning_rate.generator_learning_rate,
betas=[0.9, 0.999]
)
else:
self.g_optimizer = optim.Adam(
self.g.parameters(),
lr=self.config.Train.learning_rate.generator_learning_rate,
betas=[0.9, 0.999]
)
def state_dict(self):
if self.composition:
model_state = {
'g_body': self.g_body.state_dict(),
'g_body_optim': self.g_body_optimizer.state_dict(),
'g_hand': self.g_hand.state_dict(),
'g_hand_optim': self.g_hand_optimizer.state_dict(),
'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
}
else:
model_state = {
'g': self.g.state_dict(),
'g_optim': self.g_optimizer.state_dict(),
'discriminator': self.discriminator.state_dict() if self.discriminator is not None else None,
'discriminator_optim': self.discriminator_optimizer.state_dict() if self.discriminator is not None else None
}
return model_state
def init_params(self):
if self.config.Data.pose.convert_to_6d:
scale = 2
else:
scale = 1
global_orient = round(0 * scale)
leye_pose = reye_pose = round(0 * scale)
jaw_pose = round(0 * scale)
body_pose = round((63 - 24) * scale)
left_hand_pose = right_hand_pose = round(45 * scale)
if self.expression:
expression = 100
else:
expression = 0
b_j = 0
jaw_dim = jaw_pose
b_e = b_j + jaw_dim
eye_dim = leye_pose + reye_pose
b_b = b_e + eye_dim
body_dim = global_orient + body_pose
b_h = b_b + body_dim
hand_dim = left_hand_pose + right_hand_pose
b_f = b_h + hand_dim
face_dim = expression
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
self.pose = int(self.full_dim / round(3 * scale))
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
def __call__(self, bat):
# assert (not self.args.infer), "infer mode"
self.global_step += 1
total_loss = None
loss_dict = {}
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
# id = bat['speaker'].to(self.device) - 20
# id = F.one_hot(id, self.num_classes)
poses = poses[:, self.c_index, :]
gt_poses = poses.permute(0, 2, 1)
b_poses = gt_poses[..., :self.each_dim[1]]
h_poses = gt_poses[..., self.each_dim[1]:]
if self.composition:
loss = 0
loss_dict, loss = self.vq_train(b_poses[:, :], 'b', self.g_body, loss_dict, loss)
loss_dict, loss = self.vq_train(h_poses[:, :], 'h', self.g_hand, loss_dict, loss)
else:
loss = 0
loss_dict, loss = self.vq_train(gt_poses[:, :], 'g', self.g, loss_dict, loss)
return total_loss, loss_dict
def vq_train(self, gt, name, model, dict, total_loss, pre=None):
e_q_loss, x_recon = model(gt_poses=gt, pre_state=pre)
loss, loss_dict = self.get_loss(pred_poses=x_recon, gt_poses=gt, e_q_loss=e_q_loss, pre=pre)
# total_loss = total_loss + loss
if name == 'b':
optimizer_name = 'g_body_optimizer'
elif name == 'h':
optimizer_name = 'g_hand_optimizer'
elif name == 'g':
optimizer_name = 'g_optimizer'
else:
raise ValueError("model's name must be b or h")
optimizer = getattr(self, optimizer_name)
optimizer.zero_grad()
loss.backward()
optimizer.step()
for key in list(loss_dict.keys()):
dict[name + key] = loss_dict.get(key, 0).item()
return dict, total_loss
def get_loss(self,
pred_poses,
gt_poses,
e_q_loss,
pre=None
):
loss_dict = {}
rec_loss = torch.mean(torch.abs(pred_poses - gt_poses))
v_pr = pred_poses[:, 1:] - pred_poses[:, :-1]
v_gt = gt_poses[:, 1:] - gt_poses[:, :-1]
velocity_loss = torch.mean(torch.abs(v_pr - v_gt))
if pre is None:
f0_vel = 0
else:
v0_pr = pred_poses[:, 0] - pre[:, -1]
v0_gt = gt_poses[:, 0] - pre[:, -1]
f0_vel = torch.mean(torch.abs(v0_pr - v0_gt))
gen_loss = rec_loss + e_q_loss + velocity_loss + f0_vel
loss_dict['rec_loss'] = rec_loss
loss_dict['velocity_loss'] = velocity_loss
# loss_dict['e_q_loss'] = e_q_loss
if pre is not None:
loss_dict['f0_vel'] = f0_vel
return gen_loss, loss_dict
def infer_on_audio(self, aud_fn, initial_pose=None, norm_stats=None, exp=None, var=None, w_pre=False, continuity=False,
id=None, fps=15, sr=22000, smooth=False, **kwargs):
'''
initial_pose: (B, C, T), normalized
(aud_fn, txgfile) -> generated motion (B, T, C)
'''
output = []
assert self.args.infer, "train mode"
if self.composition:
self.g_body.eval()
self.g_hand.eval()
else:
self.g.eval()
if self.config.Data.pose.normalization:
assert norm_stats is not None
data_mean = norm_stats[0]
data_std = norm_stats[1]
# assert initial_pose.shape[-1] == pre_length
if initial_pose is not None:
gt = initial_pose[:, :, :].to(self.device).to(torch.float32)
pre_poses = initial_pose[:, :, :15].permute(0, 2, 1).to(self.device).to(torch.float32)
poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
B = pre_poses.shape[0]
else:
gt = None
pre_poses = None
B = 1
if type(aud_fn) == torch.Tensor:
aud_feat = torch.tensor(aud_fn, dtype=torch.float32).to(self.device)
num_poses_to_generate = aud_feat.shape[-1]
else:
aud_feat = get_mfcc_ta(aud_fn, sr=sr, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
aud_feat = aud_feat[:, :]
num_poses_to_generate = aud_feat.shape[-1]
aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
# pre_poses = torch.randn(pre_poses.shape).to(self.device).to(torch.float32)
if id is None:
id = F.one_hot(torch.tensor([[0]]), self.num_classes).to(self.device)
with torch.no_grad():
aud_feat = aud_feat.permute(0, 2, 1)
gt_poses = gt[:, self.c_index].permute(0, 2, 1)
if self.composition:
if continuity:
pred_poses_body = []
pred_poses_hand = []
pre_b = None
pre_h = None
for i in range(5):
_, pred_body = self.g_body(gt_poses=gt_poses[:, i*60:(i+1)*60, :self.each_dim[1]], pre_state=pre_b)
pre_b = pred_body[..., -1:].transpose(1,2)
pred_poses_body.append(pred_body)
_, pred_hand = self.g_hand(gt_poses=gt_poses[:, i*60:(i+1)*60, self.each_dim[1]:], pre_state=pre_h)
pre_h = pred_hand[..., -1:].transpose(1,2)
pred_poses_hand.append(pred_hand)
pred_poses_body = torch.cat(pred_poses_body, dim=2)
pred_poses_hand = torch.cat(pred_poses_hand, dim=2)
else:
_, pred_poses_body = self.g_body(gt_poses=gt_poses[..., :self.each_dim[1]], id=id)
_, pred_poses_hand = self.g_hand(gt_poses=gt_poses[..., self.each_dim[1]:], id=id)
pred_poses = torch.cat([pred_poses_body, pred_poses_hand], dim=1)
else:
_, pred_poses = self.g(gt_poses=gt_poses, id=id)
pred_poses = pred_poses.transpose(1, 2).cpu().numpy()
output = pred_poses
if self.config.Data.pose.normalization:
output = denormalize(output, data_mean, data_std)
if smooth:
lamda = 0.8
smooth_f = 10
frame = 149
for i in range(smooth_f):
f = frame + i
l = lamda * (i + 1) / smooth_f
output[0, f] = (1 - l) * output[0, f - 1] + l * output[0, f]
output = np.concatenate(output, axis=1)
return output
def load_state_dict(self, state_dict):
if self.composition:
self.g_body.load_state_dict(state_dict['g_body'])
self.g_hand.load_state_dict(state_dict['g_hand'])
else:
self.g.load_state_dict(state_dict['g'])