Spaces:
Build error
Build error
import torch | |
import math | |
from argparse import Namespace | |
from typing import Optional, List, Dict, Union | |
from tqdm import tqdm | |
from .Layer import Conv1d, Lambda | |
class Diffusion(torch.nn.Module): | |
def __init__( | |
self, | |
hyper_parameters: Namespace | |
): | |
super().__init__() | |
self.hp = hyper_parameters | |
if self.hp.Feature_Type == 'Mel': | |
self.feature_size = self.hp.Sound.Mel_Dim | |
elif self.hp.Feature_Type == 'Spectrogram': | |
self.feature_size = self.hp.Sound.N_FFT // 2 + 1 | |
self.denoiser = Denoiser( | |
hyper_parameters= self.hp | |
) | |
self.timesteps = self.hp.Diffusion.Max_Step | |
betas = torch.linspace(1e-4, 0.06, self.timesteps) | |
alphas = 1.0 - betas | |
alphas_cumprod = torch.cumprod(alphas, axis= 0) | |
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]]) | |
# calculations for diffusion q(x_t | x_{t-1}) and others | |
self.register_buffer('alphas_cumprod', alphas_cumprod) # [Diffusion_t] | |
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # [Diffusion_t] | |
self.register_buffer('sqrt_alphas_cumprod', alphas_cumprod.sqrt()) | |
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1.0 - alphas_cumprod).sqrt()) | |
self.register_buffer('sqrt_recip_alphas_cumprod', (1.0 / alphas_cumprod).sqrt()) | |
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1.0 / alphas_cumprod - 1.0).sqrt()) | |
# calculations for posterior q(x_{t-1} | x_t, x_0) | |
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) | |
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain | |
self.register_buffer('posterior_log_variance', torch.maximum(posterior_variance, torch.tensor([1e-20])).log()) | |
self.register_buffer('posterior_mean_coef1', betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod)) | |
self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod)) | |
def forward( | |
self, | |
encodings: torch.Tensor, | |
features: torch.Tensor= None | |
): | |
''' | |
encodings: [Batch, Enc_d, Enc_t] | |
features: [Batch, Feature_d, Feature_t] | |
feature_lengths: [Batch] | |
''' | |
if not features is None: # train | |
diffusion_steps = torch.randint( | |
low= 0, | |
high= self.timesteps, | |
size= (encodings.size(0),), | |
dtype= torch.long, | |
device= encodings.device | |
) # random single step | |
noises, epsilons = self.Get_Noise_Epsilon_for_Train( | |
features= features, | |
encodings= encodings, | |
diffusion_steps= diffusion_steps, | |
) | |
return None, noises, epsilons | |
else: # inference | |
features = self.Sampling( | |
encodings= encodings, | |
) | |
return features, None, None | |
def Sampling( | |
self, | |
encodings: torch.Tensor, | |
): | |
features = torch.randn( | |
size= (encodings.size(0), self.feature_size, encodings.size(2)), | |
device= encodings.device | |
) | |
for diffusion_step in reversed(range(self.timesteps)): | |
features = self.P_Sampling( | |
features= features, | |
encodings= encodings, | |
diffusion_steps= torch.full( | |
size= (encodings.size(0), ), | |
fill_value= diffusion_step, | |
dtype= torch.long, | |
device= encodings.device | |
), | |
) | |
return features | |
def P_Sampling( | |
self, | |
features: torch.Tensor, | |
encodings: torch.Tensor, | |
diffusion_steps: torch.Tensor, | |
): | |
posterior_means, posterior_log_variances = self.Get_Posterior( | |
features= features, | |
encodings= encodings, | |
diffusion_steps= diffusion_steps, | |
) | |
noises = torch.randn_like(features) # [Batch, Feature_d, Feature_d] | |
masks = (diffusion_steps > 0).float().unsqueeze(1).unsqueeze(1) #[Batch, 1, 1] | |
return posterior_means + masks * (0.5 * posterior_log_variances).exp() * noises | |
def Get_Posterior( | |
self, | |
features: torch.Tensor, | |
encodings: torch.Tensor, | |
diffusion_steps: torch.Tensor | |
): | |
noised_predictions = self.denoiser( | |
features= features, | |
encodings= encodings, | |
diffusion_steps= diffusion_steps | |
) | |
epsilons = \ | |
features * self.sqrt_recip_alphas_cumprod[diffusion_steps][:, None, None] - \ | |
noised_predictions * self.sqrt_recipm1_alphas_cumprod[diffusion_steps][:, None, None] | |
epsilons.clamp_(-1.0, 1.0) # clipped | |
posterior_means = \ | |
epsilons * self.posterior_mean_coef1[diffusion_steps][:, None, None] + \ | |
features * self.posterior_mean_coef2[diffusion_steps][:, None, None] | |
posterior_log_variances = \ | |
self.posterior_log_variance[diffusion_steps][:, None, None] | |
return posterior_means, posterior_log_variances | |
def Get_Noise_Epsilon_for_Train( | |
self, | |
features: torch.Tensor, | |
encodings: torch.Tensor, | |
diffusion_steps: torch.Tensor, | |
): | |
noises = torch.randn_like(features) | |
noised_features = \ | |
features * self.sqrt_alphas_cumprod[diffusion_steps][:, None, None] + \ | |
noises * self.sqrt_one_minus_alphas_cumprod[diffusion_steps][:, None, None] | |
epsilons = self.denoiser( | |
features= noised_features, | |
encodings= encodings, | |
diffusion_steps= diffusion_steps | |
) | |
return noises, epsilons | |
def DDIM( | |
self, | |
encodings: torch.Tensor, | |
ddim_steps: int, | |
eta: float= 0.0, | |
temperature: float= 1.0, | |
use_tqdm: bool= False | |
): | |
ddim_timesteps = self.Get_DDIM_Steps( | |
ddim_steps= ddim_steps | |
) | |
sigmas, alphas, alphas_prev = self.Get_DDIM_Sampling_Parameters( | |
ddim_timesteps= ddim_timesteps, | |
eta= eta | |
) | |
sqrt_one_minus_alphas = (1. - alphas).sqrt() | |
features = torch.randn( | |
size= (encodings.size(0), self.feature_size, encodings.size(2)), | |
device= encodings.device | |
) | |
setp_range = reversed(range(ddim_steps)) | |
if use_tqdm: | |
tqdm( | |
setp_range, | |
desc= '[Diffusion]', | |
total= ddim_steps | |
) | |
for diffusion_steps in setp_range: | |
noised_predictions = self.denoiser( | |
features= features, | |
encodings= encodings, | |
diffusion_steps= torch.full( | |
size= (encodings.size(0), ), | |
fill_value= diffusion_steps, | |
dtype= torch.long, | |
device= encodings.device | |
) | |
) | |
feature_starts = (features - sqrt_one_minus_alphas[diffusion_steps] * noised_predictions) / alphas[diffusion_steps].sqrt() | |
direction_pointings = (1.0 - alphas_prev[diffusion_steps] - sigmas[diffusion_steps].pow(2.0)) * noised_predictions | |
noises = sigmas[diffusion_steps] * torch.randn_like(features) * temperature | |
features = alphas_prev[diffusion_steps].sqrt() * feature_starts + direction_pointings + noises | |
return features | |
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py | |
def Get_DDIM_Steps( | |
self, | |
ddim_steps: int, | |
ddim_discr_method: str= 'uniform' | |
): | |
if ddim_discr_method == 'uniform': | |
ddim_timesteps = torch.arange(0, self.timesteps, self.timesteps // ddim_steps).long() | |
elif ddim_discr_method == 'quad': | |
ddim_timesteps = torch.linspace(0, (torch.tensor(self.timesteps) * 0.8).sqrt(), ddim_steps).pow(2.0).long() | |
else: | |
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') | |
ddim_timesteps[-1] = self.timesteps - 1 | |
return ddim_timesteps | |
def Get_DDIM_Sampling_Parameters(self, ddim_timesteps, eta): | |
alphas = self.alphas_cumprod[ddim_timesteps] | |
alphas_prev = self.alphas_cumprod_prev[ddim_timesteps] | |
sigmas = eta * ((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)).sqrt() | |
return sigmas, alphas, alphas_prev | |
class Denoiser(torch.nn.Module): | |
def __init__( | |
self, | |
hyper_parameters: Namespace | |
): | |
super().__init__() | |
self.hp = hyper_parameters | |
if self.hp.Feature_Type == 'Mel': | |
feature_size = self.hp.Sound.Mel_Dim | |
elif self.hp.Feature_Type == 'Spectrogram': | |
feature_size = self.hp.Sound.N_FFT // 2 + 1 | |
self.prenet = torch.nn.Sequential( | |
Conv1d( | |
in_channels= feature_size, | |
out_channels= self.hp.Diffusion.Size, | |
kernel_size= 1, | |
w_init_gain= 'relu' | |
), | |
torch.nn.Mish() | |
) | |
self.step_ffn = torch.nn.Sequential( | |
Diffusion_Embedding( | |
channels= self.hp.Diffusion.Size | |
), | |
Lambda(lambda x: x.unsqueeze(2)), | |
Conv1d( | |
in_channels= self.hp.Diffusion.Size, | |
out_channels= self.hp.Diffusion.Size * 4, | |
kernel_size= 1, | |
w_init_gain= 'relu' | |
), | |
torch.nn.Mish(), | |
Conv1d( | |
in_channels= self.hp.Diffusion.Size * 4, | |
out_channels= self.hp.Diffusion.Size, | |
kernel_size= 1, | |
w_init_gain= 'linear' | |
) | |
) | |
self.residual_blocks = torch.nn.ModuleList([ | |
Residual_Block( | |
in_channels= self.hp.Diffusion.Size, | |
kernel_size= self.hp.Diffusion.Kernel_Size, | |
condition_channels= self.hp.Encoder.Size + feature_size | |
) | |
for _ in range(self.hp.Diffusion.Stack) | |
]) | |
self.projection = torch.nn.Sequential( | |
Conv1d( | |
in_channels= self.hp.Diffusion.Size, | |
out_channels= self.hp.Diffusion.Size, | |
kernel_size= 1, | |
w_init_gain= 'relu' | |
), | |
torch.nn.ReLU(), | |
Conv1d( | |
in_channels= self.hp.Diffusion.Size, | |
out_channels= feature_size, | |
kernel_size= 1 | |
), | |
) | |
torch.nn.init.zeros_(self.projection[-1].weight) # This is key factor.... | |
def forward( | |
self, | |
features: torch.Tensor, | |
encodings: torch.Tensor, | |
diffusion_steps: torch.Tensor | |
): | |
''' | |
features: [Batch, Feature_d, Feature_t] | |
encodings: [Batch, Enc_d, Feature_t] | |
diffusion_steps: [Batch] | |
''' | |
x = self.prenet(features) | |
diffusion_steps = self.step_ffn(diffusion_steps) # [Batch, Res_d, 1] | |
skips_list = [] | |
for residual_block in self.residual_blocks: | |
x, skips = residual_block( | |
x= x, | |
conditions= encodings, | |
diffusion_steps= diffusion_steps | |
) | |
skips_list.append(skips) | |
x = torch.stack(skips_list, dim= 0).sum(dim= 0) / math.sqrt(self.hp.Diffusion.Stack) | |
x = self.projection(x) | |
return x | |
class Diffusion_Embedding(torch.nn.Module): | |
def __init__( | |
self, | |
channels: int | |
): | |
super().__init__() | |
self.channels = channels | |
def forward(self, x: torch.Tensor): | |
half_channels = self.channels // 2 # sine and cosine | |
embeddings = math.log(10000.0) / (half_channels - 1) | |
embeddings = torch.exp(torch.arange(half_channels, device= x.device) * -embeddings) | |
embeddings = x.unsqueeze(1) * embeddings.unsqueeze(0) | |
embeddings = torch.cat([embeddings.sin(), embeddings.cos()], dim= -1) | |
return embeddings | |
class Residual_Block(torch.nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
kernel_size: int, | |
condition_channels: int | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.condition = Conv1d( | |
in_channels= condition_channels, | |
out_channels= in_channels * 2, | |
kernel_size= 1 | |
) | |
self.diffusion_step = Conv1d( | |
in_channels= in_channels, | |
out_channels= in_channels, | |
kernel_size= 1 | |
) | |
self.conv = Conv1d( | |
in_channels= in_channels, | |
out_channels= in_channels * 2, | |
kernel_size= kernel_size, | |
padding= kernel_size // 2 | |
) | |
self.projection = Conv1d( | |
in_channels= in_channels, | |
out_channels= in_channels * 2, | |
kernel_size= 1 | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
conditions: torch.Tensor, | |
diffusion_steps: torch.Tensor | |
): | |
residuals = x | |
conditions = self.condition(conditions) | |
diffusion_steps = self.diffusion_step(diffusion_steps) | |
x = self.conv(x + diffusion_steps) + conditions | |
x_a, x_b = x.chunk(chunks= 2, dim= 1) | |
x = x_a.sigmoid() * x_b.tanh() | |
x = self.projection(x) | |
x, skips = x.chunk(chunks= 2, dim= 1) | |
return (x + residuals) / math.sqrt(2.0), skips |