# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn.utils import weight_norm from models.codec.amphion_codec.quantize import ( ResidualVQ, VectorQuantize, FactorizedVectorQuantize, LookupFreeQuantize, ) from models.codec.amphion_codec.vocos import Vocos def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) # Scripting this brings model speed up 1.4x @torch.jit.script def snake(x, alpha): shape = x.shape x = x.reshape(shape[0], shape[1], -1) x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) x = x.reshape(shape) return x class Snake1d(nn.Module): def __init__(self, channels): super().__init__() self.alpha = nn.Parameter(torch.ones(1, channels, 1)) def forward(self, x): return snake(x, self.alpha) def init_weights(m): if isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) class ResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.block = nn.Sequential( Snake1d(dim), WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), Snake1d(dim), WNConv1d(dim, dim, kernel_size=1), ) def forward(self, x): y = self.block(x) pad = (x.shape[-1] - y.shape[-1]) // 2 if pad > 0: x = x[..., pad:-pad] return x + y class EncoderBlock(nn.Module): def __init__(self, dim: int = 16, stride: int = 1): super().__init__() self.block = nn.Sequential( ResidualUnit(dim // 2, dilation=1), ResidualUnit(dim // 2, dilation=3), ResidualUnit(dim // 2, dilation=9), Snake1d(dim // 2), WNConv1d( dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), ), ) def forward(self, x): return self.block(x) class CodecEncoder(nn.Module): def __init__( self, d_model: int = 64, up_ratios: list = [4, 5, 5, 6], out_channels: int = 256, use_tanh: bool = False, cfg=None, ): super().__init__() d_model = cfg.d_model if cfg is not None else d_model up_ratios = cfg.up_ratios if cfg is not None else up_ratios out_channels = cfg.out_channels if cfg is not None else out_channels use_tanh = cfg.use_tanh if cfg is not None else use_tanh # Create first convolution self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] # Create EncoderBlocks that double channels as they downsample by `stride` for stride in up_ratios: d_model *= 2 self.block += [EncoderBlock(d_model, stride=stride)] # Create last convolution self.block += [ Snake1d(d_model), WNConv1d(d_model, out_channels, kernel_size=3, padding=1), ] if use_tanh: self.block += [nn.Tanh()] # Wrap black into nn.Sequential self.block = nn.Sequential(*self.block) self.enc_dim = d_model self.reset_parameters() def forward(self, x): return self.block(x) def reset_parameters(self): self.apply(init_weights) class DecoderBlock(nn.Module): def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): super().__init__() self.block = nn.Sequential( Snake1d(input_dim), WNConvTranspose1d( input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=stride // 2 + stride % 2, output_padding=stride % 2, ), ResidualUnit(output_dim, dilation=1), ResidualUnit(output_dim, dilation=3), ResidualUnit(output_dim, dilation=9), ) def forward(self, x): return self.block(x) class CodecDecoder(nn.Module): def __init__( self, in_channels: int = 256, upsample_initial_channel: int = 1536, up_ratios: list = [5, 5, 4, 2], num_quantizers: int = 8, codebook_size: int = 1024, codebook_dim: int = 256, quantizer_type: str = "vq", quantizer_dropout: float = 0.5, commitment: float = 0.25, codebook_loss_weight: float = 1.0, use_l2_normlize: bool = False, codebook_type: str = "euclidean", kmeans_init: bool = False, kmeans_iters: int = 10, decay: float = 0.8, eps: float = 1e-5, threshold_ema_dead_code: int = 2, weight_init: bool = False, use_vocos: bool = False, vocos_dim: int = 384, vocos_intermediate_dim: int = 1152, vocos_num_layers: int = 8, n_fft: int = 800, hop_size: int = 200, padding: str = "same", cfg=None, ): super().__init__() in_channels = ( cfg.in_channels if cfg is not None and hasattr(cfg, "in_channels") else in_channels ) upsample_initial_channel = ( cfg.upsample_initial_channel if cfg is not None and hasattr(cfg, "upsample_initial_channel") else upsample_initial_channel ) up_ratios = ( cfg.up_ratios if cfg is not None and hasattr(cfg, "up_ratios") else up_ratios ) num_quantizers = ( cfg.num_quantizers if cfg is not None and hasattr(cfg, "num_quantizers") else num_quantizers ) codebook_size = ( cfg.codebook_size if cfg is not None and hasattr(cfg, "codebook_size") else codebook_size ) codebook_dim = ( cfg.codebook_dim if cfg is not None and hasattr(cfg, "codebook_dim") else codebook_dim ) quantizer_type = ( cfg.quantizer_type if cfg is not None and hasattr(cfg, "quantizer_type") else quantizer_type ) quantizer_dropout = ( cfg.quantizer_dropout if cfg is not None and hasattr(cfg, "quantizer_dropout") else quantizer_dropout ) commitment = ( cfg.commitment if cfg is not None and hasattr(cfg, "commitment") else commitment ) codebook_loss_weight = ( cfg.codebook_loss_weight if cfg is not None and hasattr(cfg, "codebook_loss_weight") else codebook_loss_weight ) use_l2_normlize = ( cfg.use_l2_normlize if cfg is not None and hasattr(cfg, "use_l2_normlize") else use_l2_normlize ) codebook_type = ( cfg.codebook_type if cfg is not None and hasattr(cfg, "codebook_type") else codebook_type ) kmeans_init = ( cfg.kmeans_init if cfg is not None and hasattr(cfg, "kmeans_init") else kmeans_init ) kmeans_iters = ( cfg.kmeans_iters if cfg is not None and hasattr(cfg, "kmeans_iters") else kmeans_iters ) decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps threshold_ema_dead_code = ( cfg.threshold_ema_dead_code if cfg is not None and hasattr(cfg, "threshold_ema_dead_code") else threshold_ema_dead_code ) weight_init = ( cfg.weight_init if cfg is not None and hasattr(cfg, "weight_init") else weight_init ) use_vocos = ( cfg.use_vocos if cfg is not None and hasattr(cfg, "use_vocos") else use_vocos ) vocos_dim = ( cfg.vocos_dim if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_dim ) vocos_intermediate_dim = ( cfg.vocos_intermediate_dim if cfg is not None and hasattr(cfg, "vocos_intermediate_dim") else vocos_intermediate_dim ) vocos_num_layers = ( cfg.vocos_num_layers if cfg is not None and hasattr(cfg, "vocos_num_layers") else vocos_num_layers ) n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft hop_size = ( cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size ) padding = ( cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding ) if quantizer_type == "vq": self.quantizer = ResidualVQ( input_dim=in_channels, num_quantizers=num_quantizers, codebook_size=codebook_size, codebook_dim=codebook_dim, quantizer_type=quantizer_type, quantizer_dropout=quantizer_dropout, commitment=commitment, codebook_loss_weight=codebook_loss_weight, use_l2_normlize=use_l2_normlize, codebook_type=codebook_type, kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, decay=decay, eps=eps, threshold_ema_dead_code=threshold_ema_dead_code, weight_init=weight_init, ) elif quantizer_type == "fvq": self.quantizer = ResidualVQ( input_dim=in_channels, num_quantizers=num_quantizers, codebook_size=codebook_size, codebook_dim=codebook_dim, quantizer_type=quantizer_type, quantizer_dropout=quantizer_dropout, commitment=commitment, codebook_loss_weight=codebook_loss_weight, use_l2_normlize=use_l2_normlize, ) elif quantizer_type == "lfq": self.quantizer = ResidualVQ( input_dim=in_channels, num_quantizers=num_quantizers, codebook_size=codebook_size, codebook_dim=codebook_dim, quantizer_type=quantizer_type, ) else: raise ValueError(f"Unknown quantizer type {quantizer_type}") if not use_vocos: # Add first conv layer channels = upsample_initial_channel layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] # Add upsampling + MRF blocks for i, stride in enumerate(up_ratios): input_dim = channels // 2**i output_dim = channels // 2 ** (i + 1) layers += [DecoderBlock(input_dim, output_dim, stride)] # Add final conv layer layers += [ Snake1d(output_dim), WNConv1d(output_dim, 1, kernel_size=7, padding=3), nn.Tanh(), ] self.model = nn.Sequential(*layers) if use_vocos: self.model = Vocos( input_channels=in_channels, dim=vocos_dim, intermediate_dim=vocos_intermediate_dim, num_layers=vocos_num_layers, adanorm_num_embeddings=None, n_fft=n_fft, hop_size=hop_size, padding=padding, ) self.reset_parameters() def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None): """ if vq is True, x = encoder output, then return quantized output; else, x = quantized output, then return decoder output """ if vq is True: if eval_vq: self.quantizer.eval() ( quantized_out, all_indices, all_commit_losses, all_codebook_losses, all_quantized, ) = self.quantizer(x, n_quantizers=n_quantizers) return ( quantized_out, all_indices, all_commit_losses, all_codebook_losses, all_quantized, ) return self.model(x) def quantize(self, x, n_quantizers=None): self.quantizer.eval() quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers) return quantized_out, vq # TODO: check consistency of vq2emb and quantize def vq2emb(self, vq, n_quantizers=None): return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers) def decode(self, x): return self.model(x) def latent2dist(self, x, n_quantizers=None): return self.quantizer.latent2dist(x, n_quantizers=n_quantizers) def reset_parameters(self): self.apply(init_weights)