congnguyen3695 ntt123 commited on
Commit
ff36620
0 Parent(s):

Duplicate from ntt123/Vietnam-female-voice-TTS

Browse files

Co-authored-by: Thông Nguyễn <[email protected]>

Files changed (16) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. app.py +230 -0
  4. attentions.py +329 -0
  5. commons.py +162 -0
  6. config.json +72 -0
  7. duration_model.pth +3 -0
  8. flow.py +120 -0
  9. gen_210k.pth +3 -0
  10. gen_543k.pth +3 -0
  11. gen_630k.pth +3 -0
  12. models.py +489 -0
  13. modules.py +356 -0
  14. packages.txt +1 -0
  15. phone_set.json +1 -0
  16. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Vietnam Female Voice TTS
3
+ emoji: 👁
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.40.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-sa-4.0
11
+ duplicated_from: ntt123/Vietnam-female-voice-TTS
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch # isort:skip
2
+
3
+ torch.manual_seed(42)
4
+ import json
5
+ import re
6
+ import unicodedata
7
+ from types import SimpleNamespace
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import regex
12
+
13
+ from models import DurationNet, SynthesizerTrn
14
+
15
+ title = "LightSpeed: Vietnamese Female Voice TTS"
16
+ description = "Vietnam Female Voice TTS."
17
+ config_file = "config.json"
18
+ duration_model_path = "duration_model.pth"
19
+ lightspeed_model_path = "gen_630k.pth"
20
+ phone_set_file = "phone_set.json"
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ with open(config_file, "rb") as f:
23
+ hps = json.load(f, object_hook=lambda x: SimpleNamespace(**x))
24
+
25
+ # load phone set json file
26
+ with open(phone_set_file, "r") as f:
27
+ phone_set = json.load(f)
28
+
29
+ assert phone_set[0][1:-1] == "SEP"
30
+ assert "sil" in phone_set
31
+ sil_idx = phone_set.index("sil")
32
+
33
+ space_re = regex.compile(r"\s+")
34
+ number_re = regex.compile("([0-9]+)")
35
+ digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
36
+ num_re = regex.compile(r"([0-9.,]*[0-9])")
37
+ alphabet = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
38
+ keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]")
39
+ keep_text_re = regex.compile(rf"[^\s{alphabet}]")
40
+
41
+
42
+ def read_number(num: str) -> str:
43
+ if len(num) == 1:
44
+ return digits[int(num)]
45
+ elif len(num) == 2 and num.isdigit():
46
+ n = int(num)
47
+ end = digits[n % 10]
48
+ if n == 10:
49
+ return "mười"
50
+ if n % 10 == 5:
51
+ end = "lăm"
52
+ if n % 10 == 0:
53
+ return digits[n // 10] + " mươi"
54
+ elif n < 20:
55
+ return "mười " + end
56
+ else:
57
+ if n % 10 == 1:
58
+ end = "mốt"
59
+ return digits[n // 10] + " mươi " + end
60
+ elif len(num) == 3 and num.isdigit():
61
+ n = int(num)
62
+ if n % 100 == 0:
63
+ return digits[n // 100] + " trăm"
64
+ elif num[1] == "0":
65
+ return digits[n // 100] + " trăm lẻ " + digits[n % 100]
66
+ else:
67
+ return digits[n // 100] + " trăm " + read_number(num[1:])
68
+ elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
69
+ n = int(num)
70
+ n1 = n // 1000
71
+ return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
72
+ elif "," in num:
73
+ n1, n2 = num.split(",")
74
+ return read_number(n1) + " phẩy " + read_number(n2)
75
+ elif "." in num:
76
+ parts = num.split(".")
77
+ if len(parts) == 2:
78
+ if parts[1] == "000":
79
+ return read_number(parts[0]) + " ngàn"
80
+ elif parts[1].startswith("00"):
81
+ end = digits[int(parts[1][2:])]
82
+ return read_number(parts[0]) + " ngàn lẻ " + end
83
+ else:
84
+ return read_number(parts[0]) + " ngàn " + read_number(parts[1])
85
+ elif len(parts) == 3:
86
+ return (
87
+ read_number(parts[0])
88
+ + " triệu "
89
+ + read_number(parts[1])
90
+ + " ngàn "
91
+ + read_number(parts[2])
92
+ )
93
+ return num
94
+
95
+
96
+ def text_to_phone_idx(text):
97
+ # lowercase
98
+ text = text.lower()
99
+ # unicode normalize
100
+ text = unicodedata.normalize("NFKC", text)
101
+ text = text.replace(".", " . ")
102
+ text = text.replace(",", " , ")
103
+ text = text.replace(";", " ; ")
104
+ text = text.replace(":", " : ")
105
+ text = text.replace("!", " ! ")
106
+ text = text.replace("?", " ? ")
107
+ text = text.replace("(", " ( ")
108
+
109
+ text = num_re.sub(r" \1 ", text)
110
+ words = text.split()
111
+ words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
112
+ text = " ".join(words)
113
+
114
+ # remove redundant spaces
115
+ text = re.sub(r"\s+", " ", text)
116
+ # remove leading and trailing spaces
117
+ text = text.strip()
118
+ # convert words to phone indices
119
+ tokens = []
120
+ for c in text:
121
+ # if c is "," or ".", add <sil> phone
122
+ if c in ":,.!?;(":
123
+ tokens.append(sil_idx)
124
+ elif c in phone_set:
125
+ tokens.append(phone_set.index(c))
126
+ elif c == " ":
127
+ # add <sep> phone
128
+ tokens.append(0)
129
+ if tokens[0] != sil_idx:
130
+ # insert <sil> phone at the beginning
131
+ tokens = [sil_idx, 0] + tokens
132
+ if tokens[-1] != sil_idx:
133
+ tokens = tokens + [0, sil_idx]
134
+ return tokens
135
+
136
+
137
+ def text_to_speech(duration_net, generator, text):
138
+ # prevent too long text
139
+ if len(text) > 500:
140
+ text = text[:500]
141
+
142
+ phone_idx = text_to_phone_idx(text)
143
+ batch = {
144
+ "phone_idx": np.array([phone_idx]),
145
+ "phone_length": np.array([len(phone_idx)]),
146
+ }
147
+
148
+ # predict phoneme duration
149
+ phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
150
+ phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
151
+ with torch.inference_mode():
152
+ phone_duration = duration_net(phone_idx, phone_length)[:, :, 0] * 1000
153
+ phone_duration = torch.where(
154
+ phone_idx == sil_idx, torch.clamp_min(phone_duration, 200), phone_duration
155
+ )
156
+ phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
157
+
158
+ # generate waveform
159
+ end_time = torch.cumsum(phone_duration, dim=-1)
160
+ start_time = end_time - phone_duration
161
+ start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
162
+ end_frame = end_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
163
+ spec_length = end_frame.max(dim=-1).values
164
+ pos = torch.arange(0, spec_length.item(), device=device)
165
+ attn = torch.logical_and(
166
+ pos[None, :, None] >= start_frame[:, None, :],
167
+ pos[None, :, None] < end_frame[:, None, :],
168
+ ).float()
169
+ with torch.inference_mode():
170
+ y_hat = generator.infer(
171
+ phone_idx, phone_length, spec_length, attn, max_len=None, noise_scale=0.0
172
+ )[0]
173
+ wave = y_hat[0, 0].data.cpu().numpy()
174
+ return (wave * (2**15)).astype(np.int16)
175
+
176
+
177
+ def load_models():
178
+ duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
179
+ duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
180
+ duration_net = duration_net.eval()
181
+ generator = SynthesizerTrn(
182
+ hps.data.vocab_size,
183
+ hps.data.filter_length // 2 + 1,
184
+ hps.train.segment_size // hps.data.hop_length,
185
+ **vars(hps.model),
186
+ ).to(device)
187
+ del generator.enc_q
188
+ ckpt = torch.load(lightspeed_model_path, map_location=device)
189
+ params = {}
190
+ for k, v in ckpt["net_g"].items():
191
+ k = k[7:] if k.startswith("module.") else k
192
+ params[k] = v
193
+ generator.load_state_dict(params, strict=False)
194
+ del ckpt, params
195
+ generator = generator.eval()
196
+ return duration_net, generator
197
+
198
+
199
+ def speak(text):
200
+ duration_net, generator = load_models()
201
+ paragraphs = text.split("\n")
202
+ clips = [] # list of audio clips
203
+ # silence = np.zeros(hps.data.sampling_rate // 4)
204
+ for paragraph in paragraphs:
205
+ paragraph = paragraph.strip()
206
+ if paragraph == "":
207
+ continue
208
+ clips.append(text_to_speech(duration_net, generator, paragraph))
209
+ # clips.append(silence)
210
+ y = np.concatenate(clips)
211
+ return hps.data.sampling_rate, y
212
+
213
+
214
+ gr.Interface(
215
+ fn=speak,
216
+ inputs="text",
217
+ outputs="audio",
218
+ title=title,
219
+ examples=[
220
+ "Trăm năm trong cõi người ta, chữ tài chữ mệnh khéo là ghét nhau.",
221
+ "Đoạn trường tân thanh, thường được biết đến với cái tên đơn giản là Truyện Kiều, là một truyện thơ của đại thi hào Nguyễn Du",
222
+ "Lục Vân Tiên quê ở huyện Đông Thành, khôi ngô tuấn tú, tài kiêm văn võ. Nghe tin triều đình mở khoa thi, Vân Tiên từ giã thầy xuống núi đua tài.",
223
+ "Lê Quý Đôn, tên thuở nhỏ là Lê Danh Phương, là vị quan thời Lê trung hưng, cũng là nhà thơ và được mệnh danh là nhà bác học lớn của Việt Nam trong thời phong kiến",
224
+ "Tất cả mọi người đều sinh ra có quyền bình đẳng. Tạo hóa cho họ những quyền không ai có thể xâm phạm được; trong những quyền ấy, có quyền được sống, quyền tự do và quyền mưu cầu hạnh phúc.",
225
+ ],
226
+ description=description,
227
+ theme="default",
228
+ allow_screenshot=False,
229
+ allow_flagging="never",
230
+ ).launch(debug=False)
attentions.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ from modules import LayerNorm
9
+
10
+
11
+ class Encoder(nn.Module):
12
+ def __init__(
13
+ self,
14
+ hidden_channels,
15
+ filter_channels,
16
+ n_heads,
17
+ n_layers,
18
+ kernel_size=1,
19
+ p_dropout=0.0,
20
+ window_size=4,
21
+ **kwargs
22
+ ):
23
+ super().__init__()
24
+ self.hidden_channels = hidden_channels
25
+ self.filter_channels = filter_channels
26
+ self.n_heads = n_heads
27
+ self.n_layers = n_layers
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.window_size = window_size
31
+
32
+ self.drop = nn.Dropout(p_dropout)
33
+ self.attn_layers = nn.ModuleList()
34
+ self.norm_layers_1 = nn.ModuleList()
35
+ self.ffn_layers = nn.ModuleList()
36
+ self.norm_layers_2 = nn.ModuleList()
37
+ for i in range(self.n_layers):
38
+ self.attn_layers.append(
39
+ MultiHeadAttention(
40
+ hidden_channels,
41
+ hidden_channels,
42
+ n_heads,
43
+ p_dropout=p_dropout,
44
+ window_size=window_size,
45
+ )
46
+ )
47
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
48
+ self.ffn_layers.append(
49
+ FFN(
50
+ hidden_channels,
51
+ hidden_channels,
52
+ filter_channels,
53
+ kernel_size,
54
+ p_dropout=p_dropout,
55
+ )
56
+ )
57
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
58
+
59
+ def forward(self, x, x_mask):
60
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
61
+ x = x * x_mask
62
+ for i in range(self.n_layers):
63
+ y = self.attn_layers[i](x, x, attn_mask)
64
+ y = self.drop(y)
65
+ x = self.norm_layers_1[i](x + y)
66
+
67
+ y = self.ffn_layers[i](x, x_mask)
68
+ y = self.drop(y)
69
+ x = self.norm_layers_2[i](x + y)
70
+ x = x * x_mask
71
+ return x
72
+
73
+
74
+ class MultiHeadAttention(nn.Module):
75
+ def __init__(
76
+ self,
77
+ channels,
78
+ out_channels,
79
+ n_heads,
80
+ p_dropout=0.0,
81
+ window_size=None,
82
+ heads_share=True,
83
+ block_length=None,
84
+ proximal_bias=False,
85
+ proximal_init=False,
86
+ ):
87
+ super().__init__()
88
+ assert channels % n_heads == 0
89
+
90
+ self.channels = channels
91
+ self.out_channels = out_channels
92
+ self.n_heads = n_heads
93
+ self.p_dropout = p_dropout
94
+ self.window_size = window_size
95
+ self.heads_share = heads_share
96
+ self.block_length = block_length
97
+ self.proximal_bias = proximal_bias
98
+ self.proximal_init = proximal_init
99
+ # self.attn = None
100
+
101
+ self.k_channels = channels // n_heads
102
+ self.conv_q = nn.Conv1d(channels, channels, 1)
103
+ self.conv_k = nn.Conv1d(channels, channels, 1)
104
+ self.conv_v = nn.Conv1d(channels, channels, 1)
105
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
106
+ self.drop = nn.Dropout(p_dropout)
107
+
108
+ if window_size is not None:
109
+ n_heads_rel = 1 if heads_share else n_heads
110
+ rel_stddev = self.k_channels**-0.5
111
+ self.emb_rel_k = nn.Parameter(
112
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
113
+ * rel_stddev
114
+ )
115
+ self.emb_rel_v = nn.Parameter(
116
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
117
+ * rel_stddev
118
+ )
119
+
120
+ nn.init.xavier_uniform_(self.conv_q.weight)
121
+ nn.init.xavier_uniform_(self.conv_k.weight)
122
+ nn.init.xavier_uniform_(self.conv_v.weight)
123
+ if proximal_init:
124
+ with torch.no_grad():
125
+ self.conv_k.weight.copy_(self.conv_q.weight)
126
+ self.conv_k.bias.copy_(self.conv_q.bias)
127
+
128
+ def forward(self, x, c, attn_mask=None):
129
+ q = self.conv_q(x)
130
+ k = self.conv_k(c)
131
+ v = self.conv_v(c)
132
+
133
+ x, _ = self.attention(q, k, v, mask=attn_mask)
134
+
135
+ x = self.conv_o(x)
136
+ return x
137
+
138
+ def attention(self, query, key, value, mask=None):
139
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
140
+ b, d, t_s, t_t = (*key.size(), query.size(2))
141
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
142
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
143
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
144
+
145
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
146
+ if self.window_size is not None:
147
+ assert (
148
+ t_s == t_t
149
+ ), "Relative attention is only available for self-attention."
150
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
151
+ rel_logits = self._matmul_with_relative_keys(
152
+ query / math.sqrt(self.k_channels), key_relative_embeddings
153
+ )
154
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
155
+ scores = scores + scores_local
156
+ if self.proximal_bias:
157
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
158
+ scores = scores + self._attention_bias_proximal(t_s).to(
159
+ device=scores.device, dtype=scores.dtype
160
+ )
161
+ if mask is not None:
162
+ scores = scores.masked_fill(mask == 0, -1e4)
163
+ if self.block_length is not None:
164
+ assert (
165
+ t_s == t_t
166
+ ), "Local attention is only available for self-attention."
167
+ block_mask = (
168
+ torch.ones_like(scores)
169
+ .triu(-self.block_length)
170
+ .tril(self.block_length)
171
+ )
172
+ scores = scores.masked_fill(block_mask == 0, -1e4)
173
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
174
+ p_attn = self.drop(p_attn)
175
+ output = torch.matmul(p_attn, value)
176
+ if self.window_size is not None:
177
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
178
+ value_relative_embeddings = self._get_relative_embeddings(
179
+ self.emb_rel_v, t_s
180
+ )
181
+ output = output + self._matmul_with_relative_values(
182
+ relative_weights, value_relative_embeddings
183
+ )
184
+ output = (
185
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
186
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
187
+ return output, p_attn
188
+
189
+ def _matmul_with_relative_values(self, x, y):
190
+ """
191
+ x: [b, h, l, m]
192
+ y: [h or 1, m, d]
193
+ ret: [b, h, l, d]
194
+ """
195
+ ret = torch.matmul(x, y.unsqueeze(0))
196
+ return ret
197
+
198
+ def _matmul_with_relative_keys(self, x, y):
199
+ """
200
+ x: [b, h, l, d]
201
+ y: [h or 1, m, d]
202
+ ret: [b, h, l, m]
203
+ """
204
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
205
+ return ret
206
+
207
+ def _get_relative_embeddings(self, relative_embeddings, length):
208
+ max_relative_position = 2 * self.window_size + 1
209
+ # Pad first before slice to avoid using cond ops.
210
+ pad_length = max(length - (self.window_size + 1), 0)
211
+ slice_start_position = max((self.window_size + 1) - length, 0)
212
+ slice_end_position = slice_start_position + 2 * length - 1
213
+ if pad_length > 0:
214
+ padded_relative_embeddings = F.pad(
215
+ relative_embeddings,
216
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
217
+ )
218
+ else:
219
+ padded_relative_embeddings = relative_embeddings
220
+ used_relative_embeddings = padded_relative_embeddings[
221
+ :, slice_start_position:slice_end_position
222
+ ]
223
+ return used_relative_embeddings
224
+
225
+ def _relative_position_to_absolute_position(self, x):
226
+ """
227
+ x: [b, h, l, 2*l-1]
228
+ ret: [b, h, l, l]
229
+ """
230
+ batch, heads, length, _ = x.size()
231
+ # Concat columns of pad to shift from relative to absolute indexing.
232
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
233
+
234
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
235
+ x_flat = x.view([batch, heads, length * 2 * length])
236
+ x_flat = F.pad(
237
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
238
+ )
239
+
240
+ # Reshape and slice out the padded elements.
241
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
242
+ :, :, :length, length - 1 :
243
+ ]
244
+ return x_final
245
+
246
+ def _absolute_position_to_relative_position(self, x):
247
+ """
248
+ x: [b, h, l, l]
249
+ ret: [b, h, l, 2*l-1]
250
+ """
251
+ batch, heads, length, _ = x.shape
252
+ # padd along column
253
+ x = F.pad(
254
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
255
+ )
256
+ x_flat = x.view([batch, heads, length * length + length * (length - 1)])
257
+ # add 0's in the beginning that will skew the elements after reshape
258
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
259
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
260
+ return x_final
261
+
262
+ def _attention_bias_proximal(self, length):
263
+ """Bias for self-attention to encourage attention to close positions.
264
+ Args:
265
+ length: an integer scalar.
266
+ Returns:
267
+ a Tensor with shape [1, 1, length, length]
268
+ """
269
+ r = torch.arange(length, dtype=torch.float32)
270
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
271
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
272
+
273
+
274
+ class FFN(nn.Module):
275
+ def __init__(
276
+ self,
277
+ in_channels,
278
+ out_channels,
279
+ filter_channels,
280
+ kernel_size,
281
+ p_dropout=0.0,
282
+ activation=None,
283
+ causal=False,
284
+ ):
285
+ super().__init__()
286
+ self.in_channels = in_channels
287
+ self.out_channels = out_channels
288
+ self.filter_channels = filter_channels
289
+ self.kernel_size = kernel_size
290
+ self.p_dropout = p_dropout
291
+ self.activation = activation
292
+ self.causal = causal
293
+
294
+ if causal:
295
+ self.padding = self._causal_padding
296
+ else:
297
+ self.padding = self._same_padding
298
+
299
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
300
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
301
+ self.drop = nn.Dropout(p_dropout)
302
+
303
+ def forward(self, x, x_mask):
304
+ x = self.conv_1(self.padding(x * x_mask))
305
+ if self.activation == "gelu":
306
+ x = x * torch.sigmoid(1.702 * x)
307
+ else:
308
+ x = torch.relu(x)
309
+ x = self.drop(x)
310
+ x = self.conv_2(self.padding(x * x_mask))
311
+ return x * x_mask
312
+
313
+ def _causal_padding(self, x):
314
+ if self.kernel_size == 1:
315
+ return x
316
+ pad_l = self.kernel_size - 1
317
+ pad_r = 0
318
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
319
+ x = F.pad(x, commons.convert_pad_shape(padding))
320
+ return x
321
+
322
+ def _same_padding(self, x):
323
+ if self.kernel_size == 1:
324
+ return x
325
+ pad_l = (self.kernel_size - 1) // 2
326
+ pad_r = self.kernel_size // 2
327
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
328
+ x = F.pad(x, commons.convert_pad_shape(padding))
329
+ return x
commons.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def init_weights(m, mean=0.0, std=0.01):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ m.weight.data.normal_(mean, std)
11
+
12
+
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size * dilation - dilation) / 2)
15
+
16
+
17
+ def convert_pad_shape(pad_shape):
18
+ l = pad_shape[::-1]
19
+ pad_shape = [item for sublist in l for item in sublist]
20
+ return pad_shape
21
+
22
+
23
+ def intersperse(lst, item):
24
+ result = [item] * (len(lst) * 2 + 1)
25
+ result[1::2] = lst
26
+ return result
27
+
28
+
29
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
30
+ """KL(P||Q)"""
31
+ kl = (logs_q - logs_p) - 0.5
32
+ kl += (
33
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
34
+ )
35
+ return kl
36
+
37
+
38
+ def rand_gumbel(shape):
39
+ """Sample from the Gumbel distribution, protect from overflows."""
40
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
41
+ return -torch.log(-torch.log(uniform_samples))
42
+
43
+
44
+ def rand_gumbel_like(x):
45
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
46
+ return g
47
+
48
+
49
+ def slice_segments(x, ids_str, segment_size=4):
50
+ ret = torch.zeros_like(x[:, :, :segment_size])
51
+ for i in range(x.size(0)):
52
+ idx_str = ids_str[i]
53
+ idx_end = idx_str + segment_size
54
+ ret[i] = x[i, :, idx_str:idx_end]
55
+ return ret
56
+
57
+
58
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
59
+ b, d, t = x.size()
60
+ if x_lengths is None:
61
+ x_lengths = t
62
+ ids_str_max = x_lengths - segment_size + 1
63
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
64
+ ret = slice_segments(x, ids_str, segment_size)
65
+ return ret, ids_str
66
+
67
+
68
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
69
+ position = torch.arange(length, dtype=torch.float)
70
+ num_timescales = channels // 2
71
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
72
+ num_timescales - 1
73
+ )
74
+ inv_timescales = min_timescale * torch.exp(
75
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
76
+ )
77
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
78
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
79
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
80
+ signal = signal.view(1, channels, length)
81
+ return signal
82
+
83
+
84
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
85
+ b, channels, length = x.size()
86
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
87
+ return x + signal.to(dtype=x.dtype, device=x.device)
88
+
89
+
90
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
91
+ b, channels, length = x.size()
92
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
93
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
94
+
95
+
96
+ def subsequent_mask(length):
97
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
98
+ return mask
99
+
100
+
101
+ @torch.jit.script
102
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
103
+ n_channels_int = n_channels[0]
104
+ in_act = input_a + input_b
105
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
106
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
107
+ acts = t_act * s_act
108
+ return acts
109
+
110
+
111
+ def convert_pad_shape(pad_shape):
112
+ l = pad_shape[::-1]
113
+ pad_shape = [item for sublist in l for item in sublist]
114
+ return pad_shape
115
+
116
+
117
+ def shift_1d(x):
118
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
119
+ return x
120
+
121
+
122
+ def sequence_mask(length, max_length=None):
123
+ if max_length is None:
124
+ max_length = length.max()
125
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
126
+ return x.unsqueeze(0) < length.unsqueeze(1)
127
+
128
+
129
+ def generate_path(duration, mask):
130
+ """
131
+ duration: [b, 1, t_x]
132
+ mask: [b, 1, t_y, t_x]
133
+ """
134
+ device = duration.device
135
+
136
+ b, _, t_y, t_x = mask.shape
137
+ cum_duration = torch.cumsum(duration, -1)
138
+
139
+ cum_duration_flat = cum_duration.view(b * t_x)
140
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
141
+ path = path.view(b, t_x, t_y)
142
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
143
+ path = path.unsqueeze(1).transpose(2, 3) * mask
144
+ return path
145
+
146
+
147
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
148
+ if isinstance(parameters, torch.Tensor):
149
+ parameters = [parameters]
150
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
151
+ norm_type = float(norm_type)
152
+ if clip_value is not None:
153
+ clip_value = float(clip_value)
154
+
155
+ total_norm = 0
156
+ for p in parameters:
157
+ param_norm = p.grad.data.norm(norm_type)
158
+ total_norm += param_norm.item() ** norm_type
159
+ if clip_value is not None:
160
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
161
+ total_norm = total_norm ** (1.0 / norm_type)
162
+ return total_norm
config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "learning_rate": 2e-4,
4
+ "betas": [
5
+ 0.8,
6
+ 0.99
7
+ ],
8
+ "eps": 1e-9,
9
+ "lr_decay": 0.999875,
10
+ "segment_size": 8192,
11
+ "c_mel": 45,
12
+ "c_kl": 1.0
13
+ },
14
+ "data": {
15
+ "vocab_size": 256,
16
+ "max_wav_value": 32768.0,
17
+ "sampling_rate": 16000,
18
+ "filter_length": 1024,
19
+ "hop_length": 256,
20
+ "win_length": 1024,
21
+ "n_mel_channels": 80,
22
+ "mel_fmin": 0.0,
23
+ "mel_fmax": null
24
+ },
25
+ "model": {
26
+ "inter_channels": 192,
27
+ "hidden_channels": 192,
28
+ "filter_channels": 768,
29
+ "n_heads": 2,
30
+ "n_layers": 6,
31
+ "kernel_size": 3,
32
+ "p_dropout": 0.1,
33
+ "resblock": "1",
34
+ "resblock_kernel_sizes": [
35
+ 3,
36
+ 7,
37
+ 11
38
+ ],
39
+ "resblock_dilation_sizes": [
40
+ [
41
+ 1,
42
+ 3,
43
+ 5
44
+ ],
45
+ [
46
+ 1,
47
+ 3,
48
+ 5
49
+ ],
50
+ [
51
+ 1,
52
+ 3,
53
+ 5
54
+ ]
55
+ ],
56
+ "upsample_rates": [
57
+ 8,
58
+ 8,
59
+ 2,
60
+ 2
61
+ ],
62
+ "upsample_initial_channel": 512,
63
+ "upsample_kernel_sizes": [
64
+ 16,
65
+ 16,
66
+ 4,
67
+ 4
68
+ ],
69
+ "n_layers_q": 3,
70
+ "use_spectral_norm": false
71
+ }
72
+ }
duration_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e86ab30448a328933b112e5ed6c4c22d7f05f1673528e61d340c98a9cc899eb
3
+ size 1164051
flow.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from modules import WN
5
+
6
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
7
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
8
+ DEFAULT_MIN_DERIVATIVE = 1e-3
9
+
10
+
11
+ class ResidualCouplingLayer(nn.Module):
12
+ def __init__(
13
+ self,
14
+ channels,
15
+ hidden_channels,
16
+ kernel_size,
17
+ dilation_rate,
18
+ n_layers,
19
+ p_dropout=0,
20
+ gin_channels=0,
21
+ mean_only=False,
22
+ ):
23
+ assert channels % 2 == 0, "channels should be divisible by 2"
24
+ super().__init__()
25
+ self.channels = channels
26
+ self.hidden_channels = hidden_channels
27
+ self.kernel_size = kernel_size
28
+ self.dilation_rate = dilation_rate
29
+ self.n_layers = n_layers
30
+ self.half_channels = channels // 2
31
+ self.mean_only = mean_only
32
+
33
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
34
+ self.enc = WN(
35
+ hidden_channels,
36
+ kernel_size,
37
+ dilation_rate,
38
+ n_layers,
39
+ p_dropout=p_dropout,
40
+ gin_channels=gin_channels,
41
+ )
42
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
43
+ self.post.weight.data.zero_()
44
+ self.post.bias.data.zero_()
45
+
46
+ def forward(self, x, x_mask, g=None, reverse=False):
47
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
48
+ h = self.pre(x0) * x_mask
49
+ h = self.enc(h, x_mask, g=g)
50
+ stats = self.post(h) * x_mask
51
+ if not self.mean_only:
52
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
53
+ else:
54
+ m = stats
55
+ logs = torch.zeros_like(m)
56
+
57
+ if not reverse:
58
+ x1 = m + x1 * torch.exp(logs) * x_mask
59
+ x = torch.cat([x0, x1], 1)
60
+ logdet = torch.sum(logs, [1, 2])
61
+ return x, logdet
62
+ else:
63
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
64
+ x = torch.cat([x0, x1], 1)
65
+ return x
66
+
67
+
68
+ class Flip(nn.Module):
69
+ def forward(self, x, *args, reverse=False, **kwargs):
70
+ x = torch.flip(x, [1])
71
+ if not reverse:
72
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
73
+ return x, logdet
74
+ else:
75
+ return x
76
+
77
+
78
+ class ResidualCouplingBlock(nn.Module):
79
+ def __init__(
80
+ self,
81
+ channels,
82
+ hidden_channels,
83
+ kernel_size,
84
+ dilation_rate,
85
+ n_layers,
86
+ n_flows=4,
87
+ gin_channels=0,
88
+ ):
89
+ super().__init__()
90
+ self.channels = channels
91
+ self.hidden_channels = hidden_channels
92
+ self.kernel_size = kernel_size
93
+ self.dilation_rate = dilation_rate
94
+ self.n_layers = n_layers
95
+ self.n_flows = n_flows
96
+ self.gin_channels = gin_channels
97
+
98
+ self.flows = nn.ModuleList()
99
+ for i in range(n_flows):
100
+ self.flows.append(
101
+ ResidualCouplingLayer(
102
+ channels,
103
+ hidden_channels,
104
+ kernel_size,
105
+ dilation_rate,
106
+ n_layers,
107
+ gin_channels=gin_channels,
108
+ mean_only=True,
109
+ )
110
+ )
111
+ self.flows.append(Flip())
112
+
113
+ def forward(self, x, x_mask, g=None, reverse=False):
114
+ if not reverse:
115
+ for flow in self.flows:
116
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
117
+ else:
118
+ for flow in reversed(self.flows):
119
+ x = flow(x, x_mask, g=g, reverse=reverse)
120
+ return x
gen_210k.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6f8485f44f492262492231e90633fead68b5db3f65bd1a73621d618a6f3a173
3
+ size 111280752
gen_543k.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e11e4b12e3e9a67ad23ec0cfafbcbd8e810f83d55fb6c38f61a6e9886c94cc5
3
+ size 111280752
gen_630k.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93fd8e41a8138978e387db561e32ad6ca0798cb09d44aefa239add5ac47e13a6
3
+ size 111280317
models.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
6
+ from torch.nn import functional as F
7
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
8
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
9
+
10
+ import attentions
11
+ import commons
12
+ import modules
13
+ from commons import get_padding, init_weights
14
+ from flow import ResidualCouplingBlock
15
+
16
+
17
+ class PriorEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ n_vocab,
21
+ out_channels,
22
+ hidden_channels,
23
+ filter_channels,
24
+ n_heads,
25
+ n_layers,
26
+ kernel_size,
27
+ p_dropout,
28
+ ):
29
+ super().__init__()
30
+ self.n_vocab = n_vocab
31
+ self.out_channels = out_channels
32
+ self.hidden_channels = hidden_channels
33
+ self.filter_channels = filter_channels
34
+ self.n_heads = n_heads
35
+ self.n_layers = n_layers
36
+ self.kernel_size = kernel_size
37
+ self.p_dropout = p_dropout
38
+
39
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
40
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
41
+ self.pre_attn_encoder = attentions.Encoder(
42
+ hidden_channels,
43
+ filter_channels,
44
+ n_heads,
45
+ n_layers // 2,
46
+ kernel_size,
47
+ p_dropout,
48
+ )
49
+ self.post_attn_encoder = attentions.Encoder(
50
+ hidden_channels,
51
+ filter_channels,
52
+ n_heads,
53
+ n_layers - n_layers // 2,
54
+ kernel_size,
55
+ p_dropout,
56
+ )
57
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
58
+
59
+ def forward(self, x, x_lengths, y_lengths, attn):
60
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
61
+ x = torch.transpose(x, 1, -1) # [b, h, t]
62
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
63
+ x.dtype
64
+ )
65
+ x = self.pre_attn_encoder(x * x_mask, x_mask)
66
+ y = torch.einsum("bht,blt->bhl", x, attn)
67
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
68
+ y.dtype
69
+ )
70
+ y = self.post_attn_encoder(y * y_mask, y_mask)
71
+ stats = self.proj(y) * y_mask
72
+
73
+ m, logs = torch.split(stats, self.out_channels, dim=1)
74
+ return y, m, logs, y_mask
75
+
76
+
77
+ class PosteriorEncoder(nn.Module):
78
+ def __init__(
79
+ self,
80
+ in_channels,
81
+ out_channels,
82
+ hidden_channels,
83
+ kernel_size,
84
+ dilation_rate,
85
+ n_layers,
86
+ gin_channels=0,
87
+ ):
88
+ super().__init__()
89
+ self.in_channels = in_channels
90
+ self.out_channels = out_channels
91
+ self.hidden_channels = hidden_channels
92
+ self.kernel_size = kernel_size
93
+ self.dilation_rate = dilation_rate
94
+ self.n_layers = n_layers
95
+ self.gin_channels = gin_channels
96
+
97
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
98
+ self.enc = modules.WN(
99
+ hidden_channels,
100
+ kernel_size,
101
+ dilation_rate,
102
+ n_layers,
103
+ gin_channels=gin_channels,
104
+ )
105
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
106
+
107
+ def forward(self, x, x_lengths, g=None):
108
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
109
+ x.dtype
110
+ )
111
+ x = self.pre(x) * x_mask
112
+ x = self.enc(x, x_mask, g=g)
113
+ stats = self.proj(x) * x_mask
114
+ m, logs = torch.split(stats, self.out_channels, dim=1)
115
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
116
+ return z, m, logs, x_mask
117
+
118
+
119
+ class Generator(torch.nn.Module):
120
+ def __init__(
121
+ self,
122
+ initial_channel,
123
+ resblock,
124
+ resblock_kernel_sizes,
125
+ resblock_dilation_sizes,
126
+ upsample_rates,
127
+ upsample_initial_channel,
128
+ upsample_kernel_sizes,
129
+ gin_channels=0,
130
+ ):
131
+ super(Generator, self).__init__()
132
+ self.num_kernels = len(resblock_kernel_sizes)
133
+ self.num_upsamples = len(upsample_rates)
134
+ self.conv_pre = Conv1d(
135
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
136
+ )
137
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
138
+
139
+ self.ups = nn.ModuleList()
140
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
141
+ self.ups.append(
142
+ weight_norm(
143
+ ConvTranspose1d(
144
+ upsample_initial_channel // (2**i),
145
+ upsample_initial_channel // (2 ** (i + 1)),
146
+ k,
147
+ u,
148
+ padding=(k - u) // 2,
149
+ )
150
+ )
151
+ )
152
+
153
+ self.resblocks = nn.ModuleList()
154
+ for i in range(len(self.ups)):
155
+ ch = upsample_initial_channel // (2 ** (i + 1))
156
+ for j, (k, d) in enumerate(
157
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
158
+ ):
159
+ self.resblocks.append(resblock(ch, k, d))
160
+
161
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
162
+ self.ups.apply(init_weights)
163
+
164
+ if gin_channels != 0:
165
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
166
+
167
+ def forward(self, x, g=None):
168
+ x = self.conv_pre(x)
169
+ if g is not None:
170
+ x = x + self.cond(g)
171
+
172
+ for i in range(self.num_upsamples):
173
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
174
+ x = self.ups[i](x)
175
+ xs = None
176
+ for j in range(self.num_kernels):
177
+ if xs is None:
178
+ xs = self.resblocks[i * self.num_kernels + j](x)
179
+ else:
180
+ xs += self.resblocks[i * self.num_kernels + j](x)
181
+ x = xs / self.num_kernels
182
+ x = F.leaky_relu(x)
183
+ x = self.conv_post(x)
184
+ x = torch.tanh(x)
185
+
186
+ return x
187
+
188
+ def remove_weight_norm(self):
189
+ print("Removing weight norm...")
190
+ for l in self.ups:
191
+ remove_weight_norm(l)
192
+ for l in self.resblocks:
193
+ l.remove_weight_norm()
194
+
195
+
196
+ class DiscriminatorP(torch.nn.Module):
197
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
198
+ super(DiscriminatorP, self).__init__()
199
+ self.period = period
200
+ self.use_spectral_norm = use_spectral_norm
201
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
202
+ self.convs = nn.ModuleList(
203
+ [
204
+ norm_f(
205
+ Conv2d(
206
+ 1,
207
+ 32,
208
+ (kernel_size, 1),
209
+ (stride, 1),
210
+ padding=(get_padding(kernel_size, 1), 0),
211
+ )
212
+ ),
213
+ norm_f(
214
+ Conv2d(
215
+ 32,
216
+ 128,
217
+ (kernel_size, 1),
218
+ (stride, 1),
219
+ padding=(get_padding(kernel_size, 1), 0),
220
+ )
221
+ ),
222
+ norm_f(
223
+ Conv2d(
224
+ 128,
225
+ 512,
226
+ (kernel_size, 1),
227
+ (stride, 1),
228
+ padding=(get_padding(kernel_size, 1), 0),
229
+ )
230
+ ),
231
+ norm_f(
232
+ Conv2d(
233
+ 512,
234
+ 1024,
235
+ (kernel_size, 1),
236
+ (stride, 1),
237
+ padding=(get_padding(kernel_size, 1), 0),
238
+ )
239
+ ),
240
+ norm_f(
241
+ Conv2d(
242
+ 1024,
243
+ 1024,
244
+ (kernel_size, 1),
245
+ 1,
246
+ padding=(get_padding(kernel_size, 1), 0),
247
+ )
248
+ ),
249
+ ]
250
+ )
251
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
252
+
253
+ def forward(self, x):
254
+ fmap = []
255
+
256
+ # 1d to 2d
257
+ b, c, t = x.shape
258
+ if t % self.period != 0: # pad first
259
+ n_pad = self.period - (t % self.period)
260
+ x = F.pad(x, (0, n_pad), "reflect")
261
+ t = t + n_pad
262
+ x = x.view(b, c, t // self.period, self.period)
263
+
264
+ for l in self.convs:
265
+ x = l(x)
266
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
267
+ fmap.append(x)
268
+ x = self.conv_post(x)
269
+ fmap.append(x)
270
+ x = torch.flatten(x, 1, -1)
271
+
272
+ return x, fmap
273
+
274
+
275
+ class DiscriminatorS(torch.nn.Module):
276
+ def __init__(self, use_spectral_norm=False):
277
+ super(DiscriminatorS, self).__init__()
278
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
279
+ self.convs = nn.ModuleList(
280
+ [
281
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
282
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
283
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
284
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
285
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
286
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
287
+ ]
288
+ )
289
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
290
+
291
+ def forward(self, x):
292
+ fmap = []
293
+
294
+ for l in self.convs:
295
+ x = l(x)
296
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
297
+ fmap.append(x)
298
+ x = self.conv_post(x)
299
+ fmap.append(x)
300
+ x = torch.flatten(x, 1, -1)
301
+
302
+ return x, fmap
303
+
304
+
305
+ class MultiPeriodDiscriminator(torch.nn.Module):
306
+ def __init__(self, use_spectral_norm=False):
307
+ super(MultiPeriodDiscriminator, self).__init__()
308
+ periods = [2, 3, 5, 7, 11]
309
+
310
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
311
+ discs = discs + [
312
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
313
+ ]
314
+ self.discriminators = nn.ModuleList(discs)
315
+
316
+ def forward(self, y, y_hat):
317
+ y_d_rs = []
318
+ y_d_gs = []
319
+ fmap_rs = []
320
+ fmap_gs = []
321
+ for i, d in enumerate(self.discriminators):
322
+ y_d_r, fmap_r = d(y)
323
+ y_d_g, fmap_g = d(y_hat)
324
+ y_d_rs.append(y_d_r)
325
+ y_d_gs.append(y_d_g)
326
+ fmap_rs.append(fmap_r)
327
+ fmap_gs.append(fmap_g)
328
+
329
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
330
+
331
+
332
+ class SynthesizerTrn(nn.Module):
333
+ """
334
+ Synthesizer for Training
335
+ """
336
+
337
+ def __init__(
338
+ self,
339
+ n_vocab,
340
+ spec_channels,
341
+ segment_size,
342
+ inter_channels,
343
+ hidden_channels,
344
+ filter_channels,
345
+ n_heads,
346
+ n_layers,
347
+ kernel_size,
348
+ p_dropout,
349
+ resblock,
350
+ resblock_kernel_sizes,
351
+ resblock_dilation_sizes,
352
+ upsample_rates,
353
+ upsample_initial_channel,
354
+ upsample_kernel_sizes,
355
+ n_speakers=0,
356
+ gin_channels=0,
357
+ **kwargs
358
+ ):
359
+ super().__init__()
360
+ self.n_vocab = n_vocab
361
+ self.spec_channels = spec_channels
362
+ self.inter_channels = inter_channels
363
+ self.hidden_channels = hidden_channels
364
+ self.filter_channels = filter_channels
365
+ self.n_heads = n_heads
366
+ self.n_layers = n_layers
367
+ self.kernel_size = kernel_size
368
+ self.p_dropout = p_dropout
369
+ self.resblock = resblock
370
+ self.resblock_kernel_sizes = resblock_kernel_sizes
371
+ self.resblock_dilation_sizes = resblock_dilation_sizes
372
+ self.upsample_rates = upsample_rates
373
+ self.upsample_initial_channel = upsample_initial_channel
374
+ self.upsample_kernel_sizes = upsample_kernel_sizes
375
+ self.segment_size = segment_size
376
+ self.n_speakers = n_speakers
377
+ self.gin_channels = gin_channels
378
+
379
+ self.enc_p = PriorEncoder(
380
+ n_vocab,
381
+ inter_channels,
382
+ hidden_channels,
383
+ filter_channels,
384
+ n_heads,
385
+ n_layers,
386
+ kernel_size,
387
+ p_dropout,
388
+ )
389
+ self.dec = Generator(
390
+ inter_channels,
391
+ resblock,
392
+ resblock_kernel_sizes,
393
+ resblock_dilation_sizes,
394
+ upsample_rates,
395
+ upsample_initial_channel,
396
+ upsample_kernel_sizes,
397
+ gin_channels=gin_channels,
398
+ )
399
+ self.enc_q = PosteriorEncoder(
400
+ spec_channels,
401
+ inter_channels,
402
+ hidden_channels,
403
+ 5,
404
+ 1,
405
+ 16,
406
+ gin_channels=gin_channels,
407
+ )
408
+ self.flow = ResidualCouplingBlock(
409
+ inter_channels, hidden_channels, 5, 2, 4, gin_channels=gin_channels
410
+ )
411
+
412
+ if n_speakers > 1:
413
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
414
+
415
+ def forward(self, x, x_lengths, attn, y, y_lengths, sid=None):
416
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, y_lengths, attn=attn)
417
+ if self.n_speakers > 0:
418
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
419
+ else:
420
+ g = None
421
+
422
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
423
+ z_p = self.flow(z, y_mask, g=g)
424
+
425
+ z_slice, ids_slice = commons.rand_slice_segments(
426
+ z, y_lengths, self.segment_size
427
+ )
428
+ o = self.dec(z_slice, g=g)
429
+ l_length = None
430
+ return (
431
+ o,
432
+ l_length,
433
+ attn,
434
+ ids_slice,
435
+ x_mask,
436
+ y_mask,
437
+ (z, z_p, m_p, logs_p, m_q, logs_q),
438
+ )
439
+
440
+ def infer(
441
+ self,
442
+ x,
443
+ x_lengths,
444
+ y_lengths,
445
+ attn,
446
+ sid=None,
447
+ noise_scale=1,
448
+ max_len=None,
449
+ ):
450
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, y_lengths, attn=attn)
451
+ if self.n_speakers > 0:
452
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
453
+ else:
454
+ g = None
455
+
456
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, attn.shape[1]), 1).to(
457
+ x_mask.dtype
458
+ )
459
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
460
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
461
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
462
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
463
+
464
+
465
+ class DurationNet(torch.nn.Module):
466
+ def __init__(self, vocab_size: int, dim: int, num_layers=2):
467
+ super().__init__()
468
+ self.embed = torch.nn.Embedding(vocab_size, embedding_dim=dim)
469
+ self.rnn = torch.nn.GRU(
470
+ dim,
471
+ dim,
472
+ num_layers=num_layers,
473
+ batch_first=True,
474
+ bidirectional=True,
475
+ dropout=0.2,
476
+ )
477
+ self.proj = torch.nn.Linear(2 * dim, 1)
478
+
479
+ def forward(self, token, lengths):
480
+ x = self.embed(token)
481
+ lengths = lengths.long().cpu()
482
+ x = pack_padded_sequence(
483
+ x, lengths=lengths, batch_first=True, enforce_sorted=False
484
+ )
485
+ x, _ = self.rnn(x)
486
+ x, _ = pad_packed_sequence(x, batch_first=True, total_length=token.shape[1])
487
+ x = self.proj(x)
488
+ x = torch.nn.functional.softplus(x)
489
+ return x
modules.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Conv1d
4
+ from torch.nn import functional as F
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+
7
+ import commons
8
+ from commons import get_padding, init_weights
9
+
10
+ LRELU_SLOPE = 0.1
11
+
12
+
13
+ class LayerNorm(nn.Module):
14
+ def __init__(self, channels, eps=1e-5):
15
+ super().__init__()
16
+ self.channels = channels
17
+ self.eps = eps
18
+
19
+ self.gamma = nn.Parameter(torch.ones(channels))
20
+ self.beta = nn.Parameter(torch.zeros(channels))
21
+
22
+ def forward(self, x):
23
+ x = x.transpose(1, -1)
24
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
25
+ return x.transpose(1, -1)
26
+
27
+
28
+ class ConvReluNorm(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_channels,
32
+ hidden_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ n_layers,
36
+ p_dropout,
37
+ ):
38
+ super().__init__()
39
+ self.in_channels = in_channels
40
+ self.hidden_channels = hidden_channels
41
+ self.out_channels = out_channels
42
+ self.kernel_size = kernel_size
43
+ self.n_layers = n_layers
44
+ self.p_dropout = p_dropout
45
+ assert n_layers > 1, "Number of layers should be larger than 0."
46
+
47
+ self.conv_layers = nn.ModuleList()
48
+ self.norm_layers = nn.ModuleList()
49
+ self.conv_layers.append(
50
+ nn.Conv1d(
51
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
52
+ )
53
+ )
54
+ self.norm_layers.append(LayerNorm(hidden_channels))
55
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
56
+ for _ in range(n_layers - 1):
57
+ self.conv_layers.append(
58
+ nn.Conv1d(
59
+ hidden_channels,
60
+ hidden_channels,
61
+ kernel_size,
62
+ padding=kernel_size // 2,
63
+ )
64
+ )
65
+ self.norm_layers.append(LayerNorm(hidden_channels))
66
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
67
+ self.proj.weight.data.zero_()
68
+ self.proj.bias.data.zero_()
69
+
70
+ def forward(self, x, x_mask):
71
+ x_org = x
72
+ for i in range(self.n_layers):
73
+ x = self.conv_layers[i](x * x_mask)
74
+ x = self.norm_layers[i](x)
75
+ x = self.relu_drop(x)
76
+ x = x_org + self.proj(x)
77
+ return x * x_mask
78
+
79
+
80
+ class DDSConv(nn.Module):
81
+ """
82
+ Dialted and Depth-Separable Convolution
83
+ """
84
+
85
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
86
+ super().__init__()
87
+ self.channels = channels
88
+ self.kernel_size = kernel_size
89
+ self.n_layers = n_layers
90
+ self.p_dropout = p_dropout
91
+
92
+ self.drop = nn.Dropout(p_dropout)
93
+ self.convs_sep = nn.ModuleList()
94
+ self.convs_1x1 = nn.ModuleList()
95
+ self.norms_1 = nn.ModuleList()
96
+ self.norms_2 = nn.ModuleList()
97
+ for i in range(n_layers):
98
+ dilation = kernel_size**i
99
+ padding = (kernel_size * dilation - dilation) // 2
100
+ self.convs_sep.append(
101
+ nn.Conv1d(
102
+ channels,
103
+ channels,
104
+ kernel_size,
105
+ groups=channels,
106
+ dilation=dilation,
107
+ padding=padding,
108
+ )
109
+ )
110
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
111
+ self.norms_1.append(LayerNorm(channels))
112
+ self.norms_2.append(LayerNorm(channels))
113
+
114
+ def forward(self, x, x_mask, g=None):
115
+ if g is not None:
116
+ x = x + g
117
+ for i in range(self.n_layers):
118
+ y = self.convs_sep[i](x * x_mask)
119
+ y = self.norms_1[i](y)
120
+ y = F.gelu(y)
121
+ y = self.convs_1x1[i](y)
122
+ y = self.norms_2[i](y)
123
+ y = F.gelu(y)
124
+ y = self.drop(y)
125
+ x = x + y
126
+ return x * x_mask
127
+
128
+
129
+ class WN(torch.nn.Module):
130
+ def __init__(
131
+ self,
132
+ hidden_channels,
133
+ kernel_size,
134
+ dilation_rate,
135
+ n_layers,
136
+ gin_channels=0,
137
+ p_dropout=0,
138
+ ):
139
+ super(WN, self).__init__()
140
+ assert kernel_size % 2 == 1
141
+ self.hidden_channels = hidden_channels
142
+ self.kernel_size = (kernel_size,)
143
+ self.dilation_rate = dilation_rate
144
+ self.n_layers = n_layers
145
+ self.gin_channels = gin_channels
146
+ self.p_dropout = p_dropout
147
+
148
+ self.in_layers = torch.nn.ModuleList()
149
+ self.res_skip_layers = torch.nn.ModuleList()
150
+ self.drop = nn.Dropout(p_dropout)
151
+
152
+ if gin_channels != 0:
153
+ cond_layer = torch.nn.Conv1d(
154
+ gin_channels, 2 * hidden_channels * n_layers, 1
155
+ )
156
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
157
+
158
+ for i in range(n_layers):
159
+ dilation = dilation_rate**i
160
+ padding = int((kernel_size * dilation - dilation) / 2)
161
+ in_layer = torch.nn.Conv1d(
162
+ hidden_channels,
163
+ 2 * hidden_channels,
164
+ kernel_size,
165
+ dilation=dilation,
166
+ padding=padding,
167
+ )
168
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
169
+ self.in_layers.append(in_layer)
170
+
171
+ # last one is not necessary
172
+ if i < n_layers - 1:
173
+ res_skip_channels = 2 * hidden_channels
174
+ else:
175
+ res_skip_channels = hidden_channels
176
+
177
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
178
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
179
+ self.res_skip_layers.append(res_skip_layer)
180
+
181
+ def forward(self, x, x_mask, g=None, **kwargs):
182
+ output = torch.zeros_like(x)
183
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
184
+
185
+ if g is not None:
186
+ g = self.cond_layer(g)
187
+
188
+ for i in range(self.n_layers):
189
+ x_in = self.in_layers[i](x)
190
+ if g is not None:
191
+ cond_offset = i * 2 * self.hidden_channels
192
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
193
+ else:
194
+ g_l = torch.zeros_like(x_in)
195
+
196
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
197
+ acts = self.drop(acts)
198
+
199
+ res_skip_acts = self.res_skip_layers[i](acts)
200
+ if i < self.n_layers - 1:
201
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
202
+ x = (x + res_acts) * x_mask
203
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
204
+ else:
205
+ output = output + res_skip_acts
206
+ return output * x_mask
207
+
208
+ def remove_weight_norm(self):
209
+ if self.gin_channels != 0:
210
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
211
+ for l in self.in_layers:
212
+ torch.nn.utils.remove_weight_norm(l)
213
+ for l in self.res_skip_layers:
214
+ torch.nn.utils.remove_weight_norm(l)
215
+
216
+
217
+ class ResBlock1(torch.nn.Module):
218
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
219
+ super(ResBlock1, self).__init__()
220
+ self.convs1 = nn.ModuleList(
221
+ [
222
+ weight_norm(
223
+ Conv1d(
224
+ channels,
225
+ channels,
226
+ kernel_size,
227
+ 1,
228
+ dilation=dilation[0],
229
+ padding=get_padding(kernel_size, dilation[0]),
230
+ )
231
+ ),
232
+ weight_norm(
233
+ Conv1d(
234
+ channels,
235
+ channels,
236
+ kernel_size,
237
+ 1,
238
+ dilation=dilation[1],
239
+ padding=get_padding(kernel_size, dilation[1]),
240
+ )
241
+ ),
242
+ weight_norm(
243
+ Conv1d(
244
+ channels,
245
+ channels,
246
+ kernel_size,
247
+ 1,
248
+ dilation=dilation[2],
249
+ padding=get_padding(kernel_size, dilation[2]),
250
+ )
251
+ ),
252
+ ]
253
+ )
254
+ self.convs1.apply(init_weights)
255
+
256
+ self.convs2 = nn.ModuleList(
257
+ [
258
+ weight_norm(
259
+ Conv1d(
260
+ channels,
261
+ channels,
262
+ kernel_size,
263
+ 1,
264
+ dilation=1,
265
+ padding=get_padding(kernel_size, 1),
266
+ )
267
+ ),
268
+ weight_norm(
269
+ Conv1d(
270
+ channels,
271
+ channels,
272
+ kernel_size,
273
+ 1,
274
+ dilation=1,
275
+ padding=get_padding(kernel_size, 1),
276
+ )
277
+ ),
278
+ weight_norm(
279
+ Conv1d(
280
+ channels,
281
+ channels,
282
+ kernel_size,
283
+ 1,
284
+ dilation=1,
285
+ padding=get_padding(kernel_size, 1),
286
+ )
287
+ ),
288
+ ]
289
+ )
290
+ self.convs2.apply(init_weights)
291
+
292
+ def forward(self, x, x_mask=None):
293
+ for c1, c2 in zip(self.convs1, self.convs2):
294
+ xt = F.leaky_relu(x, LRELU_SLOPE)
295
+ if x_mask is not None:
296
+ xt = xt * x_mask
297
+ xt = c1(xt)
298
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c2(xt)
302
+ x = xt + x
303
+ if x_mask is not None:
304
+ x = x * x_mask
305
+ return x
306
+
307
+ def remove_weight_norm(self):
308
+ for l in self.convs1:
309
+ remove_weight_norm(l)
310
+ for l in self.convs2:
311
+ remove_weight_norm(l)
312
+
313
+
314
+ class ResBlock2(torch.nn.Module):
315
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
316
+ super(ResBlock2, self).__init__()
317
+ self.convs = nn.ModuleList(
318
+ [
319
+ weight_norm(
320
+ Conv1d(
321
+ channels,
322
+ channels,
323
+ kernel_size,
324
+ 1,
325
+ dilation=dilation[0],
326
+ padding=get_padding(kernel_size, dilation[0]),
327
+ )
328
+ ),
329
+ weight_norm(
330
+ Conv1d(
331
+ channels,
332
+ channels,
333
+ kernel_size,
334
+ 1,
335
+ dilation=dilation[1],
336
+ padding=get_padding(kernel_size, dilation[1]),
337
+ )
338
+ ),
339
+ ]
340
+ )
341
+ self.convs.apply(init_weights)
342
+
343
+ def forward(self, x, x_mask=None):
344
+ for c in self.convs:
345
+ xt = F.leaky_relu(x, LRELU_SLOPE)
346
+ if x_mask is not None:
347
+ xt = xt * x_mask
348
+ xt = c(xt)
349
+ x = xt + x
350
+ if x_mask is not None:
351
+ x = x * x_mask
352
+ return x
353
+
354
+ def remove_weight_norm(self):
355
+ for l in self.convs:
356
+ remove_weight_norm(l)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libsndfile1-dev
phone_set.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["[SEP]", "a", "b", "c", "d", "e", "g", "h", "i", "k", "l", "m", "n", "o", "p", "q", "r", "s", "sil", "spn", "t", "u", "v", "x", "y", "\u00e0", "\u00e1", "\u00e2", "\u00e3", "\u00e8", "\u00e9", "\u00ea", "\u00ec", "\u00ed", "\u00f2", "\u00f3", "\u00f4", "\u00f5", "\u00f9", "\u00fa", "\u00fd", "\u0103", "\u0111", "\u0129", "\u0169", "\u01a1", "\u01b0", "\u1ea1", "\u1ea3", "\u1ea5", "\u1ea7", "\u1ea9", "\u1eab", "\u1ead", "\u1eaf", "\u1eb1", "\u1eb3", "\u1eb5", "\u1eb7", "\u1eb9", "\u1ebb", "\u1ebd", "\u1ebf", "\u1ec1", "\u1ec3", "\u1ec5", "\u1ec7", "\u1ec9", "\u1ecb", "\u1ecd", "\u1ecf", "\u1ed1", "\u1ed3", "\u1ed5", "\u1ed7", "\u1ed9", "\u1edb", "\u1edd", "\u1edf", "\u1ee1", "\u1ee3", "\u1ee5", "\u1ee7", "\u1ee9", "\u1eeb", "\u1eed", "\u1eef", "\u1ef1", "\u1ef3", "\u1ef5", "\u1ef7", "\u1ef9"]
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ regex
3
+ torch