|
|
|
|
|
|
|
|
|
|
|
import math |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from torch.nn.utils import weight_norm |
|
|
|
from models.codec.amphion_codec.quantize import ( |
|
ResidualVQ, |
|
VectorQuantize, |
|
FactorizedVectorQuantize, |
|
LookupFreeQuantize, |
|
) |
|
|
|
from models.codec.amphion_codec.vocos import Vocos |
|
|
|
|
|
def WNConv1d(*args, **kwargs): |
|
return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs): |
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
|
|
|
@torch.jit.script |
|
def snake(x, alpha): |
|
shape = x.shape |
|
x = x.reshape(shape[0], shape[1], -1) |
|
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) |
|
x = x.reshape(shape) |
|
return x |
|
|
|
|
|
class Snake1d(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.alpha = nn.Parameter(torch.ones(1, channels, 1)) |
|
|
|
def forward(self, x): |
|
return snake(x, self.alpha) |
|
|
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
nn.init.constant_(m.bias, 0) |
|
if isinstance(m, nn.Linear): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
def __init__(self, dim: int = 16, dilation: int = 1): |
|
super().__init__() |
|
pad = ((7 - 1) * dilation) // 2 |
|
self.block = nn.Sequential( |
|
Snake1d(dim), |
|
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), |
|
Snake1d(dim), |
|
WNConv1d(dim, dim, kernel_size=1), |
|
) |
|
|
|
def forward(self, x): |
|
y = self.block(x) |
|
pad = (x.shape[-1] - y.shape[-1]) // 2 |
|
if pad > 0: |
|
x = x[..., pad:-pad] |
|
return x + y |
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
def __init__(self, dim: int = 16, stride: int = 1): |
|
super().__init__() |
|
self.block = nn.Sequential( |
|
ResidualUnit(dim // 2, dilation=1), |
|
ResidualUnit(dim // 2, dilation=3), |
|
ResidualUnit(dim // 2, dilation=9), |
|
Snake1d(dim // 2), |
|
WNConv1d( |
|
dim // 2, |
|
dim, |
|
kernel_size=2 * stride, |
|
stride=stride, |
|
padding=math.ceil(stride / 2), |
|
), |
|
) |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class CodecEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int = 64, |
|
up_ratios: list = [4, 5, 5, 6], |
|
out_channels: int = 256, |
|
use_tanh: bool = False, |
|
cfg=None, |
|
): |
|
super().__init__() |
|
|
|
d_model = cfg.d_model if cfg is not None else d_model |
|
up_ratios = cfg.up_ratios if cfg is not None else up_ratios |
|
out_channels = cfg.out_channels if cfg is not None else out_channels |
|
use_tanh = cfg.use_tanh if cfg is not None else use_tanh |
|
|
|
|
|
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] |
|
|
|
|
|
for stride in up_ratios: |
|
d_model *= 2 |
|
self.block += [EncoderBlock(d_model, stride=stride)] |
|
|
|
|
|
self.block += [ |
|
Snake1d(d_model), |
|
WNConv1d(d_model, out_channels, kernel_size=3, padding=1), |
|
] |
|
|
|
if use_tanh: |
|
self.block += [nn.Tanh()] |
|
|
|
|
|
self.block = nn.Sequential(*self.block) |
|
self.enc_dim = d_model |
|
|
|
self.reset_parameters() |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
def reset_parameters(self): |
|
self.apply(init_weights) |
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): |
|
super().__init__() |
|
self.block = nn.Sequential( |
|
Snake1d(input_dim), |
|
WNConvTranspose1d( |
|
input_dim, |
|
output_dim, |
|
kernel_size=2 * stride, |
|
stride=stride, |
|
padding=stride // 2 + stride % 2, |
|
output_padding=stride % 2, |
|
), |
|
ResidualUnit(output_dim, dilation=1), |
|
ResidualUnit(output_dim, dilation=3), |
|
ResidualUnit(output_dim, dilation=9), |
|
) |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class CodecDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int = 256, |
|
upsample_initial_channel: int = 1536, |
|
up_ratios: list = [5, 5, 4, 2], |
|
num_quantizers: int = 8, |
|
codebook_size: int = 1024, |
|
codebook_dim: int = 256, |
|
quantizer_type: str = "vq", |
|
quantizer_dropout: float = 0.5, |
|
commitment: float = 0.25, |
|
codebook_loss_weight: float = 1.0, |
|
use_l2_normlize: bool = False, |
|
codebook_type: str = "euclidean", |
|
kmeans_init: bool = False, |
|
kmeans_iters: int = 10, |
|
decay: float = 0.8, |
|
eps: float = 1e-5, |
|
threshold_ema_dead_code: int = 2, |
|
weight_init: bool = False, |
|
use_vocos: bool = False, |
|
vocos_dim: int = 384, |
|
vocos_intermediate_dim: int = 1152, |
|
vocos_num_layers: int = 8, |
|
n_fft: int = 800, |
|
hop_size: int = 200, |
|
padding: str = "same", |
|
cfg=None, |
|
): |
|
super().__init__() |
|
|
|
in_channels = ( |
|
cfg.in_channels |
|
if cfg is not None and hasattr(cfg, "in_channels") |
|
else in_channels |
|
) |
|
upsample_initial_channel = ( |
|
cfg.upsample_initial_channel |
|
if cfg is not None and hasattr(cfg, "upsample_initial_channel") |
|
else upsample_initial_channel |
|
) |
|
up_ratios = ( |
|
cfg.up_ratios |
|
if cfg is not None and hasattr(cfg, "up_ratios") |
|
else up_ratios |
|
) |
|
num_quantizers = ( |
|
cfg.num_quantizers |
|
if cfg is not None and hasattr(cfg, "num_quantizers") |
|
else num_quantizers |
|
) |
|
codebook_size = ( |
|
cfg.codebook_size |
|
if cfg is not None and hasattr(cfg, "codebook_size") |
|
else codebook_size |
|
) |
|
codebook_dim = ( |
|
cfg.codebook_dim |
|
if cfg is not None and hasattr(cfg, "codebook_dim") |
|
else codebook_dim |
|
) |
|
quantizer_type = ( |
|
cfg.quantizer_type |
|
if cfg is not None and hasattr(cfg, "quantizer_type") |
|
else quantizer_type |
|
) |
|
quantizer_dropout = ( |
|
cfg.quantizer_dropout |
|
if cfg is not None and hasattr(cfg, "quantizer_dropout") |
|
else quantizer_dropout |
|
) |
|
commitment = ( |
|
cfg.commitment |
|
if cfg is not None and hasattr(cfg, "commitment") |
|
else commitment |
|
) |
|
codebook_loss_weight = ( |
|
cfg.codebook_loss_weight |
|
if cfg is not None and hasattr(cfg, "codebook_loss_weight") |
|
else codebook_loss_weight |
|
) |
|
use_l2_normlize = ( |
|
cfg.use_l2_normlize |
|
if cfg is not None and hasattr(cfg, "use_l2_normlize") |
|
else use_l2_normlize |
|
) |
|
codebook_type = ( |
|
cfg.codebook_type |
|
if cfg is not None and hasattr(cfg, "codebook_type") |
|
else codebook_type |
|
) |
|
kmeans_init = ( |
|
cfg.kmeans_init |
|
if cfg is not None and hasattr(cfg, "kmeans_init") |
|
else kmeans_init |
|
) |
|
kmeans_iters = ( |
|
cfg.kmeans_iters |
|
if cfg is not None and hasattr(cfg, "kmeans_iters") |
|
else kmeans_iters |
|
) |
|
decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay |
|
eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps |
|
threshold_ema_dead_code = ( |
|
cfg.threshold_ema_dead_code |
|
if cfg is not None and hasattr(cfg, "threshold_ema_dead_code") |
|
else threshold_ema_dead_code |
|
) |
|
weight_init = ( |
|
cfg.weight_init |
|
if cfg is not None and hasattr(cfg, "weight_init") |
|
else weight_init |
|
) |
|
use_vocos = ( |
|
cfg.use_vocos |
|
if cfg is not None and hasattr(cfg, "use_vocos") |
|
else use_vocos |
|
) |
|
vocos_dim = ( |
|
cfg.vocos_dim |
|
if cfg is not None and hasattr(cfg, "vocos_dim") |
|
else vocos_dim |
|
) |
|
vocos_intermediate_dim = ( |
|
cfg.vocos_intermediate_dim |
|
if cfg is not None and hasattr(cfg, "vocos_intermediate_dim") |
|
else vocos_intermediate_dim |
|
) |
|
vocos_num_layers = ( |
|
cfg.vocos_num_layers |
|
if cfg is not None and hasattr(cfg, "vocos_num_layers") |
|
else vocos_num_layers |
|
) |
|
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft |
|
hop_size = ( |
|
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size |
|
) |
|
padding = ( |
|
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding |
|
) |
|
|
|
if quantizer_type == "vq": |
|
self.quantizer = ResidualVQ( |
|
input_dim=in_channels, |
|
num_quantizers=num_quantizers, |
|
codebook_size=codebook_size, |
|
codebook_dim=codebook_dim, |
|
quantizer_type=quantizer_type, |
|
quantizer_dropout=quantizer_dropout, |
|
commitment=commitment, |
|
codebook_loss_weight=codebook_loss_weight, |
|
use_l2_normlize=use_l2_normlize, |
|
codebook_type=codebook_type, |
|
kmeans_init=kmeans_init, |
|
kmeans_iters=kmeans_iters, |
|
decay=decay, |
|
eps=eps, |
|
threshold_ema_dead_code=threshold_ema_dead_code, |
|
weight_init=weight_init, |
|
) |
|
elif quantizer_type == "fvq": |
|
self.quantizer = ResidualVQ( |
|
input_dim=in_channels, |
|
num_quantizers=num_quantizers, |
|
codebook_size=codebook_size, |
|
codebook_dim=codebook_dim, |
|
quantizer_type=quantizer_type, |
|
quantizer_dropout=quantizer_dropout, |
|
commitment=commitment, |
|
codebook_loss_weight=codebook_loss_weight, |
|
use_l2_normlize=use_l2_normlize, |
|
) |
|
elif quantizer_type == "lfq": |
|
self.quantizer = ResidualVQ( |
|
input_dim=in_channels, |
|
num_quantizers=num_quantizers, |
|
codebook_size=codebook_size, |
|
codebook_dim=codebook_dim, |
|
quantizer_type=quantizer_type, |
|
) |
|
else: |
|
raise ValueError(f"Unknown quantizer type {quantizer_type}") |
|
|
|
if not use_vocos: |
|
|
|
channels = upsample_initial_channel |
|
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] |
|
|
|
|
|
for i, stride in enumerate(up_ratios): |
|
input_dim = channels // 2**i |
|
output_dim = channels // 2 ** (i + 1) |
|
layers += [DecoderBlock(input_dim, output_dim, stride)] |
|
|
|
|
|
layers += [ |
|
Snake1d(output_dim), |
|
WNConv1d(output_dim, 1, kernel_size=7, padding=3), |
|
nn.Tanh(), |
|
] |
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
if use_vocos: |
|
self.model = Vocos( |
|
input_channels=in_channels, |
|
dim=vocos_dim, |
|
intermediate_dim=vocos_intermediate_dim, |
|
num_layers=vocos_num_layers, |
|
adanorm_num_embeddings=None, |
|
n_fft=n_fft, |
|
hop_size=hop_size, |
|
padding=padding, |
|
) |
|
|
|
self.reset_parameters() |
|
|
|
def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None): |
|
""" |
|
if vq is True, x = encoder output, then return quantized output; |
|
else, x = quantized output, then return decoder output |
|
""" |
|
if vq is True: |
|
if eval_vq: |
|
self.quantizer.eval() |
|
( |
|
quantized_out, |
|
all_indices, |
|
all_commit_losses, |
|
all_codebook_losses, |
|
all_quantized, |
|
) = self.quantizer(x, n_quantizers=n_quantizers) |
|
return ( |
|
quantized_out, |
|
all_indices, |
|
all_commit_losses, |
|
all_codebook_losses, |
|
all_quantized, |
|
) |
|
|
|
return self.model(x) |
|
|
|
def quantize(self, x, n_quantizers=None): |
|
self.quantizer.eval() |
|
quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers) |
|
return quantized_out, vq |
|
|
|
|
|
def vq2emb(self, vq, n_quantizers=None): |
|
return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers) |
|
|
|
def decode(self, x): |
|
return self.model(x) |
|
|
|
def latent2dist(self, x, n_quantizers=None): |
|
return self.quantizer.latent2dist(x, n_quantizers=n_quantizers) |
|
|
|
def reset_parameters(self): |
|
self.apply(init_weights) |
|
|