maskgct / models /codec /kmeans /repcodec_model.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from concurrent.futures import ALL_COMPLETED
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat
from models.codec.amphion_codec.quantize import ResidualVQ
from models.codec.kmeans.vocos import VocosBackbone
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def compute_codebook_perplexity(indices, codebook_size):
indices = indices.flatten()
prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
return perp
class RepCodec(nn.Module):
def __init__(
self,
codebook_size=8192,
hidden_size=1024,
codebook_dim=8,
vocos_dim=384,
vocos_intermediate_dim=2048,
vocos_num_layers=12,
num_quantizers=1,
downsample_scale=1,
cfg=None,
):
super().__init__()
codebook_size = (
cfg.codebook_size
if cfg is not None and hasattr(cfg, "codebook_size")
else codebook_size
)
codebook_dim = (
cfg.codebook_dim
if cfg is not None and hasattr(cfg, "codebook_dim")
else codebook_dim
)
hidden_size = (
cfg.hidden_size
if cfg is not None and hasattr(cfg, "hidden_size")
else hidden_size
)
vocos_dim = (
cfg.vocos_dim
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_dim
)
vocos_intermediate_dim = (
cfg.vocos_intermediate_dim
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_intermediate_dim
)
vocos_num_layers = (
cfg.vocos_num_layers
if cfg is not None and hasattr(cfg, "vocos_dim")
else vocos_num_layers
)
num_quantizers = (
cfg.num_quantizers
if cfg is not None and hasattr(cfg, "num_quantizers")
else num_quantizers
)
downsample_scale = (
cfg.downsample_scale
if cfg is not None and hasattr(cfg, "downsample_scale")
else downsample_scale
)
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.hidden_size = hidden_size
self.vocos_dim = vocos_dim
self.vocos_intermediate_dim = vocos_intermediate_dim
self.vocos_num_layers = vocos_num_layers
self.num_quantizers = num_quantizers
self.downsample_scale = downsample_scale
if self.downsample_scale != None and self.downsample_scale > 1:
self.down = nn.Conv1d(
self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
)
self.up = nn.Conv1d(
self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1
)
self.encoder = nn.Sequential(
VocosBackbone(
input_channels=self.hidden_size,
dim=self.vocos_dim,
intermediate_dim=self.vocos_intermediate_dim,
num_layers=self.vocos_num_layers,
adanorm_num_embeddings=None,
),
nn.Linear(self.vocos_dim, self.hidden_size),
)
self.decoder = nn.Sequential(
VocosBackbone(
input_channels=self.hidden_size,
dim=self.vocos_dim,
intermediate_dim=self.vocos_intermediate_dim,
num_layers=self.vocos_num_layers,
adanorm_num_embeddings=None,
),
nn.Linear(self.vocos_dim, self.hidden_size),
)
self.quantizer = ResidualVQ(
input_dim=hidden_size,
num_quantizers=num_quantizers,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
quantizer_type="fvq",
quantizer_dropout=0.0,
commitment=0.15,
codebook_loss_weight=1.0,
use_l2_normlize=True,
)
self.reset_parameters()
def forward(self, x):
# downsample
if self.downsample_scale != None and self.downsample_scale > 1:
x = x.transpose(1, 2)
x = self.down(x)
x = F.gelu(x)
x = x.transpose(1, 2)
# encoder
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
# vq
(
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
_,
) = self.quantizer(x)
# decoder
x = self.decoder(quantized_out)
# up
if self.downsample_scale != None and self.downsample_scale > 1:
x = x.transpose(1, 2)
x = F.interpolate(x, scale_factor=2, mode="nearest")
x_rec = self.up(x).transpose(1, 2)
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
all_indices = all_indices
return x_rec, codebook_loss, all_indices
def quantize(self, x):
if self.downsample_scale != None and self.downsample_scale > 1:
x = x.transpose(1, 2)
x = self.down(x)
x = F.gelu(x)
x = x.transpose(1, 2)
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
(
quantized_out,
all_indices,
all_commit_losses,
all_codebook_losses,
_,
) = self.quantizer(x)
if all_indices.shape[0] == 1:
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
return all_indices, quantized_out.transpose(1, 2)
def reset_parameters(self):
self.apply(init_weights)
if __name__ == "__main__":
repcodec = RepCodec(vocos_dim=1024, downsample_scale=2)
print(repcodec)
print(sum(p.numel() for p in repcodec.parameters()) / 1e6)
x = torch.randn(5, 10, 1024)
x_rec, codebook_loss, all_indices = repcodec(x)
print(x_rec.shape, codebook_loss, all_indices.shape)
vq_id, emb = repcodec.quantize(x)
print(vq_id.shape, emb.shape)