# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from concurrent.futures import ALL_COMPLETED import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange, repeat from models.codec.amphion_codec.quantize import ResidualVQ from models.codec.kmeans.vocos import VocosBackbone def init_weights(m): if isinstance(m, nn.Conv1d): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def compute_codebook_perplexity(indices, codebook_size): indices = indices.flatten() prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0) perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10))) return perp class RepCodec(nn.Module): def __init__( self, codebook_size=8192, hidden_size=1024, codebook_dim=8, vocos_dim=384, vocos_intermediate_dim=2048, vocos_num_layers=12, num_quantizers=1, downsample_scale=1, cfg=None, ): super().__init__() codebook_size = ( cfg.codebook_size if cfg is not None and hasattr(cfg, "codebook_size") else codebook_size ) codebook_dim = ( cfg.codebook_dim if cfg is not None and hasattr(cfg, "codebook_dim") else codebook_dim ) hidden_size = ( cfg.hidden_size if cfg is not None and hasattr(cfg, "hidden_size") else hidden_size ) vocos_dim = ( cfg.vocos_dim if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_dim ) vocos_intermediate_dim = ( cfg.vocos_intermediate_dim if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_intermediate_dim ) vocos_num_layers = ( cfg.vocos_num_layers if cfg is not None and hasattr(cfg, "vocos_dim") else vocos_num_layers ) num_quantizers = ( cfg.num_quantizers if cfg is not None and hasattr(cfg, "num_quantizers") else num_quantizers ) downsample_scale = ( cfg.downsample_scale if cfg is not None and hasattr(cfg, "downsample_scale") else downsample_scale ) self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.hidden_size = hidden_size self.vocos_dim = vocos_dim self.vocos_intermediate_dim = vocos_intermediate_dim self.vocos_num_layers = vocos_num_layers self.num_quantizers = num_quantizers self.downsample_scale = downsample_scale if self.downsample_scale != None and self.downsample_scale > 1: self.down = nn.Conv1d( self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1 ) self.up = nn.Conv1d( self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1 ) self.encoder = nn.Sequential( VocosBackbone( input_channels=self.hidden_size, dim=self.vocos_dim, intermediate_dim=self.vocos_intermediate_dim, num_layers=self.vocos_num_layers, adanorm_num_embeddings=None, ), nn.Linear(self.vocos_dim, self.hidden_size), ) self.decoder = nn.Sequential( VocosBackbone( input_channels=self.hidden_size, dim=self.vocos_dim, intermediate_dim=self.vocos_intermediate_dim, num_layers=self.vocos_num_layers, adanorm_num_embeddings=None, ), nn.Linear(self.vocos_dim, self.hidden_size), ) self.quantizer = ResidualVQ( input_dim=hidden_size, num_quantizers=num_quantizers, codebook_size=codebook_size, codebook_dim=codebook_dim, quantizer_type="fvq", quantizer_dropout=0.0, commitment=0.15, codebook_loss_weight=1.0, use_l2_normlize=True, ) self.reset_parameters() def forward(self, x): # downsample if self.downsample_scale != None and self.downsample_scale > 1: x = x.transpose(1, 2) x = self.down(x) x = F.gelu(x) x = x.transpose(1, 2) # encoder x = self.encoder(x.transpose(1, 2)).transpose(1, 2) # vq ( quantized_out, all_indices, all_commit_losses, all_codebook_losses, _, ) = self.quantizer(x) # decoder x = self.decoder(quantized_out) # up if self.downsample_scale != None and self.downsample_scale > 1: x = x.transpose(1, 2) x = F.interpolate(x, scale_factor=2, mode="nearest") x_rec = self.up(x).transpose(1, 2) codebook_loss = (all_codebook_losses + all_commit_losses).mean() all_indices = all_indices return x_rec, codebook_loss, all_indices def quantize(self, x): if self.downsample_scale != None and self.downsample_scale > 1: x = x.transpose(1, 2) x = self.down(x) x = F.gelu(x) x = x.transpose(1, 2) x = self.encoder(x.transpose(1, 2)).transpose(1, 2) ( quantized_out, all_indices, all_commit_losses, all_codebook_losses, _, ) = self.quantizer(x) if all_indices.shape[0] == 1: return all_indices.squeeze(0), quantized_out.transpose(1, 2) return all_indices, quantized_out.transpose(1, 2) def reset_parameters(self): self.apply(init_weights) if __name__ == "__main__": repcodec = RepCodec(vocos_dim=1024, downsample_scale=2) print(repcodec) print(sum(p.numel() for p in repcodec.parameters()) / 1e6) x = torch.randn(5, 10, 1024) x_rec, codebook_loss, all_indices = repcodec(x) print(x_rec.shape, codebook_loss, all_indices.shape) vq_id, emb = repcodec.quantize(x) print(vq_id.shape, emb.shape)