|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch import nn |
|
from torch.nn import functional as F |
|
import math |
|
from modules.flow.modules import * |
|
|
|
|
|
class StochasticDurationPredictor(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
filter_channels, |
|
kernel_size, |
|
p_dropout, |
|
n_flows=4, |
|
gin_channels=0, |
|
): |
|
super().__init__() |
|
filter_channels = in_channels |
|
self.in_channels = in_channels |
|
self.filter_channels = filter_channels |
|
self.kernel_size = kernel_size |
|
self.p_dropout = p_dropout |
|
self.n_flows = n_flows |
|
self.gin_channels = gin_channels |
|
|
|
self.log_flow = Log() |
|
self.flows = nn.ModuleList() |
|
self.flows.append(ElementwiseAffine(2)) |
|
for i in range(n_flows): |
|
self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) |
|
self.flows.append(Flip()) |
|
|
|
self.post_pre = nn.Conv1d(1, filter_channels, 1) |
|
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) |
|
self.post_convs = DDSConv( |
|
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout |
|
) |
|
self.post_flows = nn.ModuleList() |
|
self.post_flows.append(ElementwiseAffine(2)) |
|
for i in range(4): |
|
self.post_flows.append( |
|
ConvFlow(2, filter_channels, kernel_size, n_layers=3) |
|
) |
|
self.post_flows.append(Flip()) |
|
|
|
self.pre = nn.Conv1d(in_channels, filter_channels, 1) |
|
self.proj = nn.Conv1d(filter_channels, filter_channels, 1) |
|
self.convs = DDSConv( |
|
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout |
|
) |
|
if gin_channels != 0: |
|
self.cond = nn.Conv1d(gin_channels, filter_channels, 1) |
|
|
|
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): |
|
x = torch.detach(x) |
|
x = self.pre(x) |
|
if g is not None: |
|
g = torch.detach(g) |
|
x = x + self.cond(g) |
|
x = self.convs(x, x_mask) |
|
x = self.proj(x) * x_mask |
|
|
|
if not reverse: |
|
flows = self.flows |
|
assert w is not None |
|
|
|
logdet_tot_q = 0 |
|
h_w = self.post_pre(w) |
|
h_w = self.post_convs(h_w, x_mask) |
|
h_w = self.post_proj(h_w) * x_mask |
|
e_q = ( |
|
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) |
|
* x_mask |
|
) |
|
z_q = e_q |
|
for flow in self.post_flows: |
|
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) |
|
logdet_tot_q += logdet_q |
|
z_u, z1 = torch.split(z_q, [1, 1], 1) |
|
u = torch.sigmoid(z_u) * x_mask |
|
z0 = (w - u) * x_mask |
|
logdet_tot_q += torch.sum( |
|
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] |
|
) |
|
logq = ( |
|
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) |
|
- logdet_tot_q |
|
) |
|
|
|
logdet_tot = 0 |
|
z0, logdet = self.log_flow(z0, x_mask) |
|
logdet_tot += logdet |
|
z = torch.cat([z0, z1], 1) |
|
for flow in flows: |
|
z, logdet = flow(z, x_mask, g=x, reverse=reverse) |
|
logdet_tot = logdet_tot + logdet |
|
nll = ( |
|
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) |
|
- logdet_tot |
|
) |
|
return nll + logq |
|
else: |
|
flows = list(reversed(self.flows)) |
|
flows = flows[:-2] + [flows[-1]] |
|
z = ( |
|
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) |
|
* noise_scale |
|
) |
|
for flow in flows: |
|
z = flow(z, x_mask, g=x, reverse=reverse) |
|
z0, z1 = torch.split(z, [1, 1], 1) |
|
logw = z0 |
|
return logw |
|
|