|
from typing import Optional, List, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import Conv1d, ConvTranspose1d |
|
from torch.nn import functional as F |
|
from torch.nn.utils import remove_weight_norm, weight_norm |
|
|
|
from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE |
|
from .utils import call_weight_data_normal_if_Conv |
|
|
|
|
|
class Generator(torch.nn.Module): |
|
def __init__( |
|
self, |
|
initial_channel: int, |
|
resblock: str, |
|
resblock_kernel_sizes: List[int], |
|
resblock_dilation_sizes: List[List[int]], |
|
upsample_rates: List[int], |
|
upsample_initial_channel: int, |
|
upsample_kernel_sizes: List[int], |
|
gin_channels: int = 0, |
|
): |
|
super(Generator, self).__init__() |
|
self.num_kernels = len(resblock_kernel_sizes) |
|
self.num_upsamples = len(upsample_rates) |
|
|
|
self.conv_pre = Conv1d( |
|
initial_channel, upsample_initial_channel, 7, 1, padding=3 |
|
) |
|
|
|
self.ups = nn.ModuleList() |
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): |
|
self.ups.append( |
|
weight_norm( |
|
ConvTranspose1d( |
|
upsample_initial_channel // (2**i), |
|
upsample_initial_channel // (2 ** (i + 1)), |
|
k, |
|
u, |
|
padding=(k - u) // 2, |
|
) |
|
) |
|
) |
|
|
|
self.resblocks = nn.ModuleList() |
|
resblock_module = ResBlock1 if resblock == "1" else ResBlock2 |
|
for i in range(len(self.ups)): |
|
ch = upsample_initial_channel // (2 ** (i + 1)) |
|
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): |
|
self.resblocks.append(resblock_module(ch, k, d)) |
|
|
|
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) |
|
self.ups.apply(call_weight_data_normal_if_Conv) |
|
|
|
if gin_channels != 0: |
|
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) |
|
|
|
def __call__( |
|
self, |
|
x: torch.Tensor, |
|
g: Optional[torch.Tensor] = None, |
|
n_res: Optional[int] = None, |
|
) -> torch.Tensor: |
|
return super().__call__(x, g=g, n_res=n_res) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
g: Optional[torch.Tensor] = None, |
|
n_res: Optional[int] = None, |
|
): |
|
if n_res is not None: |
|
n = int(n_res) |
|
if n != x.shape[-1]: |
|
x = F.interpolate(x, size=n, mode="linear") |
|
|
|
x = self.conv_pre(x) |
|
if g is not None: |
|
x = x + self.cond(g) |
|
|
|
for i in range(self.num_upsamples): |
|
x = F.leaky_relu(x, LRELU_SLOPE) |
|
x = self.ups[i](x) |
|
n = i * self.num_kernels |
|
xs = self.resblocks[n](x) |
|
for j in range(1, self.num_kernels): |
|
xs += self.resblocks[n + j](x) |
|
x = xs / self.num_kernels |
|
|
|
x = F.leaky_relu(x) |
|
x = self.conv_post(x) |
|
x = torch.tanh(x) |
|
|
|
return x |
|
|
|
def __prepare_scriptable__(self): |
|
for l in self.ups: |
|
for hook in l._forward_pre_hooks.values(): |
|
|
|
|
|
|
|
|
|
if ( |
|
hook.__module__ == "torch.nn.utils.weight_norm" |
|
and hook.__class__.__name__ == "WeightNorm" |
|
): |
|
torch.nn.utils.remove_weight_norm(l) |
|
|
|
for l in self.resblocks: |
|
for hook in l._forward_pre_hooks.values(): |
|
if ( |
|
hook.__module__ == "torch.nn.utils.weight_norm" |
|
and hook.__class__.__name__ == "WeightNorm" |
|
): |
|
torch.nn.utils.remove_weight_norm(l) |
|
return self |
|
|
|
def remove_weight_norm(self): |
|
for l in self.ups: |
|
remove_weight_norm(l) |
|
for l in self.resblocks: |
|
l.remove_weight_norm() |
|
|
|
|
|
class SineGenerator(torch.nn.Module): |
|
"""Definition of sine generator |
|
SineGenerator(samp_rate, harmonic_num = 0, |
|
sine_amp = 0.1, noise_std = 0.003, |
|
voiced_threshold = 0, |
|
flag_for_pulse=False) |
|
samp_rate: sampling rate in Hz |
|
harmonic_num: number of harmonic overtones (default 0) |
|
sine_amp: amplitude of sine-wavefrom (default 0.1) |
|
noise_std: std of Gaussian noise (default 0.003) |
|
voiced_thoreshold: F0 threshold for U/V classification (default 0) |
|
flag_for_pulse: this SinGen is used inside PulseGen (default False) |
|
Note: when flag_for_pulse is True, the first time step of a voiced |
|
segment is always sin(torch.pi) or cos(0) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
samp_rate: int, |
|
harmonic_num: int = 0, |
|
sine_amp: float = 0.1, |
|
noise_std: float = 0.003, |
|
voiced_threshold: int = 0, |
|
): |
|
super(SineGenerator, self).__init__() |
|
self.sine_amp = sine_amp |
|
self.noise_std = noise_std |
|
self.harmonic_num = harmonic_num |
|
self.dim = harmonic_num + 1 |
|
self.sampling_rate = samp_rate |
|
self.voiced_threshold = voiced_threshold |
|
|
|
def __call__( |
|
self, f0: torch.Tensor, upp: int |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
return super().__call__(f0, upp) |
|
|
|
def forward( |
|
self, f0: torch.Tensor, upp: int |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""sine_tensor, uv = forward(f0) |
|
input F0: tensor(batchsize=1, length, dim=1) |
|
f0 for unvoiced steps should be 0 |
|
output sine_tensor: tensor(batchsize=1, length, dim) |
|
output uv: tensor(batchsize=1, length, 1) |
|
""" |
|
with torch.no_grad(): |
|
f0 = f0[:, None].transpose(1, 2) |
|
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) |
|
|
|
f0_buf[:, :, 0] = f0[:, :, 0] |
|
for idx in range(self.harmonic_num): |
|
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( |
|
idx + 2 |
|
) |
|
rad_values = ( |
|
f0_buf / self.sampling_rate |
|
) % 1 |
|
rand_ini = torch.rand( |
|
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device |
|
) |
|
rand_ini[:, 0] = 0 |
|
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini |
|
tmp_over_one = torch.cumsum( |
|
rad_values, 1 |
|
) |
|
tmp_over_one *= upp |
|
tmp_over_one: torch.Tensor = F.interpolate( |
|
tmp_over_one.transpose(2, 1), |
|
scale_factor=float(upp), |
|
mode="linear", |
|
align_corners=True, |
|
).transpose(2, 1) |
|
rad_values: torch.Tensor = F.interpolate( |
|
rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest" |
|
).transpose( |
|
2, 1 |
|
) |
|
tmp_over_one %= 1 |
|
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 |
|
cumsum_shift = torch.zeros_like(rad_values) |
|
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 |
|
sine_waves = torch.sin( |
|
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi |
|
) |
|
sine_waves = sine_waves * self.sine_amp |
|
uv = self._f02uv(f0) |
|
uv: torch.Tensor = F.interpolate( |
|
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest" |
|
).transpose(2, 1) |
|
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 |
|
noise = noise_amp * torch.randn_like(sine_waves) |
|
sine_waves = sine_waves * uv + noise |
|
return sine_waves, uv, noise |
|
|
|
def _f02uv(self, f0): |
|
|
|
uv = torch.ones_like(f0) |
|
uv = uv * (f0 > self.voiced_threshold) |
|
if uv.device.type == "privateuseone": |
|
uv = uv.float() |
|
return uv |
|
|