TheComputerMan commited on
Commit
3edffb9
1 Parent(s): d927b86

Upload LengthRegulator.py

Browse files
Files changed (1) hide show
  1. LengthRegulator.py +62 -0
LengthRegulator.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ from abc import ABC
6
+
7
+ import torch
8
+
9
+ from Utility.utils import pad_list
10
+
11
+
12
+ class LengthRegulator(torch.nn.Module, ABC):
13
+ """
14
+ Length regulator module for feed-forward Transformer.
15
+
16
+ This is a module of length regulator described in
17
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
18
+ The length regulator expands char or
19
+ phoneme-level embedding features to frame-level by repeating each
20
+ feature based on the corresponding predicted durations.
21
+
22
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
23
+ https://arxiv.org/pdf/1905.09263.pdf
24
+
25
+ """
26
+
27
+ def __init__(self, pad_value=0.0):
28
+ """
29
+ Initialize length regulator module.
30
+
31
+ Args:
32
+ pad_value (float, optional): Value used for padding.
33
+ """
34
+ super(LengthRegulator, self).__init__()
35
+ self.pad_value = pad_value
36
+
37
+ def forward(self, xs, ds, alpha=1.0):
38
+ """
39
+ Calculate forward propagation.
40
+
41
+ Args:
42
+ xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
43
+ ds (LongTensor): Batch of durations of each frame (B, T).
44
+ alpha (float, optional): Alpha value to control speed of speech.
45
+
46
+ Returns:
47
+ Tensor: replicated input tensor based on durations (B, T*, D).
48
+ """
49
+ if alpha != 1.0:
50
+ assert alpha > 0
51
+ ds = torch.round(ds.float() * alpha).long()
52
+
53
+ if ds.sum() == 0:
54
+ ds[ds.sum(dim=1).eq(0)] = 1
55
+
56
+ return pad_list([self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)], self.pad_value)
57
+
58
+ def _repeat_one_sequence(self, x, d):
59
+ """
60
+ Repeat each frame according to duration
61
+ """
62
+ return torch.repeat_interleave(x, d, dim=0)