|
|
|
|
|
|
|
|
|
|
|
from concurrent.futures import ALL_COMPLETED |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from torch.nn import functional as F |
|
from einops import rearrange, repeat |
|
|
|
from models.codec.amphion_codec.quantize import ResidualVQ |
|
from models.codec.kmeans.vocos import VocosBackbone |
|
|
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
nn.init.constant_(m.bias, 0) |
|
if isinstance(m, nn.Linear): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def compute_codebook_perplexity(indices, codebook_size): |
|
indices = indices.flatten() |
|
prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0) |
|
perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10))) |
|
return perp |
|
|
|
|
|
class RepCodec(nn.Module): |
|
def __init__( |
|
self, |
|
codebook_size=8192, |
|
hidden_size=1024, |
|
codebook_dim=8, |
|
vocos_dim=384, |
|
vocos_intermediate_dim=2048, |
|
vocos_num_layers=12, |
|
num_quantizers=1, |
|
downsample_scale=1, |
|
cfg=None, |
|
): |
|
super().__init__() |
|
codebook_size = ( |
|
cfg.codebook_size |
|
if cfg is not None and hasattr(cfg, "codebook_size") |
|
else codebook_size |
|
) |
|
codebook_dim = ( |
|
cfg.codebook_dim |
|
if cfg is not None and hasattr(cfg, "codebook_dim") |
|
else codebook_dim |
|
) |
|
hidden_size = ( |
|
cfg.hidden_size |
|
if cfg is not None and hasattr(cfg, "hidden_size") |
|
else hidden_size |
|
) |
|
vocos_dim = ( |
|
cfg.vocos_dim |
|
if cfg is not None and hasattr(cfg, "vocos_dim") |
|
else vocos_dim |
|
) |
|
vocos_intermediate_dim = ( |
|
cfg.vocos_intermediate_dim |
|
if cfg is not None and hasattr(cfg, "vocos_dim") |
|
else vocos_intermediate_dim |
|
) |
|
vocos_num_layers = ( |
|
cfg.vocos_num_layers |
|
if cfg is not None and hasattr(cfg, "vocos_dim") |
|
else vocos_num_layers |
|
) |
|
num_quantizers = ( |
|
cfg.num_quantizers |
|
if cfg is not None and hasattr(cfg, "num_quantizers") |
|
else num_quantizers |
|
) |
|
downsample_scale = ( |
|
cfg.downsample_scale |
|
if cfg is not None and hasattr(cfg, "downsample_scale") |
|
else downsample_scale |
|
) |
|
|
|
self.codebook_size = codebook_size |
|
self.codebook_dim = codebook_dim |
|
self.hidden_size = hidden_size |
|
self.vocos_dim = vocos_dim |
|
self.vocos_intermediate_dim = vocos_intermediate_dim |
|
self.vocos_num_layers = vocos_num_layers |
|
self.num_quantizers = num_quantizers |
|
self.downsample_scale = downsample_scale |
|
|
|
if self.downsample_scale != None and self.downsample_scale > 1: |
|
self.down = nn.Conv1d( |
|
self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1 |
|
) |
|
self.up = nn.Conv1d( |
|
self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1 |
|
) |
|
|
|
self.encoder = nn.Sequential( |
|
VocosBackbone( |
|
input_channels=self.hidden_size, |
|
dim=self.vocos_dim, |
|
intermediate_dim=self.vocos_intermediate_dim, |
|
num_layers=self.vocos_num_layers, |
|
adanorm_num_embeddings=None, |
|
), |
|
nn.Linear(self.vocos_dim, self.hidden_size), |
|
) |
|
self.decoder = nn.Sequential( |
|
VocosBackbone( |
|
input_channels=self.hidden_size, |
|
dim=self.vocos_dim, |
|
intermediate_dim=self.vocos_intermediate_dim, |
|
num_layers=self.vocos_num_layers, |
|
adanorm_num_embeddings=None, |
|
), |
|
nn.Linear(self.vocos_dim, self.hidden_size), |
|
) |
|
|
|
self.quantizer = ResidualVQ( |
|
input_dim=hidden_size, |
|
num_quantizers=num_quantizers, |
|
codebook_size=codebook_size, |
|
codebook_dim=codebook_dim, |
|
quantizer_type="fvq", |
|
quantizer_dropout=0.0, |
|
commitment=0.15, |
|
codebook_loss_weight=1.0, |
|
use_l2_normlize=True, |
|
) |
|
|
|
self.reset_parameters() |
|
|
|
def forward(self, x): |
|
|
|
|
|
if self.downsample_scale != None and self.downsample_scale > 1: |
|
x = x.transpose(1, 2) |
|
x = self.down(x) |
|
x = F.gelu(x) |
|
x = x.transpose(1, 2) |
|
|
|
|
|
x = self.encoder(x.transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
( |
|
quantized_out, |
|
all_indices, |
|
all_commit_losses, |
|
all_codebook_losses, |
|
_, |
|
) = self.quantizer(x) |
|
|
|
|
|
x = self.decoder(quantized_out) |
|
|
|
|
|
if self.downsample_scale != None and self.downsample_scale > 1: |
|
x = x.transpose(1, 2) |
|
x = F.interpolate(x, scale_factor=2, mode="nearest") |
|
x_rec = self.up(x).transpose(1, 2) |
|
|
|
codebook_loss = (all_codebook_losses + all_commit_losses).mean() |
|
all_indices = all_indices |
|
|
|
return x_rec, codebook_loss, all_indices |
|
|
|
def quantize(self, x): |
|
|
|
if self.downsample_scale != None and self.downsample_scale > 1: |
|
x = x.transpose(1, 2) |
|
x = self.down(x) |
|
x = F.gelu(x) |
|
x = x.transpose(1, 2) |
|
|
|
x = self.encoder(x.transpose(1, 2)).transpose(1, 2) |
|
|
|
( |
|
quantized_out, |
|
all_indices, |
|
all_commit_losses, |
|
all_codebook_losses, |
|
_, |
|
) = self.quantizer(x) |
|
|
|
if all_indices.shape[0] == 1: |
|
return all_indices.squeeze(0), quantized_out.transpose(1, 2) |
|
return all_indices, quantized_out.transpose(1, 2) |
|
|
|
def reset_parameters(self): |
|
self.apply(init_weights) |
|
|
|
|
|
if __name__ == "__main__": |
|
repcodec = RepCodec(vocos_dim=1024, downsample_scale=2) |
|
print(repcodec) |
|
print(sum(p.numel() for p in repcodec.parameters()) / 1e6) |
|
x = torch.randn(5, 10, 1024) |
|
x_rec, codebook_loss, all_indices = repcodec(x) |
|
print(x_rec.shape, codebook_loss, all_indices.shape) |
|
vq_id, emb = repcodec.quantize(x) |
|
print(vq_id.shape, emb.shape) |
|
|