lengyue233's picture
Init hf space integration
0a3525d verified
raw
history blame
No virus
6.47 kB
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
class Mish(nn.Module):
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class DiffusionEmbedding(nn.Module):
"""Diffusion Step Embedding"""
def __init__(self, d_denoiser):
super(DiffusionEmbedding, self).__init__()
self.dim = d_denoiser
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class LinearNorm(nn.Module):
"""LinearNorm Projection"""
def __init__(self, in_features, out_features, bias=False):
super(LinearNorm, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(self.linear.weight)
if bias:
nn.init.constant_(self.linear.bias, 0.0)
def forward(self, x):
x = self.linear(x)
return x
class ConvNorm(nn.Module):
"""1D Convolution"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
nn.init.kaiming_normal_(self.conv.weight)
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class ResidualBlock(nn.Module):
"""Residual Block"""
def __init__(
self,
residual_channels,
use_linear_bias=False,
dilation=1,
condition_channels=None,
):
super(ResidualBlock, self).__init__()
self.conv_layer = ConvNorm(
residual_channels,
2 * residual_channels,
kernel_size=3,
stride=1,
padding=dilation,
dilation=dilation,
)
if condition_channels is not None:
self.diffusion_projection = LinearNorm(
residual_channels, residual_channels, use_linear_bias
)
self.condition_projection = ConvNorm(
condition_channels, 2 * residual_channels, kernel_size=1
)
self.output_projection = ConvNorm(
residual_channels, 2 * residual_channels, kernel_size=1
)
def forward(self, x, condition=None, diffusion_step=None):
y = x
if diffusion_step is not None:
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
y = y + diffusion_step
y = self.conv_layer(y)
if condition is not None:
condition = self.condition_projection(condition)
y = y + condition
gate, filter = torch.chunk(y, 2, dim=1)
y = torch.sigmoid(gate) * torch.tanh(filter)
y = self.output_projection(y)
residual, skip = torch.chunk(y, 2, dim=1)
return (x + residual) / math.sqrt(2.0), skip
class WaveNet(nn.Module):
def __init__(
self,
input_channels: Optional[int] = None,
output_channels: Optional[int] = None,
residual_channels: int = 512,
residual_layers: int = 20,
dilation_cycle: Optional[int] = 4,
is_diffusion: bool = False,
condition_channels: Optional[int] = None,
):
super().__init__()
# Input projection
self.input_projection = None
if input_channels is not None and input_channels != residual_channels:
self.input_projection = ConvNorm(
input_channels, residual_channels, kernel_size=1
)
if input_channels is None:
input_channels = residual_channels
self.input_channels = input_channels
# Residual layers
self.residual_layers = nn.ModuleList(
[
ResidualBlock(
residual_channels=residual_channels,
use_linear_bias=False,
dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
condition_channels=condition_channels,
)
for i in range(residual_layers)
]
)
# Skip projection
self.skip_projection = ConvNorm(
residual_channels, residual_channels, kernel_size=1
)
# Output projection
self.output_projection = None
if output_channels is not None and output_channels != residual_channels:
self.output_projection = ConvNorm(
residual_channels, output_channels, kernel_size=1
)
if is_diffusion:
self.diffusion_embedding = DiffusionEmbedding(residual_channels)
self.mlp = nn.Sequential(
LinearNorm(residual_channels, residual_channels * 4, False),
Mish(),
LinearNorm(residual_channels * 4, residual_channels, False),
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if getattr(m, "bias", None) is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, t=None, condition=None):
if self.input_projection is not None:
x = self.input_projection(x)
x = F.silu(x)
if t is not None:
t = self.diffusion_embedding(t)
t = self.mlp(t)
skip = []
for layer in self.residual_layers:
x, skip_connection = layer(x, condition, t)
skip.append(skip_connection)
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
x = self.skip_projection(x)
if self.output_projection is not None:
x = F.silu(x)
x = self.output_projection(x)
return x