# code adapted from: https://github.com/Stability-AI/stable-audio-tools import torch from torch import nn from typing import Literal, Dict, Any import math import comfy.ops ops = comfy.ops.disable_weight_init def vae_sample(mean, scale): stdev = nn.functional.softplus(scale) + 1e-4 var = stdev * stdev logvar = torch.log(var) latents = torch.randn_like(mean) * stdev + mean kl = (mean * mean + var - logvar - 1).sum(1).mean() return latents, kl class VAEBottleneck(nn.Module): def __init__(self): super().__init__() self.is_discrete = False def encode(self, x, return_info=False, **kwargs): info = {} mean, scale = x.chunk(2, dim=1) x, kl = vae_sample(mean, scale) info["kl"] = kl if return_info: return x, info else: return x def decode(self, x): return x def snake_beta(x, alpha, beta): return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license class SnakeBeta(nn.Module): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): super(SnakeBeta, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) self.beta = nn.Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = nn.Parameter(torch.ones(in_features) * alpha) self.beta = nn.Parameter(torch.ones(in_features) * alpha) # self.alpha.requires_grad = alpha_trainable # self.beta.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = snake_beta(x, alpha, beta) return x def WNConv1d(*args, **kwargs): try: return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs)) except: return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older def WNConvTranspose1d(*args, **kwargs): try: return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) except: return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: if activation == "elu": act = torch.nn.ELU() elif activation == "snake": act = SnakeBeta(channels) elif activation == "none": act = torch.nn.Identity() else: raise ValueError(f"Unknown activation {activation}") if antialias: act = Activation1d(act) return act class ResidualUnit(nn.Module): def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): super().__init__() self.dilation = dilation padding = (dilation * (7-1)) // 2 self.layers = nn.Sequential( get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, dilation=dilation, padding=padding), get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), WNConv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1) ) def forward(self, x): res = x #x = checkpoint(self.layers, x) x = self.layers(x) return x + res class EncoderBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): super().__init__() self.layers = nn.Sequential( ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=1, use_snake=use_snake), ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=3, use_snake=use_snake), ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=9, use_snake=use_snake), get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), ) def forward(self, x): return self.layers(x) class DecoderBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): super().__init__() if use_nearest_upsample: upsample_layer = nn.Sequential( nn.Upsample(scale_factor=stride, mode="nearest"), WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=1, bias=False, padding='same') ) else: upsample_layer = WNConvTranspose1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) self.layers = nn.Sequential( get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), upsample_layer, ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=1, use_snake=use_snake), ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=3, use_snake=use_snake), ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=9, use_snake=use_snake), ) def forward(self, x): return self.layers(x) class OobleckEncoder(nn.Module): def __init__(self, in_channels=2, channels=128, latent_dim=32, c_mults = [1, 2, 4, 8], strides = [2, 4, 8, 8], use_snake=False, antialias_activation=False ): super().__init__() c_mults = [1] + c_mults self.depth = len(c_mults) layers = [ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) ] for i in range(self.depth-1): layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] layers += [ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class OobleckDecoder(nn.Module): def __init__(self, out_channels=2, channels=128, latent_dim=32, c_mults = [1, 2, 4, 8], strides = [2, 4, 8, 8], use_snake=False, antialias_activation=False, use_nearest_upsample=False, final_tanh=True): super().__init__() c_mults = [1] + c_mults self.depth = len(c_mults) layers = [ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), ] for i in range(self.depth-1, 0, -1): layers += [DecoderBlock( in_channels=c_mults[i]*channels, out_channels=c_mults[i-1]*channels, stride=strides[i-1], use_snake=use_snake, antialias_activation=antialias_activation, use_nearest_upsample=use_nearest_upsample ) ] layers += [ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), nn.Tanh() if final_tanh else nn.Identity() ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class AudioOobleckVAE(nn.Module): def __init__(self, in_channels=2, channels=128, latent_dim=64, c_mults = [1, 2, 4, 8, 16], strides = [2, 4, 4, 8, 8], use_snake=True, antialias_activation=False, use_nearest_upsample=False, final_tanh=False): super().__init__() self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation) self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation, use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh) self.bottleneck = VAEBottleneck() def encode(self, x): return self.bottleneck.encode(self.encoder(x)) def decode(self, x): return self.decoder(self.bottleneck.decode(x))