# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn.utils import weight_norm def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) class LookupFreeQuantize(nn.Module): def __init__( self, input_dim, codebook_size, codebook_dim, ): super().__init__() self.input_dim = input_dim self.codebook_size = codebook_size self.codebook_dim = codebook_dim assert 2**codebook_dim == codebook_size if self.input_dim != self.codebook_dim: self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) self.out_project = WNConv1d( self.codebook_dim, self.input_dim, kernel_size=1 ) else: self.in_project = nn.Identity() self.out_project = nn.Identity() def forward(self, z): z_e = self.in_project(z) z_e = F.sigmoid(z_e) z_q = z_e + (torch.round(z_e) - z_e).detach() z_q = self.out_project(z_q) commit_loss = torch.zeros(z.shape[0], device=z.device) codebook_loss = torch.zeros(z.shape[0], device=z.device) bits = ( 2 ** torch.arange(self.codebook_dim, device=z.device) .unsqueeze(0) .unsqueeze(-1) .long() ) # (1, d, 1) indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long() return z_q, commit_loss, codebook_loss, indices, z_e def vq2emb(self, vq, out_proj=True): emb = torch.zeros( vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device ) # (B, d, T) for i in range(self.codebook_dim): emb[:, i, :] = (vq % 2).float() vq = vq // 2 if out_proj: emb = self.out_project(emb) return emb