Spaces:
Sleeping
Sleeping
import logging | |
import math | |
import torch | |
import torch.nn as nn | |
logger = logging.getLogger(__name__) | |
def _fused_tanh_sigmoid(h): | |
a, b = h.chunk(2, dim=1) | |
h = a.tanh() * b.sigmoid() | |
return h | |
class WNLayer(nn.Module): | |
""" | |
A DiffWave-like WN | |
""" | |
def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation): | |
super().__init__() | |
local_output_dim = hidden_dim * 2 | |
if global_dim is not None: | |
self.gconv = nn.Conv1d(global_dim, hidden_dim, 1) | |
if local_dim is not None: | |
self.lconv = nn.Conv1d(local_dim, local_output_dim, 1) | |
self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same") | |
self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1) | |
def forward(self, z, l, g): | |
identity = z | |
if g is not None: | |
if g.dim() == 2: | |
g = g.unsqueeze(-1) | |
z = z + self.gconv(g) | |
z = self.dconv(z) | |
if l is not None: | |
z = z + self.lconv(l) | |
z = _fused_tanh_sigmoid(z) | |
h = self.out(z) | |
z, s = h.chunk(2, dim=1) | |
o = (z + identity) / math.sqrt(2) | |
return o, s | |
class WN(nn.Module): | |
def __init__( | |
self, | |
input_dim, | |
output_dim, | |
local_dim=None, | |
global_dim=None, | |
n_layers=30, | |
kernel_size=3, | |
dilation_cycle=5, | |
hidden_dim=512, | |
): | |
super().__init__() | |
assert kernel_size % 2 == 1 | |
assert hidden_dim % 2 == 0 | |
self.input_dim = input_dim | |
self.hidden_dim = hidden_dim | |
self.local_dim = local_dim | |
self.global_dim = global_dim | |
self.start = nn.Conv1d(input_dim, hidden_dim, 1) | |
if local_dim is not None: | |
self.local_norm = nn.InstanceNorm1d(local_dim) | |
self.layers = nn.ModuleList( | |
[ | |
WNLayer( | |
hidden_dim=hidden_dim, | |
local_dim=local_dim, | |
global_dim=global_dim, | |
kernel_size=kernel_size, | |
dilation=2 ** (i % dilation_cycle), | |
) | |
for i in range(n_layers) | |
] | |
) | |
self.end = nn.Conv1d(hidden_dim, output_dim, 1) | |
def forward(self, z, l=None, g=None): | |
""" | |
Args: | |
z: input (b c t) | |
l: local condition (b c t) | |
g: global condition (b d) | |
""" | |
z = self.start(z) | |
if l is not None: | |
l = self.local_norm(l) | |
# Skips | |
s_list = [] | |
for layer in self.layers: | |
z, s = layer(z, l, g) | |
s_list.append(s) | |
s_list = torch.stack(s_list, dim=0).sum(dim=0) | |
s_list = s_list / math.sqrt(len(self.layers)) | |
o = self.end(s_list) | |
return o | |
def summarize(self, length=100): | |
from ptflops import get_model_complexity_info | |
x = torch.randn(1, self.input_dim, length) | |
macs, params = get_model_complexity_info( | |
self, | |
(self.input_dim, length), | |
as_strings=True, | |
print_per_layer_stat=True, | |
verbose=True, | |
) | |
print(f"Input shape: {x.shape}") | |
print(f"Computational complexity: {macs}") | |
print(f"Number of parameters: {params}") | |
if __name__ == "__main__": | |
model = WN(input_dim=64, output_dim=64) | |
model.summarize() | |