Hilley's picture
Upload 9 files
509357a verified
raw
history blame
4.96 kB
import math
from einops import rearrange
from vector_quantize_pytorch import GroupedResidualFSQ
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvNeXtBlock(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
kernel, dilation,
layer_scale_init_value: float = 1e-6,
):
# ConvNeXt Block copied from Vocos.
super().__init__()
self.dwconv = nn.Conv1d(dim, dim,
kernel_size=kernel, padding=dilation*(kernel//2),
dilation=dilation, groups=dim
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class GFSQ(nn.Module):
def __init__(self,
dim, levels, G, R, eps=1e-5, transpose = True
):
super(GFSQ, self).__init__()
self.quantizer = GroupedResidualFSQ(
dim=dim,
levels=levels,
num_quantizers=R,
groups=G,
)
self.n_ind = math.prod(levels)
self.eps = eps
self.transpose = transpose
self.G = G
self.R = R
def _embed(self, x):
if self.transpose:
x = x.transpose(1,2)
x = rearrange(
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
)
feat = self.quantizer.get_output_from_indices(x)
return feat.transpose(1,2) if self.transpose else feat
def forward(self, x,):
if self.transpose:
x = x.transpose(1,2)
feat, ind = self.quantizer(x)
ind = rearrange(
ind, "g b t r ->b t (g r)",
)
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
e_mean = torch.mean(embed_onehot, dim=[0,1])
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
return (
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
feat.transpose(1,2) if self.transpose else feat,
perplexity,
None,
ind.transpose(1,2) if self.transpose else ind,
)
class DVAEDecoder(nn.Module):
def __init__(self, idim, odim,
n_layer = 12, bn_dim = 64, hidden = 256,
kernel = 7, dilation = 2, up = False
):
super().__init__()
self.up = up
self.conv_in = nn.Sequential(
nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
nn.Conv1d(bn_dim, hidden, 3, 1, 1)
)
self.decoder_block = nn.ModuleList([
ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
for _ in range(n_layer)])
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
def forward(self, input, conditioning=None):
# B, T, C
x = input.transpose(1, 2)
x = self.conv_in(x)
for f in self.decoder_block:
x = f(x, conditioning)
x = self.conv_out(x)
return x.transpose(1, 2)
class DVAE(nn.Module):
def __init__(
self, decoder_config, vq_config, dim=512
):
super().__init__()
self.register_buffer('coef', torch.randn(1, 100, 1))
self.decoder = DVAEDecoder(**decoder_config)
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
if vq_config is not None:
self.vq_layer = GFSQ(**vq_config)
else:
self.vq_layer = None
def forward(self, inp):
if self.vq_layer is not None:
vq_feats = self.vq_layer._embed(inp)
else:
vq_feats = inp.detach().clone()
temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
temp = torch.stack(temp, -1)
vq_feats = temp.reshape(*temp.shape[:2], -1)
vq_feats = vq_feats.transpose(1, 2)
dec_out = self.decoder(input=vq_feats)
dec_out = self.out_conv(dec_out.transpose(1, 2))
mel = dec_out * self.coef
return mel