RayeRen's picture
init
d1b91e7
raw
history blame
776 Bytes
import torch
import torch.nn.functional as F
def build_word_mask(x2word, y2word):
return (x2word[:, :, None] == y2word[:, None, :]).long()
def mel2ph_to_mel2word(mel2ph, ph2word):
mel2word = (ph2word - 1).gather(1, (mel2ph - 1).clamp(min=0)) + 1
mel2word = mel2word * (mel2ph > 0).long()
return mel2word
def clip_mel2token_to_multiple(mel2token, frames_multiple):
if mel2token.shape[1] % frames_multiple > 0:
max_frames = mel2token.shape[1] // frames_multiple * frames_multiple
mel2token = mel2token[:, :max_frames]
return mel2token
def expand_states(h, mel2token):
h = F.pad(h, [0, 0, 1, 0])
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]])
h = torch.gather(h, 1, mel2token_) # [B, T, H]
return h