Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from src.audio2pose_models.cvae import CVAE | |
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator | |
from src.audio2pose_models.audio_encoder import AudioEncoder | |
class Audio2Pose(nn.Module): | |
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): | |
super().__init__() | |
self.cfg = cfg | |
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN | |
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE | |
self.device = device | |
self.audio_encoder = AudioEncoder(wav2lip_checkpoint) | |
self.audio_encoder.eval() | |
for param in self.audio_encoder.parameters(): | |
param.requires_grad = False | |
self.netG = CVAE(cfg) | |
self.netD_motion = PoseSequenceDiscriminator(cfg) | |
self.gan_criterion = nn.MSELoss() | |
self.reg_criterion = nn.L1Loss(reduction='none') | |
self.pair_criterion = nn.PairwiseDistance() | |
self.cosine_loss = nn.CosineSimilarity(dim=1) | |
def forward(self, x): | |
batch = {} | |
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73 | |
batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6 | |
batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6 | |
batch['class'] = x['class'].squeeze(0).cuda() # bs | |
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16 | |
# forward | |
audio_emb_list = [] | |
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512 | |
batch['audio_emb'] = audio_emb | |
batch = self.netG(batch) | |
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6 | |
pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6 | |
pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6 | |
batch['pose_pred'] = pose_pred | |
batch['pose_gt'] = pose_gt | |
return batch | |
def test(self, x): | |
batch = {} | |
ref = x['ref'] #bs 1 70 | |
batch['ref'] = x['ref'][:,0,-6:] | |
batch['class'] = x['class'] | |
bs = ref.shape[0] | |
indiv_mels= x['indiv_mels'] # bs T 1 80 16 | |
indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame | |
num_frames = x['num_frames'] | |
num_frames = int(num_frames) - 1 | |
# | |
div = num_frames//self.seq_len | |
re = num_frames%self.seq_len | |
audio_emb_list = [] | |
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, | |
device=batch['ref'].device)] | |
for i in range(div): | |
z = torch.randn(bs, self.latent_dim).to(ref.device) | |
batch['z'] = z | |
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512 | |
batch['audio_emb'] = audio_emb | |
batch = self.netG.test(batch) | |
pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6 | |
if re != 0: | |
z = torch.randn(bs, self.latent_dim).to(ref.device) | |
batch['z'] = z | |
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512 | |
batch['audio_emb'] = audio_emb | |
batch = self.netG.test(batch) | |
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) | |
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) | |
batch['pose_motion_pred'] = pose_motion_pred | |
pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6 | |
batch['pose_pred'] = pose_pred | |
return batch | |