''' not exactly the same as the official repo but the results are good ''' import sys import os from transformers import Wav2Vec2Processor from .wav2vec import Wav2Vec2Model from torchaudio.sox_effects import apply_effects_tensor sys.path.append(os.getcwd()) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchaudio as ta import math from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000): """ :param audio: 1 x T tensor containing a 16kHz audio signal :param frame_rate: frame rate for video (we need one audio chunk per video frame) :param chunk_size: number of audio samples per chunk :return: num_chunks x chunk_size tensor containing sliced audio """ samples_per_frame = 16000 // frame_rate padding = (chunk_size - samples_per_frame) // 2 audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0) anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame)) audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0) return audio class MeshtalkEncoder(nn.Module): def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'): """ :param latent_dim: size of the latent audio embedding :param model_name: name of the model, used to load and save the model """ super().__init__() self.melspec = ta.transforms.MelSpectrogram( sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80 ) conv_len = 5 self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len) self.weights_init(self.convert_dimensions) self.receptive_field = conv_len convs = [] for i in range(6): dilation = 2 * (i % 3 + 1) self.receptive_field += (conv_len - 1) * dilation convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)] self.weights_init(convs[-1]) self.convs = torch.nn.ModuleList(convs) self.code = torch.nn.Linear(128, latent_dim) self.apply(lambda x: self.weights_init(x)) def weights_init(self, m): if isinstance(m, torch.nn.Conv1d): torch.nn.init.xavier_uniform_(m.weight) try: torch.nn.init.constant_(m.bias, .01) except: pass def forward(self, audio: torch.Tensor): """ :param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame :return: code: B x T x latent_dim Tensor containing a latent audio code/embedding """ B, T = audio.shape[0], audio.shape[1] x = self.melspec(audio).squeeze(1) x = torch.log(x.clamp(min=1e-10, max=None)) if T == 1: x = x.unsqueeze(1) # Convert to the right dimensionality x = x.view(-1, x.shape[2], x.shape[3]) x = F.leaky_relu(self.convert_dimensions(x), .2) # Process stacks for conv in self.convs: x_ = F.leaky_relu(conv(x), .2) if self.training: x_ = F.dropout(x_, .2) l = (x.shape[2] - x_.shape[2]) // 2 x = (x[:, :, l:-l] + x_) / 2 x = torch.mean(x, dim=-1) x = x.view(B, T, x.shape[-1]) x = self.code(x) return {"code": x} class AudioEncoder(nn.Module): def __init__(self, in_dim, out_dim, identity=False, num_classes=0): super().__init__() self.identity = identity if self.identity: in_dim = in_dim + 64 self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1) self.first_net = SeqTranslator1D(in_dim, out_dim, min_layers_num=3, residual=True, norm='ln' ) self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True) self.dropout = nn.Dropout(0.1) # self.att = nn.MultiheadAttention(out_dim, 4, dropout=0.1, batch_first=True) def forward(self, spectrogram, pre_state=None, id=None, time_steps=None): spectrogram = spectrogram spectrogram = self.dropout(spectrogram) if self.identity: id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32) id = self.id_mlp(id) spectrogram = torch.cat([spectrogram, id], dim=1) x1 = self.first_net(spectrogram)# .permute(0, 2, 1) if time_steps is not None: x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear') # x1, _ = self.att(x1, x1, x1) # x1, hidden_state = self.grus(x1) # x1 = x1.permute(0, 2, 1) hidden_state=None return x1, hidden_state class Generator(nn.Module): def __init__(self, n_poses, each_dim: list, dim_list: list, training=False, device=None, identity=True, num_classes=0, ): super().__init__() self.training = training self.device = device self.gen_length = n_poses self.identity = identity norm = 'ln' in_dim = 256 out_dim = 256 self.encoder_choice = 'faceformer' if self.encoder_choice == 'meshtalk': self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim) elif self.encoder_choice == 'faceformer': # wav2vec 2.0 weights initialization self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") # "vitouphy/wav2vec2-xls-r-300m-phoneme""facebook/wav2vec2-base-960h" self.audio_encoder.feature_extractor._freeze_parameters() self.audio_feature_map = nn.Linear(768, in_dim) else: self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim) self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes) self.dim_list = dim_list self.decoder = nn.ModuleList() self.final_out = nn.ModuleList() self.decoder.append(nn.Sequential( ConvNormRelu(out_dim, 64, norm=norm), ConvNormRelu(64, 64, norm=norm), ConvNormRelu(64, 64, norm=norm), )) self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1)) self.decoder.append(nn.Sequential( ConvNormRelu(out_dim, out_dim, norm=norm), ConvNormRelu(out_dim, out_dim, norm=norm), ConvNormRelu(out_dim, out_dim, norm=norm), )) self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1)) def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None): if self.training: time_steps = gt_poses.shape[1] # vector, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) if self.encoder_choice == 'meshtalk': in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000) feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2) elif self.encoder_choice == 'faceformer': hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state feature = self.audio_feature_map(hidden_states).transpose(1, 2) else: feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) # hidden_states = in_spec feature, _ = self.audio_middle(feature, id=id) out = [] for i in range(self.decoder.__len__()): mid = self.decoder[i](feature) mid = self.final_out[i](mid) out.append(mid) out = torch.cat(out, dim=1) out = out.transpose(1, 2) return out, None