Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import dgl | |
from dgl.nn.pytorch import GatedGraphConv | |
def sequence_mask(lengths, maxlen, dtype=torch.bool): | |
if maxlen is None: | |
maxlen = lengths.max() | |
mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() | |
mask.type(dtype) | |
return mask | |
def group_hidden_by_segs(h, seg_ids, max_len): | |
""" | |
:param h: [B, T, H] | |
:param seg_ids: [B, T] | |
:return: h_ph: [B, T_ph, H] | |
""" | |
B, T, H = h.shape | |
h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) | |
all_ones = h.new_ones(h.shape[:2]) | |
cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() | |
h_gby_segs = h_gby_segs[:, 1:] | |
cnt_gby_segs = cnt_gby_segs[:, 1:] | |
h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) | |
# assert h_gby_segs.shape[-1] == 192 | |
return h_gby_segs | |
class GraphAuxEnc(nn.Module): | |
def __init__(self, in_dim, hid_dim, out_dim, n_iterations=5, n_edge_types=6): | |
super(GraphAuxEnc, self).__init__() | |
self.in_dim = in_dim | |
self.hid_dim = hid_dim | |
self.out_dim = out_dim | |
self.skip_connect = True | |
self.dropout_after_gae = False | |
self.ggc_1 = GatedGraphConv(in_feats=in_dim, out_feats=hid_dim | |
, n_steps=n_iterations, n_etypes=n_edge_types) | |
self.ggc_2 = GatedGraphConv(in_feats=hid_dim, out_feats=out_dim | |
, n_steps=n_iterations, n_etypes=n_edge_types) | |
self.dropout = nn.Dropout(p=0.5) | |
def ph_encoding_to_word_encoding(ph_encoding, ph2word, word_len): | |
""" | |
ph_encoding: [batch, t_p, hid] | |
ph2word: tensor [batch, t_w] | |
word_len: tensor [batch] | |
""" | |
word_encoding_for_graph, batch_word_encoding, has_word_row_idx = GraphAuxEnc._process_ph_to_word_encoding( | |
ph_encoding, | |
ph2word, | |
word_len) | |
# [batch, t_w, hid] | |
return batch_word_encoding, word_encoding_for_graph | |
def pad_word_encoding_to_phoneme(self, word_encoding, ph2word, t_p): | |
return self._postprocess_word2ph(word_encoding, ph2word, t_p) | |
def _process_ph_to_word_encoding(ph_encoding, ph2word, word_len=None): | |
""" | |
ph_encoding: [batch, t_p, hid] | |
ph2word: tensor [batch, t_w] | |
word_len: tensor [batch] | |
""" | |
word_len = word_len.reshape([-1,]) | |
max_len = max(word_len) | |
num_nodes = sum(word_len) | |
batch_word_encoding = group_hidden_by_segs(ph_encoding, ph2word, max_len) | |
bs, t_p, hid = batch_word_encoding.shape | |
has_word_mask = sequence_mask(word_len, max_len) # [batch, t_p, 1] | |
word_encoding = batch_word_encoding.reshape([bs * t_p, hid]) | |
has_word_row_idx = has_word_mask.reshape([-1]) | |
word_encoding = word_encoding[has_word_row_idx] | |
assert word_encoding.shape[0] == num_nodes | |
return word_encoding, batch_word_encoding, has_word_row_idx | |
def _postprocess_word2ph(word_encoding, ph2word, t_p): | |
word_encoding = F.pad(word_encoding,[0,0,1,0]) | |
ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]]) | |
out = torch.gather(word_encoding, 1, ph2word_) # [B, T, H] | |
return out | |
def _repeat_one_sequence(x, d, T): | |
"""Repeat each frame according to duration.""" | |
if d.sum() == 0: | |
d = d.fill_(1) | |
hid = x.shape[-1] | |
expanded_lst = [x_.repeat(int(d_), 1) for x_, d_ in zip(x, d) if d_ != 0] | |
expanded = torch.cat(expanded_lst, dim=0) | |
if T > expanded.shape[0]: | |
expanded = torch.cat([expanded, torch.zeros([T - expanded.shape[0], hid]).to(expanded.device)], dim=0) | |
return expanded | |
def word_forward(self, graph_lst, word_encoding, etypes_lst): | |
""" | |
word encoding in, word encoding out. | |
""" | |
batched_graph = dgl.batch(graph_lst) | |
inp = word_encoding | |
batched_etypes = torch.cat(etypes_lst) # [num_edges_in_batch, 1] | |
assert batched_graph.num_nodes() == inp.shape[0] | |
gcc1_out = self.ggc_1(batched_graph, inp, batched_etypes) | |
if self.dropout_after_gae: | |
gcc1_out = self.dropout(gcc1_out) | |
gcc2_out = self.ggc_2(batched_graph, gcc1_out, batched_etypes) # [num_nodes_in_batch, hin] | |
if self.dropout_after_gae: | |
gcc2_out = self.ggc_2(batched_graph, gcc2_out, batched_etypes) | |
if self.skip_connect: | |
assert self.in_dim == self.hid_dim and self.hid_dim == self.out_dim | |
gcc2_out = inp + gcc1_out + gcc2_out | |
word_len = torch.tensor([g.num_nodes() for g in graph_lst]).reshape([-1]) | |
max_len = max(word_len) | |
has_word_mask = sequence_mask(word_len, max_len) # [batch, t_p, 1] | |
has_word_row_idx = has_word_mask.reshape([-1]) | |
bs = len(graph_lst) | |
t_w = max([g.num_nodes() for g in graph_lst]) | |
hid = word_encoding.shape[-1] | |
output = torch.zeros([bs * t_w, hid]).to(gcc2_out.device) | |
output[has_word_row_idx] = gcc2_out | |
output = output.reshape([bs, t_w, hid]) | |
word_level_output = output | |
return torch.transpose(word_level_output, 1, 2) | |
def forward(self, graph_lst, ph_encoding, ph2word, etypes_lst, return_word_encoding=False): | |
""" | |
graph_lst: [list of dgl_graph] | |
ph_encoding: [batch, hid, t_p] | |
ph2word: [list of list[1,2,2,2,3,3,3]] | |
etypes_lst: [list of etypes]; etypes: torch.LongTensor | |
""" | |
t_p = ph_encoding.shape[-1] | |
ph_encoding = ph_encoding.transpose(1,2) # [batch, t_p, hid] | |
word_len = torch.tensor([g.num_nodes() for g in graph_lst]).reshape([-1]) | |
batched_graph = dgl.batch(graph_lst) | |
inp, batched_word_encoding, has_word_row_idx = self._process_ph_to_word_encoding(ph_encoding, ph2word, | |
word_len=word_len) # [num_nodes_in_batch, in_dim] | |
bs, t_w, hid = batched_word_encoding.shape | |
batched_etypes = torch.cat(etypes_lst) # [num_edges_in_batch, 1] | |
gcc1_out = self.ggc_1(batched_graph, inp, batched_etypes) | |
gcc2_out = self.ggc_2(batched_graph, gcc1_out, batched_etypes) # [num_nodes_in_batch, hin] | |
# skip connection | |
gcc2_out = inp + gcc1_out + gcc2_out # [n_nodes, hid] | |
output = torch.zeros([bs * t_w, hid]).to(gcc2_out.device) | |
output[has_word_row_idx] = gcc2_out | |
output = output.reshape([bs, t_w, hid]) | |
word_level_output = output | |
output = self._postprocess_word2ph(word_level_output, ph2word, t_p) # [batch, t_p, hid] | |
output = torch.transpose(output, 1, 2) | |
if return_word_encoding: | |
return output, torch.transpose(word_level_output, 1, 2) | |
else: | |
return output | |
if __name__ == '__main__': | |
# Unit Test for batching graphs | |
from modules.syntaspeech.syntactic_graph_buider import Sentence2GraphParser, plot_dgl_sentence_graph | |
parser = Sentence2GraphParser("en") | |
# Unit Test for English Graph Builder | |
text1 = "To be or not to be , that 's a question ." | |
text2 = "I love you . You love me . Mixue ice-scream and tea ." | |
graph1, etypes1 = parser.parse(text1) | |
graph2, etypes2 = parser.parse(text2) | |
batched_text = "<BOS> " + text1 + " <EOS>" + " " + "<BOS> " + text2 + " <EOS>" | |
batched_nodes = [graph1.num_nodes(), graph2.num_nodes()] | |
plot_dgl_sentence_graph(dgl.batch([graph1, graph2]), {i: w for i, w in enumerate(batched_text.split(" "))}) | |
etypes_lst = [etypes1, etypes2] | |
# Unit Test for Graph Encoder forward | |
in_feats = 4 | |
out_feats = 4 | |
enc = GraphAuxEnc(in_dim=in_feats, hid_dim=in_feats, out_dim=out_feats) | |
ph2word = torch.tensor([ | |
[1, 2, 3, 3, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], | |
[1, 2, 3, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] | |
]) | |
inp = torch.randn([2, in_feats, 17]) # [N_sentence, feat, ph_length] | |
graph_lst = [graph1, graph2] | |
out = enc(graph_lst, inp, ph2word, etypes_lst) | |
print(out.shape) # [N_sentence, feat, ph_length] | |