# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # This code is modified from https://github.com/lifeiteng/vall-e/blob/main/valle/models/valle.py import random from typing import Dict, Iterator, List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torchmetrics.classification import MulticlassAccuracy from utils.util import make_pad_mask from utils.topk_sampling import topk_sampling from modules.general import Transpose from modules.encoder import TokenEmbedding from modules.general import PromptedFeatures from modules.transformer import SinePositionalEmbedding from modules.norms import AdaptiveLayerNorm, LayerNorm from modules.transformer.transformer import TransformerEncoder, TransformerEncoderLayer class VALLE(nn.Module): def __init__( self, cfg, decoder_cls=TransformerEncoder, decoder_layer_cls=TransformerEncoderLayer, ): super().__init__() decoder_dim = cfg.decoder_dim nhead = cfg.nhead nar_scale_factor = cfg.nar_scale_factor num_quantizers = cfg.num_quantizers num_decoder_layers = cfg.num_decoder_layers nar_decoder_dim = int(decoder_dim * nar_scale_factor) self.ar_text_embedding = TokenEmbedding(decoder_dim, cfg.text_token_num) self.nar_text_embedding = TokenEmbedding(nar_decoder_dim, cfg.text_token_num) self.ar_audio_prepend_bos = cfg.prepend_bos self.ar_audio_embedding = TokenEmbedding( decoder_dim, cfg.audio_token_num + 1 + int(cfg.prepend_bos) ) self.audio_token_num = cfg.audio_token_num # PreNet of AR if cfg.add_prenet: self.ar_text_prenet = nn.Sequential( Transpose(), nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"), nn.BatchNorm1d(decoder_dim), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"), nn.BatchNorm1d(decoder_dim), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"), nn.BatchNorm1d(decoder_dim), nn.ReLU(), nn.Dropout(0.5), Transpose(), nn.Linear(decoder_dim, decoder_dim), ) self.ar_audio_prenet = nn.Sequential( nn.Linear(decoder_dim, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, decoder_dim), ) else: self.ar_text_prenet = nn.Identity() self.ar_audio_prenet = nn.Identity() self.ar_text_position = SinePositionalEmbedding( decoder_dim, dropout=0.1, scale=False, alpha=True, ) self.ar_audio_position = SinePositionalEmbedding( decoder_dim, dropout=0.1, scale=False, alpha=True, ) self.ar_decoder = decoder_cls( decoder_layer_cls( decoder_dim, nhead, dim_feedforward=decoder_dim * 4, # *4? dropout=0.1, batch_first=True, norm_first=cfg.norm_first, ), num_layers=num_decoder_layers, norm=LayerNorm(decoder_dim) if cfg.norm_first else None, ) self.ar_predict_layer = nn.Linear( decoder_dim, cfg.audio_token_num + 1, bias=False ) self.ar_accuracy_metric = MulticlassAccuracy( cfg.audio_token_num + 1, top_k=10, average="micro", multidim_average="global", ignore_index=cfg.audio_token_num, ) self.rng = random.Random(0) self.num_heads = nhead self.prefix_mode = cfg.prefix_mode self.num_quantizers = num_quantizers assert num_quantizers >= 1 if num_quantizers > 1: self.nar_audio_embeddings = nn.ModuleList( [ TokenEmbedding(nar_decoder_dim, cfg.audio_token_num + 1) ] # Why the first layer is audio_token_num + 1? + [ TokenEmbedding(nar_decoder_dim, cfg.audio_token_num) for i in range(num_quantizers - 1) ] ) if cfg.add_prenet: self.nar_text_prenet = nn.Sequential( Transpose(), nn.Conv1d( nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same" ), nn.BatchNorm1d(nar_decoder_dim), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d( nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same" ), nn.BatchNorm1d(nar_decoder_dim), nn.ReLU(), nn.Dropout(0.5), nn.Conv1d( nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same" ), nn.BatchNorm1d(nar_decoder_dim), nn.ReLU(), nn.Dropout(0.5), Transpose(), nn.Linear(nar_decoder_dim, nar_decoder_dim), ) self.nar_audio_prenet = nn.Sequential( nn.Linear(nar_decoder_dim, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.25), nn.Linear(256, nar_decoder_dim), ) else: self.nar_text_prenet = nn.Identity() self.nar_audio_prenet = nn.Identity() self.nar_text_position = SinePositionalEmbedding( nar_decoder_dim, dropout=0.0, scale=False, alpha=False, ) self.nar_audio_position = SinePositionalEmbedding( nar_decoder_dim, dropout=0.1, scale=False, alpha=False, ) self.nar_decoder = decoder_cls( decoder_layer_cls( nar_decoder_dim, int(nhead * nar_scale_factor), dim_feedforward=nar_decoder_dim * 4, dropout=0.1, batch_first=True, norm_first=cfg.norm_first, adaptive_layer_norm=True, ), num_layers=int(num_decoder_layers * nar_scale_factor), norm=( AdaptiveLayerNorm( nar_decoder_dim, norm=nn.LayerNorm(nar_decoder_dim) ) if cfg.norm_first else None ), ) self.nar_predict_layers = nn.ModuleList( [ nn.Linear(nar_decoder_dim, cfg.audio_token_num, bias=False) for i in range(num_quantizers - 1) ] ) self.nar_stage_embeddings = nn.ModuleList( [TokenEmbedding(nar_decoder_dim, 1) for i in range(num_quantizers - 1)] ) if cfg.share_embedding: for j in range(0, num_quantizers - 2): self.nar_predict_layers[j].weight = self.nar_audio_embeddings[ j + 2 ].weight self.nar_accuracy_metric = MulticlassAccuracy( cfg.audio_token_num + 1, top_k=10, average="micro", multidim_average="global", ignore_index=cfg.audio_token_num, ) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, y: Union[torch.Tensor, PromptedFeatures], y_lens: Union[torch.Tensor, PromptedFeatures], reduction: str = "sum", train_stage: int = 0, **kwargs, ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: """ Args: x: A 2-D tensor of shape (N, S). x_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (N, T, 8). y_lens: A 1-D tensor of shape (N,). It contains the number of tokens in `x` before padding. train_stage: 0: AR & NAR modules, 1: AR modules, 2: NAR modules Returns: Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. """ assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape y_prompts_codes = None if isinstance(y, PromptedFeatures): y_prompts_codes, y = y.data prompts_len, y_lens = y_lens.data assert prompts_len.min() == prompts_len.max() assert self.prefix_mode == 4 y_prompts_codes = y_prompts_codes.type(torch.int64) assert y.ndim == 3, y.shape assert y_lens.ndim == 1, y_lens.shape x_mask = make_pad_mask(x_lens).to(x.device) y_mask = make_pad_mask(y_lens).to(y.device) y_mask_int = y_mask.type(torch.int64) text = x codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) y, targets = self.pad_y_eos( codes[..., 0], y_mask_int, eos_id=self.audio_token_num ) self.y_mask_int = y_mask_int metrics = {} total_loss = 0.0 xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) if self.ar_audio_prepend_bos: ar_xy_padding_mask = torch.concat( [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 ) else: ar_xy_padding_mask = xy_padding_mask self.xy_padding_mask = xy_padding_mask self.ar_xy_padding_mask = ar_xy_padding_mask # AR Decoder if train_stage in [0, 1]: ar_loss, ar_metrics = self._forward_ar_decoder( text, x_lens.max(), y, y_lens.max(), targets, x_mask, y_mask, reduction ) total_loss += ar_loss metrics["AR_Top100Acc"] = ar_metrics # NAR Decoder if self.ar_audio_prepend_bos: y = y[:, 1:] if self.num_quantizers > 1 and train_stage in [0, 2]: nar_loss, nar_metrics = self._forward_nar_decoder( text, x_lens, y, y_lens, codes, y_prompts_codes, x_mask, y_mask, reduction, ) total_loss += nar_loss metrics["NAR_Top100Acc"] = nar_metrics if train_stage == 0: total_loss = total_loss / 2.0 return total_loss, metrics def _forward_ar_decoder( self, x, x_len, y, y_lens, targets, x_mask, y_mask, reduction ): x = self.ar_text_embedding(x) x = self.ar_text_prenet(x) x = self.ar_text_position(x) y_len = y_lens.max() + int(self.ar_audio_prepend_bos) x_attn_mask = F.pad( torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), (0, y_len), value=True, ) y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1, ), (x_len, 0), value=False, ) xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) bsz, src_len = x.shape[0], x_len + y_len _xy_padding_mask = ( self.ar_xy_padding_mask.view(bsz, 1, 1, src_len) .expand(-1, self.num_heads, -1, -1) .reshape(bsz * self.num_heads, 1, src_len) ) xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) xy_attn_mask = new_attn_mask y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_prenet(y_emb) y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.ar_decoder( (xy_pos, None), mask=xy_attn_mask, ) logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) ar_loss = F.cross_entropy(logits, targets, reduction=reduction) ar_metrics = self.ar_accuracy_metric( logits.detach(), targets ).item() * y_lens.sum().type(torch.float32) return ar_loss, ar_metrics def _forward_nar_decoder( self, x, x_lens, y, y_lens, codes, y_prompts_codes, x_mask, y_mask, reduction ): num_nar_layers = self.num_quantizers - 1 nar_stage = self.rng.choices( [_k for _k in range(1, self.num_quantizers)], weights=[1.0 / num_nar_layers] * num_nar_layers, k=1, )[0] x = self.nar_text_embedding(x) x = self.nar_text_prenet(x) x = self.nar_text_position(x) y_emb, prefix_len = self._prepare_prompts( y, y_lens, codes, nar_stage, y_prompts_codes ) y_len = y_lens.max() targets = codes[..., nar_stage] + self.audio_token_num * self.y_mask_int if self.prefix_mode in [2, 4]: xy_padding_mask = torch.concat( [ x_mask, F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), ], dim=1, ) elif self.prefix_mode == 1: targets = targets[:, prefix_len:] y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), src_key_padding_mask=self.xy_padding_mask, ) xy_dec = xy_dec[:, x_lens.max() + prefix_len :] if self.prefix_mode == 4: prefix_len = 0 logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1) total_length = (y_lens).sum().type(torch.float32) nar_loss = F.cross_entropy( logits, targets, ignore_index=self.audio_token_num, reduction=reduction, ) * (total_length / (total_length - prefix_len * x.shape[0])) nar_metrics = ( self.nar_accuracy_metric( F.pad( logits.detach(), (0, 0, 0, 1, 0, 0), value=logits.min().cpu().item(), ), targets, ).item() * total_length ) return nar_loss, nar_metrics def inference( self, x: torch.Tensor, x_lens: torch.Tensor, y: torch.Tensor, enroll_x_lens: torch.Tensor, top_k: int = -100, temperature: float = 1.0, ) -> torch.Tensor: """ Args: x: A 2-D tensor of shape (1, S). x_lens: A 1-D tensor of shape (1,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (1, T, 8). top_k: (`optional`) int The number of highest probability tokens to keep for top-k-filtering. Default to -100. temperature: (`optional`) float The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. Returns: Return the predicted audio code matrix. """ assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3, y.shape assert y.shape[0] == 1, y.shape assert torch.all(x_lens > 0) text = x x = self.ar_text_embedding(text) x = self.ar_text_prenet(x) x = self.ar_text_position(x) text_len = x_lens.max() prompts = y prefix_len = y.shape[1] # AR Decoder y = prompts[..., 0] if self.ar_audio_prepend_bos: y = F.pad(y, (1, 0), value=self.audio_token_num + 1) x_len = x_lens.max() x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) while True: y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_prenet(y_emb) y_pos = self.ar_audio_position(y_emb) xy_pos = torch.concat([x, y_pos], dim=1) y_len = y.shape[1] x_attn_mask_pad = F.pad( x_attn_mask, (0, y_len), value=True, ) y_attn_mask = F.pad( torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), value=False, ) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( y.device ) xy_dec, _ = self.ar_decoder( (xy_pos, None), mask=xy_attn_mask, ) logits = self.ar_predict_layer(xy_dec[:, -1]) samples = topk_sampling( logits, top_k=top_k, top_p=1.0, temperature=temperature ) if ( torch.argmax(logits, dim=-1)[0] == self.audio_token_num or samples[0, 0] == self.audio_token_num or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 ): if prompts.shape[1] == y.shape[1]: raise SyntaxError("well trained model shouldn't reach here.") break y = torch.concat([y, samples], dim=1) codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] if self.num_quantizers == 1: return torch.stack(codes, dim=-1) # Non-AR Decoders y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :]) if self.prefix_mode in [2, 4]: enrolled_len = enroll_x_lens.max().item() # SOS + Synthesis Text + EOS text = torch.concat( [ text[:, :1], text[:, enrolled_len - 1 :], ], dim=1, ) text_len = text_len - (enrolled_len - 2) assert text.shape[0] == 1 x = self.nar_text_embedding(text) x = self.nar_text_prenet(x) x = self.nar_text_position(x) if self.prefix_mode == 0: for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < self.num_quantizers - 2: y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) y_emb[:, prefix_len:] += embedding_layer(samples) else: for j in range(1, self.num_quantizers): y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < self.num_quantizers - 2: y_emb[:, prefix_len:] += embedding_layer(samples) assert len(codes) == self.num_quantizers return torch.stack(codes, dim=-1) def continual( self, x: torch.Tensor, x_lens: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: """ Args: x: A 2-D tensor of shape (1, S). x_lens: A 1-D tensor of shape (1,). It contains the number of tokens in `x` before padding. y: A 3-D tensor of shape (1, T, 8). Returns: Return the predicted audio code matrix. """ assert x.ndim == 2, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.ndim == 3, y.shape assert y.shape[0] == 1, y.shape assert torch.all(x_lens > 0) assert self.num_quantizers == 8 text = x x = self.ar_text_embedding(text) x = self.ar_text_prenet(x) x = self.ar_text_position(x) text_len = x_lens.max() prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) # AR Decoder prompts = y[:, :prefix_len] codes = [y[:, prefix_len:, 0]] # Non-AR Decoders x = self.nar_text_embedding(text) x = self.nar_text_prenet(x) x = self.nar_text_position(x) y_emb = self.nar_audio_embeddings[0](y[..., 0]) if self.prefix_mode == 0: for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_position(y_emb) y_pos = self.nar_audio_prenet(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < 6: y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) y_emb[:, prefix_len:] += embedding_layer(samples) else: for j in range(1, 8): y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) for i, (predict_layer, embedding_layer) in enumerate( zip( self.nar_predict_layers, self.nar_audio_embeddings[1:], ) ): y_pos = self.nar_audio_prenet(y_emb) y_pos = self.nar_audio_position(y_pos) xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.nar_decoder( (xy_pos, self.nar_stage_embeddings[i].weight) ) logits = predict_layer(xy_dec[:, text_len + prefix_len :]) samples = torch.argmax(logits, dim=-1) codes.append(samples) if i < 6: y_emb[:, prefix_len:] += embedding_layer(samples) assert len(codes) == 8 return torch.stack(codes, dim=-1) def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: assert stage > 0 if stage == 1: for name, param in self.named_parameters(): if name.startswith("ar_"): yield param if stage == 2: for name, param in self.named_parameters(): if name.startswith("nar_"): yield param def stage_named_parameters( self, stage: int = 1 ) -> Iterator[Tuple[str, nn.Parameter]]: assert stage > 0 if stage == 1: for pair in self.named_parameters(): if pair[0].startswith("ar_"): yield pair if stage == 2: for pair in self.named_parameters(): if pair[0].startswith("nar_"): yield pair def pad_y_eos(self, y, y_mask_int, eos_id): targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( y_mask_int, (0, 1), value=1 ) if self.ar_audio_prepend_bos: return ( F.pad(targets[:, :-1], (1, 0), value=self.audio_token_num + 1), targets, ) return targets[:, :-1], targets[:, 1:] def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds # from the same utterance. # We implement this differently. if self.prefix_mode == 0: # no prefix prefix_len = 0 y_emb = self.nar_audio_embeddings[0](y) for j in range(1, nar_stage): # Formula (4) (5) y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) elif self.prefix_mode == 1: # prefix at begining int_low = (0.25 * y_lens.min()).type(torch.int64).item() prefix_len = torch.randint(int_low, int_low * 2, size=()).item() prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) for j in range(1, self.num_quantizers): y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j]) if j < nar_stage: y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j]) y_emb = torch.concat([y_prompts, y_emb], axis=1) elif self.prefix_mode in [2, 4]: if self.prefix_mode == 2: # random prefix prefix_len = min(225, int(0.25 * y_lens.min().item())) y_prompts_codes = [] for b in range(codes.shape[0]): start = self.rng.randint(0, y_lens[b].item() - prefix_len) y_prompts_codes.append( torch.clone(codes[b, start : start + prefix_len]) ) codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS y_prompts_codes = torch.stack(y_prompts_codes, dim=0) else: prefix_len = y_prompts_codes.shape[1] y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) y_emb = self.nar_audio_embeddings[0](y) for j in range(1, self.num_quantizers): y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j]) if j < nar_stage: y_emb += self.nar_audio_embeddings[j](codes[..., j]) y_emb = torch.concat([y_prompts, y_emb], axis=1) else: raise ValueError return y_emb, prefix_len