maskgct / models /tts /jets /length_regulator.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
1.58 kB
# Copyright (c) 2024 Amphion.
#
# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/length_regulator.py
# Licensed under Apache License 2.0
import torch
class GaussianUpsampling(torch.nn.Module):
"""
Gaussian upsampling with fixed temperature as in:
https://arxiv.org/abs/2010.04301
"""
def __init__(self, delta=0.1):
super().__init__()
self.delta = delta
def forward(self, hs, ds, h_masks=None, d_masks=None):
"""
Args:
hs (Tensor): Batched hidden state to be expanded (B, T_text, adim)
ds (Tensor): Batched token duration (B, T_text)
h_masks (Tensor): Mask tensor (B,T_feats)
d_masks (Tensor): Mask tensor (B,T_text)
Returns:
Tensor: Expanded hidden state (B, T_feat, adim)
"""
B = ds.size(0)
device = ds.device
if h_masks is None:
T_feats = ds.sum().int()
else:
T_feats = h_masks.size(-1)
t = torch.arange(0, T_feats).unsqueeze(0).repeat(B, 1).to(device).float()
if h_masks is not None:
t = t * h_masks.float()
c = ds.cumsum(dim=-1) - ds / 2
energy = -1 * self.delta * (t.unsqueeze(-1) - c.unsqueeze(1)) ** 2
if d_masks is not None:
energy = energy.masked_fill(
~(d_masks.unsqueeze(1).repeat(1, T_feats, 1)), -float("inf")
)
p_attn = torch.softmax(energy, dim=2) # (B, T_feats, T_text)
hs = torch.matmul(p_attn, hs)
return hs