Spaces:
Runtime error
Runtime error
Commit
•
ff36620
0
Parent(s):
Duplicate from ntt123/Vietnam-female-voice-TTS
Browse filesCo-authored-by: Thông Nguyễn <[email protected]>
- .gitattributes +35 -0
- README.md +14 -0
- app.py +230 -0
- attentions.py +329 -0
- commons.py +162 -0
- config.json +72 -0
- duration_model.pth +3 -0
- flow.py +120 -0
- gen_210k.pth +3 -0
- gen_543k.pth +3 -0
- gen_630k.pth +3 -0
- models.py +489 -0
- modules.py +356 -0
- packages.txt +1 -0
- phone_set.json +1 -0
- requirements.txt +3 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Vietnam Female Voice TTS
|
3 |
+
emoji: 👁
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.40.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: cc-by-sa-4.0
|
11 |
+
duplicated_from: ntt123/Vietnam-female-voice-TTS
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch # isort:skip
|
2 |
+
|
3 |
+
torch.manual_seed(42)
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
import unicodedata
|
7 |
+
from types import SimpleNamespace
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
import regex
|
12 |
+
|
13 |
+
from models import DurationNet, SynthesizerTrn
|
14 |
+
|
15 |
+
title = "LightSpeed: Vietnamese Female Voice TTS"
|
16 |
+
description = "Vietnam Female Voice TTS."
|
17 |
+
config_file = "config.json"
|
18 |
+
duration_model_path = "duration_model.pth"
|
19 |
+
lightspeed_model_path = "gen_630k.pth"
|
20 |
+
phone_set_file = "phone_set.json"
|
21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
with open(config_file, "rb") as f:
|
23 |
+
hps = json.load(f, object_hook=lambda x: SimpleNamespace(**x))
|
24 |
+
|
25 |
+
# load phone set json file
|
26 |
+
with open(phone_set_file, "r") as f:
|
27 |
+
phone_set = json.load(f)
|
28 |
+
|
29 |
+
assert phone_set[0][1:-1] == "SEP"
|
30 |
+
assert "sil" in phone_set
|
31 |
+
sil_idx = phone_set.index("sil")
|
32 |
+
|
33 |
+
space_re = regex.compile(r"\s+")
|
34 |
+
number_re = regex.compile("([0-9]+)")
|
35 |
+
digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
|
36 |
+
num_re = regex.compile(r"([0-9.,]*[0-9])")
|
37 |
+
alphabet = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
|
38 |
+
keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]")
|
39 |
+
keep_text_re = regex.compile(rf"[^\s{alphabet}]")
|
40 |
+
|
41 |
+
|
42 |
+
def read_number(num: str) -> str:
|
43 |
+
if len(num) == 1:
|
44 |
+
return digits[int(num)]
|
45 |
+
elif len(num) == 2 and num.isdigit():
|
46 |
+
n = int(num)
|
47 |
+
end = digits[n % 10]
|
48 |
+
if n == 10:
|
49 |
+
return "mười"
|
50 |
+
if n % 10 == 5:
|
51 |
+
end = "lăm"
|
52 |
+
if n % 10 == 0:
|
53 |
+
return digits[n // 10] + " mươi"
|
54 |
+
elif n < 20:
|
55 |
+
return "mười " + end
|
56 |
+
else:
|
57 |
+
if n % 10 == 1:
|
58 |
+
end = "mốt"
|
59 |
+
return digits[n // 10] + " mươi " + end
|
60 |
+
elif len(num) == 3 and num.isdigit():
|
61 |
+
n = int(num)
|
62 |
+
if n % 100 == 0:
|
63 |
+
return digits[n // 100] + " trăm"
|
64 |
+
elif num[1] == "0":
|
65 |
+
return digits[n // 100] + " trăm lẻ " + digits[n % 100]
|
66 |
+
else:
|
67 |
+
return digits[n // 100] + " trăm " + read_number(num[1:])
|
68 |
+
elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
|
69 |
+
n = int(num)
|
70 |
+
n1 = n // 1000
|
71 |
+
return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
|
72 |
+
elif "," in num:
|
73 |
+
n1, n2 = num.split(",")
|
74 |
+
return read_number(n1) + " phẩy " + read_number(n2)
|
75 |
+
elif "." in num:
|
76 |
+
parts = num.split(".")
|
77 |
+
if len(parts) == 2:
|
78 |
+
if parts[1] == "000":
|
79 |
+
return read_number(parts[0]) + " ngàn"
|
80 |
+
elif parts[1].startswith("00"):
|
81 |
+
end = digits[int(parts[1][2:])]
|
82 |
+
return read_number(parts[0]) + " ngàn lẻ " + end
|
83 |
+
else:
|
84 |
+
return read_number(parts[0]) + " ngàn " + read_number(parts[1])
|
85 |
+
elif len(parts) == 3:
|
86 |
+
return (
|
87 |
+
read_number(parts[0])
|
88 |
+
+ " triệu "
|
89 |
+
+ read_number(parts[1])
|
90 |
+
+ " ngàn "
|
91 |
+
+ read_number(parts[2])
|
92 |
+
)
|
93 |
+
return num
|
94 |
+
|
95 |
+
|
96 |
+
def text_to_phone_idx(text):
|
97 |
+
# lowercase
|
98 |
+
text = text.lower()
|
99 |
+
# unicode normalize
|
100 |
+
text = unicodedata.normalize("NFKC", text)
|
101 |
+
text = text.replace(".", " . ")
|
102 |
+
text = text.replace(",", " , ")
|
103 |
+
text = text.replace(";", " ; ")
|
104 |
+
text = text.replace(":", " : ")
|
105 |
+
text = text.replace("!", " ! ")
|
106 |
+
text = text.replace("?", " ? ")
|
107 |
+
text = text.replace("(", " ( ")
|
108 |
+
|
109 |
+
text = num_re.sub(r" \1 ", text)
|
110 |
+
words = text.split()
|
111 |
+
words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
|
112 |
+
text = " ".join(words)
|
113 |
+
|
114 |
+
# remove redundant spaces
|
115 |
+
text = re.sub(r"\s+", " ", text)
|
116 |
+
# remove leading and trailing spaces
|
117 |
+
text = text.strip()
|
118 |
+
# convert words to phone indices
|
119 |
+
tokens = []
|
120 |
+
for c in text:
|
121 |
+
# if c is "," or ".", add <sil> phone
|
122 |
+
if c in ":,.!?;(":
|
123 |
+
tokens.append(sil_idx)
|
124 |
+
elif c in phone_set:
|
125 |
+
tokens.append(phone_set.index(c))
|
126 |
+
elif c == " ":
|
127 |
+
# add <sep> phone
|
128 |
+
tokens.append(0)
|
129 |
+
if tokens[0] != sil_idx:
|
130 |
+
# insert <sil> phone at the beginning
|
131 |
+
tokens = [sil_idx, 0] + tokens
|
132 |
+
if tokens[-1] != sil_idx:
|
133 |
+
tokens = tokens + [0, sil_idx]
|
134 |
+
return tokens
|
135 |
+
|
136 |
+
|
137 |
+
def text_to_speech(duration_net, generator, text):
|
138 |
+
# prevent too long text
|
139 |
+
if len(text) > 500:
|
140 |
+
text = text[:500]
|
141 |
+
|
142 |
+
phone_idx = text_to_phone_idx(text)
|
143 |
+
batch = {
|
144 |
+
"phone_idx": np.array([phone_idx]),
|
145 |
+
"phone_length": np.array([len(phone_idx)]),
|
146 |
+
}
|
147 |
+
|
148 |
+
# predict phoneme duration
|
149 |
+
phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
|
150 |
+
phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
|
151 |
+
with torch.inference_mode():
|
152 |
+
phone_duration = duration_net(phone_idx, phone_length)[:, :, 0] * 1000
|
153 |
+
phone_duration = torch.where(
|
154 |
+
phone_idx == sil_idx, torch.clamp_min(phone_duration, 200), phone_duration
|
155 |
+
)
|
156 |
+
phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
|
157 |
+
|
158 |
+
# generate waveform
|
159 |
+
end_time = torch.cumsum(phone_duration, dim=-1)
|
160 |
+
start_time = end_time - phone_duration
|
161 |
+
start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
|
162 |
+
end_frame = end_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
|
163 |
+
spec_length = end_frame.max(dim=-1).values
|
164 |
+
pos = torch.arange(0, spec_length.item(), device=device)
|
165 |
+
attn = torch.logical_and(
|
166 |
+
pos[None, :, None] >= start_frame[:, None, :],
|
167 |
+
pos[None, :, None] < end_frame[:, None, :],
|
168 |
+
).float()
|
169 |
+
with torch.inference_mode():
|
170 |
+
y_hat = generator.infer(
|
171 |
+
phone_idx, phone_length, spec_length, attn, max_len=None, noise_scale=0.0
|
172 |
+
)[0]
|
173 |
+
wave = y_hat[0, 0].data.cpu().numpy()
|
174 |
+
return (wave * (2**15)).astype(np.int16)
|
175 |
+
|
176 |
+
|
177 |
+
def load_models():
|
178 |
+
duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
|
179 |
+
duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
|
180 |
+
duration_net = duration_net.eval()
|
181 |
+
generator = SynthesizerTrn(
|
182 |
+
hps.data.vocab_size,
|
183 |
+
hps.data.filter_length // 2 + 1,
|
184 |
+
hps.train.segment_size // hps.data.hop_length,
|
185 |
+
**vars(hps.model),
|
186 |
+
).to(device)
|
187 |
+
del generator.enc_q
|
188 |
+
ckpt = torch.load(lightspeed_model_path, map_location=device)
|
189 |
+
params = {}
|
190 |
+
for k, v in ckpt["net_g"].items():
|
191 |
+
k = k[7:] if k.startswith("module.") else k
|
192 |
+
params[k] = v
|
193 |
+
generator.load_state_dict(params, strict=False)
|
194 |
+
del ckpt, params
|
195 |
+
generator = generator.eval()
|
196 |
+
return duration_net, generator
|
197 |
+
|
198 |
+
|
199 |
+
def speak(text):
|
200 |
+
duration_net, generator = load_models()
|
201 |
+
paragraphs = text.split("\n")
|
202 |
+
clips = [] # list of audio clips
|
203 |
+
# silence = np.zeros(hps.data.sampling_rate // 4)
|
204 |
+
for paragraph in paragraphs:
|
205 |
+
paragraph = paragraph.strip()
|
206 |
+
if paragraph == "":
|
207 |
+
continue
|
208 |
+
clips.append(text_to_speech(duration_net, generator, paragraph))
|
209 |
+
# clips.append(silence)
|
210 |
+
y = np.concatenate(clips)
|
211 |
+
return hps.data.sampling_rate, y
|
212 |
+
|
213 |
+
|
214 |
+
gr.Interface(
|
215 |
+
fn=speak,
|
216 |
+
inputs="text",
|
217 |
+
outputs="audio",
|
218 |
+
title=title,
|
219 |
+
examples=[
|
220 |
+
"Trăm năm trong cõi người ta, chữ tài chữ mệnh khéo là ghét nhau.",
|
221 |
+
"Đoạn trường tân thanh, thường được biết đến với cái tên đơn giản là Truyện Kiều, là một truyện thơ của đại thi hào Nguyễn Du",
|
222 |
+
"Lục Vân Tiên quê ở huyện Đông Thành, khôi ngô tuấn tú, tài kiêm văn võ. Nghe tin triều đình mở khoa thi, Vân Tiên từ giã thầy xuống núi đua tài.",
|
223 |
+
"Lê Quý Đôn, tên thuở nhỏ là Lê Danh Phương, là vị quan thời Lê trung hưng, cũng là nhà thơ và được mệnh danh là nhà bác học lớn của Việt Nam trong thời phong kiến",
|
224 |
+
"Tất cả mọi người đều sinh ra có quyền bình đẳng. Tạo hóa cho họ những quyền không ai có thể xâm phạm được; trong những quyền ấy, có quyền được sống, quyền tự do và quyền mưu cầu hạnh phúc.",
|
225 |
+
],
|
226 |
+
description=description,
|
227 |
+
theme="default",
|
228 |
+
allow_screenshot=False,
|
229 |
+
allow_flagging="never",
|
230 |
+
).launch(debug=False)
|
attentions.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
import commons
|
8 |
+
from modules import LayerNorm
|
9 |
+
|
10 |
+
|
11 |
+
class Encoder(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
hidden_channels,
|
15 |
+
filter_channels,
|
16 |
+
n_heads,
|
17 |
+
n_layers,
|
18 |
+
kernel_size=1,
|
19 |
+
p_dropout=0.0,
|
20 |
+
window_size=4,
|
21 |
+
**kwargs
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.hidden_channels = hidden_channels
|
25 |
+
self.filter_channels = filter_channels
|
26 |
+
self.n_heads = n_heads
|
27 |
+
self.n_layers = n_layers
|
28 |
+
self.kernel_size = kernel_size
|
29 |
+
self.p_dropout = p_dropout
|
30 |
+
self.window_size = window_size
|
31 |
+
|
32 |
+
self.drop = nn.Dropout(p_dropout)
|
33 |
+
self.attn_layers = nn.ModuleList()
|
34 |
+
self.norm_layers_1 = nn.ModuleList()
|
35 |
+
self.ffn_layers = nn.ModuleList()
|
36 |
+
self.norm_layers_2 = nn.ModuleList()
|
37 |
+
for i in range(self.n_layers):
|
38 |
+
self.attn_layers.append(
|
39 |
+
MultiHeadAttention(
|
40 |
+
hidden_channels,
|
41 |
+
hidden_channels,
|
42 |
+
n_heads,
|
43 |
+
p_dropout=p_dropout,
|
44 |
+
window_size=window_size,
|
45 |
+
)
|
46 |
+
)
|
47 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
48 |
+
self.ffn_layers.append(
|
49 |
+
FFN(
|
50 |
+
hidden_channels,
|
51 |
+
hidden_channels,
|
52 |
+
filter_channels,
|
53 |
+
kernel_size,
|
54 |
+
p_dropout=p_dropout,
|
55 |
+
)
|
56 |
+
)
|
57 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
58 |
+
|
59 |
+
def forward(self, x, x_mask):
|
60 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
61 |
+
x = x * x_mask
|
62 |
+
for i in range(self.n_layers):
|
63 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
64 |
+
y = self.drop(y)
|
65 |
+
x = self.norm_layers_1[i](x + y)
|
66 |
+
|
67 |
+
y = self.ffn_layers[i](x, x_mask)
|
68 |
+
y = self.drop(y)
|
69 |
+
x = self.norm_layers_2[i](x + y)
|
70 |
+
x = x * x_mask
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
class MultiHeadAttention(nn.Module):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
channels,
|
78 |
+
out_channels,
|
79 |
+
n_heads,
|
80 |
+
p_dropout=0.0,
|
81 |
+
window_size=None,
|
82 |
+
heads_share=True,
|
83 |
+
block_length=None,
|
84 |
+
proximal_bias=False,
|
85 |
+
proximal_init=False,
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
assert channels % n_heads == 0
|
89 |
+
|
90 |
+
self.channels = channels
|
91 |
+
self.out_channels = out_channels
|
92 |
+
self.n_heads = n_heads
|
93 |
+
self.p_dropout = p_dropout
|
94 |
+
self.window_size = window_size
|
95 |
+
self.heads_share = heads_share
|
96 |
+
self.block_length = block_length
|
97 |
+
self.proximal_bias = proximal_bias
|
98 |
+
self.proximal_init = proximal_init
|
99 |
+
# self.attn = None
|
100 |
+
|
101 |
+
self.k_channels = channels // n_heads
|
102 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
103 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
104 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
105 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
106 |
+
self.drop = nn.Dropout(p_dropout)
|
107 |
+
|
108 |
+
if window_size is not None:
|
109 |
+
n_heads_rel = 1 if heads_share else n_heads
|
110 |
+
rel_stddev = self.k_channels**-0.5
|
111 |
+
self.emb_rel_k = nn.Parameter(
|
112 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
113 |
+
* rel_stddev
|
114 |
+
)
|
115 |
+
self.emb_rel_v = nn.Parameter(
|
116 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
117 |
+
* rel_stddev
|
118 |
+
)
|
119 |
+
|
120 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
121 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
122 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
123 |
+
if proximal_init:
|
124 |
+
with torch.no_grad():
|
125 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
126 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
127 |
+
|
128 |
+
def forward(self, x, c, attn_mask=None):
|
129 |
+
q = self.conv_q(x)
|
130 |
+
k = self.conv_k(c)
|
131 |
+
v = self.conv_v(c)
|
132 |
+
|
133 |
+
x, _ = self.attention(q, k, v, mask=attn_mask)
|
134 |
+
|
135 |
+
x = self.conv_o(x)
|
136 |
+
return x
|
137 |
+
|
138 |
+
def attention(self, query, key, value, mask=None):
|
139 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
140 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
141 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
142 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
143 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
144 |
+
|
145 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
146 |
+
if self.window_size is not None:
|
147 |
+
assert (
|
148 |
+
t_s == t_t
|
149 |
+
), "Relative attention is only available for self-attention."
|
150 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
151 |
+
rel_logits = self._matmul_with_relative_keys(
|
152 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
153 |
+
)
|
154 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
155 |
+
scores = scores + scores_local
|
156 |
+
if self.proximal_bias:
|
157 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
158 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
159 |
+
device=scores.device, dtype=scores.dtype
|
160 |
+
)
|
161 |
+
if mask is not None:
|
162 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
163 |
+
if self.block_length is not None:
|
164 |
+
assert (
|
165 |
+
t_s == t_t
|
166 |
+
), "Local attention is only available for self-attention."
|
167 |
+
block_mask = (
|
168 |
+
torch.ones_like(scores)
|
169 |
+
.triu(-self.block_length)
|
170 |
+
.tril(self.block_length)
|
171 |
+
)
|
172 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
173 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
174 |
+
p_attn = self.drop(p_attn)
|
175 |
+
output = torch.matmul(p_attn, value)
|
176 |
+
if self.window_size is not None:
|
177 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
178 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
179 |
+
self.emb_rel_v, t_s
|
180 |
+
)
|
181 |
+
output = output + self._matmul_with_relative_values(
|
182 |
+
relative_weights, value_relative_embeddings
|
183 |
+
)
|
184 |
+
output = (
|
185 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
186 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
187 |
+
return output, p_attn
|
188 |
+
|
189 |
+
def _matmul_with_relative_values(self, x, y):
|
190 |
+
"""
|
191 |
+
x: [b, h, l, m]
|
192 |
+
y: [h or 1, m, d]
|
193 |
+
ret: [b, h, l, d]
|
194 |
+
"""
|
195 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
196 |
+
return ret
|
197 |
+
|
198 |
+
def _matmul_with_relative_keys(self, x, y):
|
199 |
+
"""
|
200 |
+
x: [b, h, l, d]
|
201 |
+
y: [h or 1, m, d]
|
202 |
+
ret: [b, h, l, m]
|
203 |
+
"""
|
204 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
205 |
+
return ret
|
206 |
+
|
207 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
208 |
+
max_relative_position = 2 * self.window_size + 1
|
209 |
+
# Pad first before slice to avoid using cond ops.
|
210 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
211 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
212 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
213 |
+
if pad_length > 0:
|
214 |
+
padded_relative_embeddings = F.pad(
|
215 |
+
relative_embeddings,
|
216 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
padded_relative_embeddings = relative_embeddings
|
220 |
+
used_relative_embeddings = padded_relative_embeddings[
|
221 |
+
:, slice_start_position:slice_end_position
|
222 |
+
]
|
223 |
+
return used_relative_embeddings
|
224 |
+
|
225 |
+
def _relative_position_to_absolute_position(self, x):
|
226 |
+
"""
|
227 |
+
x: [b, h, l, 2*l-1]
|
228 |
+
ret: [b, h, l, l]
|
229 |
+
"""
|
230 |
+
batch, heads, length, _ = x.size()
|
231 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
232 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
233 |
+
|
234 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
235 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
236 |
+
x_flat = F.pad(
|
237 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
238 |
+
)
|
239 |
+
|
240 |
+
# Reshape and slice out the padded elements.
|
241 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
242 |
+
:, :, :length, length - 1 :
|
243 |
+
]
|
244 |
+
return x_final
|
245 |
+
|
246 |
+
def _absolute_position_to_relative_position(self, x):
|
247 |
+
"""
|
248 |
+
x: [b, h, l, l]
|
249 |
+
ret: [b, h, l, 2*l-1]
|
250 |
+
"""
|
251 |
+
batch, heads, length, _ = x.shape
|
252 |
+
# padd along column
|
253 |
+
x = F.pad(
|
254 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
255 |
+
)
|
256 |
+
x_flat = x.view([batch, heads, length * length + length * (length - 1)])
|
257 |
+
# add 0's in the beginning that will skew the elements after reshape
|
258 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
259 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
260 |
+
return x_final
|
261 |
+
|
262 |
+
def _attention_bias_proximal(self, length):
|
263 |
+
"""Bias for self-attention to encourage attention to close positions.
|
264 |
+
Args:
|
265 |
+
length: an integer scalar.
|
266 |
+
Returns:
|
267 |
+
a Tensor with shape [1, 1, length, length]
|
268 |
+
"""
|
269 |
+
r = torch.arange(length, dtype=torch.float32)
|
270 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
271 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
272 |
+
|
273 |
+
|
274 |
+
class FFN(nn.Module):
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
in_channels,
|
278 |
+
out_channels,
|
279 |
+
filter_channels,
|
280 |
+
kernel_size,
|
281 |
+
p_dropout=0.0,
|
282 |
+
activation=None,
|
283 |
+
causal=False,
|
284 |
+
):
|
285 |
+
super().__init__()
|
286 |
+
self.in_channels = in_channels
|
287 |
+
self.out_channels = out_channels
|
288 |
+
self.filter_channels = filter_channels
|
289 |
+
self.kernel_size = kernel_size
|
290 |
+
self.p_dropout = p_dropout
|
291 |
+
self.activation = activation
|
292 |
+
self.causal = causal
|
293 |
+
|
294 |
+
if causal:
|
295 |
+
self.padding = self._causal_padding
|
296 |
+
else:
|
297 |
+
self.padding = self._same_padding
|
298 |
+
|
299 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
300 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
301 |
+
self.drop = nn.Dropout(p_dropout)
|
302 |
+
|
303 |
+
def forward(self, x, x_mask):
|
304 |
+
x = self.conv_1(self.padding(x * x_mask))
|
305 |
+
if self.activation == "gelu":
|
306 |
+
x = x * torch.sigmoid(1.702 * x)
|
307 |
+
else:
|
308 |
+
x = torch.relu(x)
|
309 |
+
x = self.drop(x)
|
310 |
+
x = self.conv_2(self.padding(x * x_mask))
|
311 |
+
return x * x_mask
|
312 |
+
|
313 |
+
def _causal_padding(self, x):
|
314 |
+
if self.kernel_size == 1:
|
315 |
+
return x
|
316 |
+
pad_l = self.kernel_size - 1
|
317 |
+
pad_r = 0
|
318 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
319 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
320 |
+
return x
|
321 |
+
|
322 |
+
def _same_padding(self, x):
|
323 |
+
if self.kernel_size == 1:
|
324 |
+
return x
|
325 |
+
pad_l = (self.kernel_size - 1) // 2
|
326 |
+
pad_r = self.kernel_size // 2
|
327 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
328 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
329 |
+
return x
|
commons.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def init_weights(m, mean=0.0, std=0.01):
|
8 |
+
classname = m.__class__.__name__
|
9 |
+
if classname.find("Conv") != -1:
|
10 |
+
m.weight.data.normal_(mean, std)
|
11 |
+
|
12 |
+
|
13 |
+
def get_padding(kernel_size, dilation=1):
|
14 |
+
return int((kernel_size * dilation - dilation) / 2)
|
15 |
+
|
16 |
+
|
17 |
+
def convert_pad_shape(pad_shape):
|
18 |
+
l = pad_shape[::-1]
|
19 |
+
pad_shape = [item for sublist in l for item in sublist]
|
20 |
+
return pad_shape
|
21 |
+
|
22 |
+
|
23 |
+
def intersperse(lst, item):
|
24 |
+
result = [item] * (len(lst) * 2 + 1)
|
25 |
+
result[1::2] = lst
|
26 |
+
return result
|
27 |
+
|
28 |
+
|
29 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
30 |
+
"""KL(P||Q)"""
|
31 |
+
kl = (logs_q - logs_p) - 0.5
|
32 |
+
kl += (
|
33 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
34 |
+
)
|
35 |
+
return kl
|
36 |
+
|
37 |
+
|
38 |
+
def rand_gumbel(shape):
|
39 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
40 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
41 |
+
return -torch.log(-torch.log(uniform_samples))
|
42 |
+
|
43 |
+
|
44 |
+
def rand_gumbel_like(x):
|
45 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
46 |
+
return g
|
47 |
+
|
48 |
+
|
49 |
+
def slice_segments(x, ids_str, segment_size=4):
|
50 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
51 |
+
for i in range(x.size(0)):
|
52 |
+
idx_str = ids_str[i]
|
53 |
+
idx_end = idx_str + segment_size
|
54 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
55 |
+
return ret
|
56 |
+
|
57 |
+
|
58 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
59 |
+
b, d, t = x.size()
|
60 |
+
if x_lengths is None:
|
61 |
+
x_lengths = t
|
62 |
+
ids_str_max = x_lengths - segment_size + 1
|
63 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
64 |
+
ret = slice_segments(x, ids_str, segment_size)
|
65 |
+
return ret, ids_str
|
66 |
+
|
67 |
+
|
68 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
69 |
+
position = torch.arange(length, dtype=torch.float)
|
70 |
+
num_timescales = channels // 2
|
71 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
72 |
+
num_timescales - 1
|
73 |
+
)
|
74 |
+
inv_timescales = min_timescale * torch.exp(
|
75 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
76 |
+
)
|
77 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
78 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
79 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
80 |
+
signal = signal.view(1, channels, length)
|
81 |
+
return signal
|
82 |
+
|
83 |
+
|
84 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
85 |
+
b, channels, length = x.size()
|
86 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
87 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
88 |
+
|
89 |
+
|
90 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
91 |
+
b, channels, length = x.size()
|
92 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
93 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
94 |
+
|
95 |
+
|
96 |
+
def subsequent_mask(length):
|
97 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
98 |
+
return mask
|
99 |
+
|
100 |
+
|
101 |
+
@torch.jit.script
|
102 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
103 |
+
n_channels_int = n_channels[0]
|
104 |
+
in_act = input_a + input_b
|
105 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
106 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
107 |
+
acts = t_act * s_act
|
108 |
+
return acts
|
109 |
+
|
110 |
+
|
111 |
+
def convert_pad_shape(pad_shape):
|
112 |
+
l = pad_shape[::-1]
|
113 |
+
pad_shape = [item for sublist in l for item in sublist]
|
114 |
+
return pad_shape
|
115 |
+
|
116 |
+
|
117 |
+
def shift_1d(x):
|
118 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
119 |
+
return x
|
120 |
+
|
121 |
+
|
122 |
+
def sequence_mask(length, max_length=None):
|
123 |
+
if max_length is None:
|
124 |
+
max_length = length.max()
|
125 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
126 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
127 |
+
|
128 |
+
|
129 |
+
def generate_path(duration, mask):
|
130 |
+
"""
|
131 |
+
duration: [b, 1, t_x]
|
132 |
+
mask: [b, 1, t_y, t_x]
|
133 |
+
"""
|
134 |
+
device = duration.device
|
135 |
+
|
136 |
+
b, _, t_y, t_x = mask.shape
|
137 |
+
cum_duration = torch.cumsum(duration, -1)
|
138 |
+
|
139 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
140 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
141 |
+
path = path.view(b, t_x, t_y)
|
142 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
143 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
144 |
+
return path
|
145 |
+
|
146 |
+
|
147 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
148 |
+
if isinstance(parameters, torch.Tensor):
|
149 |
+
parameters = [parameters]
|
150 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
151 |
+
norm_type = float(norm_type)
|
152 |
+
if clip_value is not None:
|
153 |
+
clip_value = float(clip_value)
|
154 |
+
|
155 |
+
total_norm = 0
|
156 |
+
for p in parameters:
|
157 |
+
param_norm = p.grad.data.norm(norm_type)
|
158 |
+
total_norm += param_norm.item() ** norm_type
|
159 |
+
if clip_value is not None:
|
160 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
161 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
162 |
+
return total_norm
|
config.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"learning_rate": 2e-4,
|
4 |
+
"betas": [
|
5 |
+
0.8,
|
6 |
+
0.99
|
7 |
+
],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"lr_decay": 0.999875,
|
10 |
+
"segment_size": 8192,
|
11 |
+
"c_mel": 45,
|
12 |
+
"c_kl": 1.0
|
13 |
+
},
|
14 |
+
"data": {
|
15 |
+
"vocab_size": 256,
|
16 |
+
"max_wav_value": 32768.0,
|
17 |
+
"sampling_rate": 16000,
|
18 |
+
"filter_length": 1024,
|
19 |
+
"hop_length": 256,
|
20 |
+
"win_length": 1024,
|
21 |
+
"n_mel_channels": 80,
|
22 |
+
"mel_fmin": 0.0,
|
23 |
+
"mel_fmax": null
|
24 |
+
},
|
25 |
+
"model": {
|
26 |
+
"inter_channels": 192,
|
27 |
+
"hidden_channels": 192,
|
28 |
+
"filter_channels": 768,
|
29 |
+
"n_heads": 2,
|
30 |
+
"n_layers": 6,
|
31 |
+
"kernel_size": 3,
|
32 |
+
"p_dropout": 0.1,
|
33 |
+
"resblock": "1",
|
34 |
+
"resblock_kernel_sizes": [
|
35 |
+
3,
|
36 |
+
7,
|
37 |
+
11
|
38 |
+
],
|
39 |
+
"resblock_dilation_sizes": [
|
40 |
+
[
|
41 |
+
1,
|
42 |
+
3,
|
43 |
+
5
|
44 |
+
],
|
45 |
+
[
|
46 |
+
1,
|
47 |
+
3,
|
48 |
+
5
|
49 |
+
],
|
50 |
+
[
|
51 |
+
1,
|
52 |
+
3,
|
53 |
+
5
|
54 |
+
]
|
55 |
+
],
|
56 |
+
"upsample_rates": [
|
57 |
+
8,
|
58 |
+
8,
|
59 |
+
2,
|
60 |
+
2
|
61 |
+
],
|
62 |
+
"upsample_initial_channel": 512,
|
63 |
+
"upsample_kernel_sizes": [
|
64 |
+
16,
|
65 |
+
16,
|
66 |
+
4,
|
67 |
+
4
|
68 |
+
],
|
69 |
+
"n_layers_q": 3,
|
70 |
+
"use_spectral_norm": false
|
71 |
+
}
|
72 |
+
}
|
duration_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e86ab30448a328933b112e5ed6c4c22d7f05f1673528e61d340c98a9cc899eb
|
3 |
+
size 1164051
|
flow.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from modules import WN
|
5 |
+
|
6 |
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
7 |
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
8 |
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
9 |
+
|
10 |
+
|
11 |
+
class ResidualCouplingLayer(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
channels,
|
15 |
+
hidden_channels,
|
16 |
+
kernel_size,
|
17 |
+
dilation_rate,
|
18 |
+
n_layers,
|
19 |
+
p_dropout=0,
|
20 |
+
gin_channels=0,
|
21 |
+
mean_only=False,
|
22 |
+
):
|
23 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
24 |
+
super().__init__()
|
25 |
+
self.channels = channels
|
26 |
+
self.hidden_channels = hidden_channels
|
27 |
+
self.kernel_size = kernel_size
|
28 |
+
self.dilation_rate = dilation_rate
|
29 |
+
self.n_layers = n_layers
|
30 |
+
self.half_channels = channels // 2
|
31 |
+
self.mean_only = mean_only
|
32 |
+
|
33 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
34 |
+
self.enc = WN(
|
35 |
+
hidden_channels,
|
36 |
+
kernel_size,
|
37 |
+
dilation_rate,
|
38 |
+
n_layers,
|
39 |
+
p_dropout=p_dropout,
|
40 |
+
gin_channels=gin_channels,
|
41 |
+
)
|
42 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
43 |
+
self.post.weight.data.zero_()
|
44 |
+
self.post.bias.data.zero_()
|
45 |
+
|
46 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
47 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
48 |
+
h = self.pre(x0) * x_mask
|
49 |
+
h = self.enc(h, x_mask, g=g)
|
50 |
+
stats = self.post(h) * x_mask
|
51 |
+
if not self.mean_only:
|
52 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
53 |
+
else:
|
54 |
+
m = stats
|
55 |
+
logs = torch.zeros_like(m)
|
56 |
+
|
57 |
+
if not reverse:
|
58 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
59 |
+
x = torch.cat([x0, x1], 1)
|
60 |
+
logdet = torch.sum(logs, [1, 2])
|
61 |
+
return x, logdet
|
62 |
+
else:
|
63 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
64 |
+
x = torch.cat([x0, x1], 1)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class Flip(nn.Module):
|
69 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
70 |
+
x = torch.flip(x, [1])
|
71 |
+
if not reverse:
|
72 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
73 |
+
return x, logdet
|
74 |
+
else:
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class ResidualCouplingBlock(nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
channels,
|
82 |
+
hidden_channels,
|
83 |
+
kernel_size,
|
84 |
+
dilation_rate,
|
85 |
+
n_layers,
|
86 |
+
n_flows=4,
|
87 |
+
gin_channels=0,
|
88 |
+
):
|
89 |
+
super().__init__()
|
90 |
+
self.channels = channels
|
91 |
+
self.hidden_channels = hidden_channels
|
92 |
+
self.kernel_size = kernel_size
|
93 |
+
self.dilation_rate = dilation_rate
|
94 |
+
self.n_layers = n_layers
|
95 |
+
self.n_flows = n_flows
|
96 |
+
self.gin_channels = gin_channels
|
97 |
+
|
98 |
+
self.flows = nn.ModuleList()
|
99 |
+
for i in range(n_flows):
|
100 |
+
self.flows.append(
|
101 |
+
ResidualCouplingLayer(
|
102 |
+
channels,
|
103 |
+
hidden_channels,
|
104 |
+
kernel_size,
|
105 |
+
dilation_rate,
|
106 |
+
n_layers,
|
107 |
+
gin_channels=gin_channels,
|
108 |
+
mean_only=True,
|
109 |
+
)
|
110 |
+
)
|
111 |
+
self.flows.append(Flip())
|
112 |
+
|
113 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
114 |
+
if not reverse:
|
115 |
+
for flow in self.flows:
|
116 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
117 |
+
else:
|
118 |
+
for flow in reversed(self.flows):
|
119 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
120 |
+
return x
|
gen_210k.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6f8485f44f492262492231e90633fead68b5db3f65bd1a73621d618a6f3a173
|
3 |
+
size 111280752
|
gen_543k.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8e11e4b12e3e9a67ad23ec0cfafbcbd8e810f83d55fb6c38f61a6e9886c94cc5
|
3 |
+
size 111280752
|
gen_630k.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93fd8e41a8138978e387db561e32ad6ca0798cb09d44aefa239add5ac47e13a6
|
3 |
+
size 111280317
|
models.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
8 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
9 |
+
|
10 |
+
import attentions
|
11 |
+
import commons
|
12 |
+
import modules
|
13 |
+
from commons import get_padding, init_weights
|
14 |
+
from flow import ResidualCouplingBlock
|
15 |
+
|
16 |
+
|
17 |
+
class PriorEncoder(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
n_vocab,
|
21 |
+
out_channels,
|
22 |
+
hidden_channels,
|
23 |
+
filter_channels,
|
24 |
+
n_heads,
|
25 |
+
n_layers,
|
26 |
+
kernel_size,
|
27 |
+
p_dropout,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.n_vocab = n_vocab
|
31 |
+
self.out_channels = out_channels
|
32 |
+
self.hidden_channels = hidden_channels
|
33 |
+
self.filter_channels = filter_channels
|
34 |
+
self.n_heads = n_heads
|
35 |
+
self.n_layers = n_layers
|
36 |
+
self.kernel_size = kernel_size
|
37 |
+
self.p_dropout = p_dropout
|
38 |
+
|
39 |
+
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
40 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
41 |
+
self.pre_attn_encoder = attentions.Encoder(
|
42 |
+
hidden_channels,
|
43 |
+
filter_channels,
|
44 |
+
n_heads,
|
45 |
+
n_layers // 2,
|
46 |
+
kernel_size,
|
47 |
+
p_dropout,
|
48 |
+
)
|
49 |
+
self.post_attn_encoder = attentions.Encoder(
|
50 |
+
hidden_channels,
|
51 |
+
filter_channels,
|
52 |
+
n_heads,
|
53 |
+
n_layers - n_layers // 2,
|
54 |
+
kernel_size,
|
55 |
+
p_dropout,
|
56 |
+
)
|
57 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
58 |
+
|
59 |
+
def forward(self, x, x_lengths, y_lengths, attn):
|
60 |
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
61 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
62 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
63 |
+
x.dtype
|
64 |
+
)
|
65 |
+
x = self.pre_attn_encoder(x * x_mask, x_mask)
|
66 |
+
y = torch.einsum("bht,blt->bhl", x, attn)
|
67 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
68 |
+
y.dtype
|
69 |
+
)
|
70 |
+
y = self.post_attn_encoder(y * y_mask, y_mask)
|
71 |
+
stats = self.proj(y) * y_mask
|
72 |
+
|
73 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
74 |
+
return y, m, logs, y_mask
|
75 |
+
|
76 |
+
|
77 |
+
class PosteriorEncoder(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
in_channels,
|
81 |
+
out_channels,
|
82 |
+
hidden_channels,
|
83 |
+
kernel_size,
|
84 |
+
dilation_rate,
|
85 |
+
n_layers,
|
86 |
+
gin_channels=0,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
self.in_channels = in_channels
|
90 |
+
self.out_channels = out_channels
|
91 |
+
self.hidden_channels = hidden_channels
|
92 |
+
self.kernel_size = kernel_size
|
93 |
+
self.dilation_rate = dilation_rate
|
94 |
+
self.n_layers = n_layers
|
95 |
+
self.gin_channels = gin_channels
|
96 |
+
|
97 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
98 |
+
self.enc = modules.WN(
|
99 |
+
hidden_channels,
|
100 |
+
kernel_size,
|
101 |
+
dilation_rate,
|
102 |
+
n_layers,
|
103 |
+
gin_channels=gin_channels,
|
104 |
+
)
|
105 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
106 |
+
|
107 |
+
def forward(self, x, x_lengths, g=None):
|
108 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
109 |
+
x.dtype
|
110 |
+
)
|
111 |
+
x = self.pre(x) * x_mask
|
112 |
+
x = self.enc(x, x_mask, g=g)
|
113 |
+
stats = self.proj(x) * x_mask
|
114 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
115 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
116 |
+
return z, m, logs, x_mask
|
117 |
+
|
118 |
+
|
119 |
+
class Generator(torch.nn.Module):
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
initial_channel,
|
123 |
+
resblock,
|
124 |
+
resblock_kernel_sizes,
|
125 |
+
resblock_dilation_sizes,
|
126 |
+
upsample_rates,
|
127 |
+
upsample_initial_channel,
|
128 |
+
upsample_kernel_sizes,
|
129 |
+
gin_channels=0,
|
130 |
+
):
|
131 |
+
super(Generator, self).__init__()
|
132 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
133 |
+
self.num_upsamples = len(upsample_rates)
|
134 |
+
self.conv_pre = Conv1d(
|
135 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
136 |
+
)
|
137 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
138 |
+
|
139 |
+
self.ups = nn.ModuleList()
|
140 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
141 |
+
self.ups.append(
|
142 |
+
weight_norm(
|
143 |
+
ConvTranspose1d(
|
144 |
+
upsample_initial_channel // (2**i),
|
145 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
146 |
+
k,
|
147 |
+
u,
|
148 |
+
padding=(k - u) // 2,
|
149 |
+
)
|
150 |
+
)
|
151 |
+
)
|
152 |
+
|
153 |
+
self.resblocks = nn.ModuleList()
|
154 |
+
for i in range(len(self.ups)):
|
155 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
156 |
+
for j, (k, d) in enumerate(
|
157 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
158 |
+
):
|
159 |
+
self.resblocks.append(resblock(ch, k, d))
|
160 |
+
|
161 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
162 |
+
self.ups.apply(init_weights)
|
163 |
+
|
164 |
+
if gin_channels != 0:
|
165 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
166 |
+
|
167 |
+
def forward(self, x, g=None):
|
168 |
+
x = self.conv_pre(x)
|
169 |
+
if g is not None:
|
170 |
+
x = x + self.cond(g)
|
171 |
+
|
172 |
+
for i in range(self.num_upsamples):
|
173 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
174 |
+
x = self.ups[i](x)
|
175 |
+
xs = None
|
176 |
+
for j in range(self.num_kernels):
|
177 |
+
if xs is None:
|
178 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
179 |
+
else:
|
180 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
181 |
+
x = xs / self.num_kernels
|
182 |
+
x = F.leaky_relu(x)
|
183 |
+
x = self.conv_post(x)
|
184 |
+
x = torch.tanh(x)
|
185 |
+
|
186 |
+
return x
|
187 |
+
|
188 |
+
def remove_weight_norm(self):
|
189 |
+
print("Removing weight norm...")
|
190 |
+
for l in self.ups:
|
191 |
+
remove_weight_norm(l)
|
192 |
+
for l in self.resblocks:
|
193 |
+
l.remove_weight_norm()
|
194 |
+
|
195 |
+
|
196 |
+
class DiscriminatorP(torch.nn.Module):
|
197 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
198 |
+
super(DiscriminatorP, self).__init__()
|
199 |
+
self.period = period
|
200 |
+
self.use_spectral_norm = use_spectral_norm
|
201 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
202 |
+
self.convs = nn.ModuleList(
|
203 |
+
[
|
204 |
+
norm_f(
|
205 |
+
Conv2d(
|
206 |
+
1,
|
207 |
+
32,
|
208 |
+
(kernel_size, 1),
|
209 |
+
(stride, 1),
|
210 |
+
padding=(get_padding(kernel_size, 1), 0),
|
211 |
+
)
|
212 |
+
),
|
213 |
+
norm_f(
|
214 |
+
Conv2d(
|
215 |
+
32,
|
216 |
+
128,
|
217 |
+
(kernel_size, 1),
|
218 |
+
(stride, 1),
|
219 |
+
padding=(get_padding(kernel_size, 1), 0),
|
220 |
+
)
|
221 |
+
),
|
222 |
+
norm_f(
|
223 |
+
Conv2d(
|
224 |
+
128,
|
225 |
+
512,
|
226 |
+
(kernel_size, 1),
|
227 |
+
(stride, 1),
|
228 |
+
padding=(get_padding(kernel_size, 1), 0),
|
229 |
+
)
|
230 |
+
),
|
231 |
+
norm_f(
|
232 |
+
Conv2d(
|
233 |
+
512,
|
234 |
+
1024,
|
235 |
+
(kernel_size, 1),
|
236 |
+
(stride, 1),
|
237 |
+
padding=(get_padding(kernel_size, 1), 0),
|
238 |
+
)
|
239 |
+
),
|
240 |
+
norm_f(
|
241 |
+
Conv2d(
|
242 |
+
1024,
|
243 |
+
1024,
|
244 |
+
(kernel_size, 1),
|
245 |
+
1,
|
246 |
+
padding=(get_padding(kernel_size, 1), 0),
|
247 |
+
)
|
248 |
+
),
|
249 |
+
]
|
250 |
+
)
|
251 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
fmap = []
|
255 |
+
|
256 |
+
# 1d to 2d
|
257 |
+
b, c, t = x.shape
|
258 |
+
if t % self.period != 0: # pad first
|
259 |
+
n_pad = self.period - (t % self.period)
|
260 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
261 |
+
t = t + n_pad
|
262 |
+
x = x.view(b, c, t // self.period, self.period)
|
263 |
+
|
264 |
+
for l in self.convs:
|
265 |
+
x = l(x)
|
266 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
267 |
+
fmap.append(x)
|
268 |
+
x = self.conv_post(x)
|
269 |
+
fmap.append(x)
|
270 |
+
x = torch.flatten(x, 1, -1)
|
271 |
+
|
272 |
+
return x, fmap
|
273 |
+
|
274 |
+
|
275 |
+
class DiscriminatorS(torch.nn.Module):
|
276 |
+
def __init__(self, use_spectral_norm=False):
|
277 |
+
super(DiscriminatorS, self).__init__()
|
278 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
279 |
+
self.convs = nn.ModuleList(
|
280 |
+
[
|
281 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
282 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
283 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
284 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
285 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
286 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
287 |
+
]
|
288 |
+
)
|
289 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
fmap = []
|
293 |
+
|
294 |
+
for l in self.convs:
|
295 |
+
x = l(x)
|
296 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
297 |
+
fmap.append(x)
|
298 |
+
x = self.conv_post(x)
|
299 |
+
fmap.append(x)
|
300 |
+
x = torch.flatten(x, 1, -1)
|
301 |
+
|
302 |
+
return x, fmap
|
303 |
+
|
304 |
+
|
305 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
306 |
+
def __init__(self, use_spectral_norm=False):
|
307 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
308 |
+
periods = [2, 3, 5, 7, 11]
|
309 |
+
|
310 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
311 |
+
discs = discs + [
|
312 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
313 |
+
]
|
314 |
+
self.discriminators = nn.ModuleList(discs)
|
315 |
+
|
316 |
+
def forward(self, y, y_hat):
|
317 |
+
y_d_rs = []
|
318 |
+
y_d_gs = []
|
319 |
+
fmap_rs = []
|
320 |
+
fmap_gs = []
|
321 |
+
for i, d in enumerate(self.discriminators):
|
322 |
+
y_d_r, fmap_r = d(y)
|
323 |
+
y_d_g, fmap_g = d(y_hat)
|
324 |
+
y_d_rs.append(y_d_r)
|
325 |
+
y_d_gs.append(y_d_g)
|
326 |
+
fmap_rs.append(fmap_r)
|
327 |
+
fmap_gs.append(fmap_g)
|
328 |
+
|
329 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
330 |
+
|
331 |
+
|
332 |
+
class SynthesizerTrn(nn.Module):
|
333 |
+
"""
|
334 |
+
Synthesizer for Training
|
335 |
+
"""
|
336 |
+
|
337 |
+
def __init__(
|
338 |
+
self,
|
339 |
+
n_vocab,
|
340 |
+
spec_channels,
|
341 |
+
segment_size,
|
342 |
+
inter_channels,
|
343 |
+
hidden_channels,
|
344 |
+
filter_channels,
|
345 |
+
n_heads,
|
346 |
+
n_layers,
|
347 |
+
kernel_size,
|
348 |
+
p_dropout,
|
349 |
+
resblock,
|
350 |
+
resblock_kernel_sizes,
|
351 |
+
resblock_dilation_sizes,
|
352 |
+
upsample_rates,
|
353 |
+
upsample_initial_channel,
|
354 |
+
upsample_kernel_sizes,
|
355 |
+
n_speakers=0,
|
356 |
+
gin_channels=0,
|
357 |
+
**kwargs
|
358 |
+
):
|
359 |
+
super().__init__()
|
360 |
+
self.n_vocab = n_vocab
|
361 |
+
self.spec_channels = spec_channels
|
362 |
+
self.inter_channels = inter_channels
|
363 |
+
self.hidden_channels = hidden_channels
|
364 |
+
self.filter_channels = filter_channels
|
365 |
+
self.n_heads = n_heads
|
366 |
+
self.n_layers = n_layers
|
367 |
+
self.kernel_size = kernel_size
|
368 |
+
self.p_dropout = p_dropout
|
369 |
+
self.resblock = resblock
|
370 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
371 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
372 |
+
self.upsample_rates = upsample_rates
|
373 |
+
self.upsample_initial_channel = upsample_initial_channel
|
374 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
375 |
+
self.segment_size = segment_size
|
376 |
+
self.n_speakers = n_speakers
|
377 |
+
self.gin_channels = gin_channels
|
378 |
+
|
379 |
+
self.enc_p = PriorEncoder(
|
380 |
+
n_vocab,
|
381 |
+
inter_channels,
|
382 |
+
hidden_channels,
|
383 |
+
filter_channels,
|
384 |
+
n_heads,
|
385 |
+
n_layers,
|
386 |
+
kernel_size,
|
387 |
+
p_dropout,
|
388 |
+
)
|
389 |
+
self.dec = Generator(
|
390 |
+
inter_channels,
|
391 |
+
resblock,
|
392 |
+
resblock_kernel_sizes,
|
393 |
+
resblock_dilation_sizes,
|
394 |
+
upsample_rates,
|
395 |
+
upsample_initial_channel,
|
396 |
+
upsample_kernel_sizes,
|
397 |
+
gin_channels=gin_channels,
|
398 |
+
)
|
399 |
+
self.enc_q = PosteriorEncoder(
|
400 |
+
spec_channels,
|
401 |
+
inter_channels,
|
402 |
+
hidden_channels,
|
403 |
+
5,
|
404 |
+
1,
|
405 |
+
16,
|
406 |
+
gin_channels=gin_channels,
|
407 |
+
)
|
408 |
+
self.flow = ResidualCouplingBlock(
|
409 |
+
inter_channels, hidden_channels, 5, 2, 4, gin_channels=gin_channels
|
410 |
+
)
|
411 |
+
|
412 |
+
if n_speakers > 1:
|
413 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
414 |
+
|
415 |
+
def forward(self, x, x_lengths, attn, y, y_lengths, sid=None):
|
416 |
+
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, y_lengths, attn=attn)
|
417 |
+
if self.n_speakers > 0:
|
418 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
419 |
+
else:
|
420 |
+
g = None
|
421 |
+
|
422 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
423 |
+
z_p = self.flow(z, y_mask, g=g)
|
424 |
+
|
425 |
+
z_slice, ids_slice = commons.rand_slice_segments(
|
426 |
+
z, y_lengths, self.segment_size
|
427 |
+
)
|
428 |
+
o = self.dec(z_slice, g=g)
|
429 |
+
l_length = None
|
430 |
+
return (
|
431 |
+
o,
|
432 |
+
l_length,
|
433 |
+
attn,
|
434 |
+
ids_slice,
|
435 |
+
x_mask,
|
436 |
+
y_mask,
|
437 |
+
(z, z_p, m_p, logs_p, m_q, logs_q),
|
438 |
+
)
|
439 |
+
|
440 |
+
def infer(
|
441 |
+
self,
|
442 |
+
x,
|
443 |
+
x_lengths,
|
444 |
+
y_lengths,
|
445 |
+
attn,
|
446 |
+
sid=None,
|
447 |
+
noise_scale=1,
|
448 |
+
max_len=None,
|
449 |
+
):
|
450 |
+
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, y_lengths, attn=attn)
|
451 |
+
if self.n_speakers > 0:
|
452 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
453 |
+
else:
|
454 |
+
g = None
|
455 |
+
|
456 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, attn.shape[1]), 1).to(
|
457 |
+
x_mask.dtype
|
458 |
+
)
|
459 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
460 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
461 |
+
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
|
462 |
+
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
463 |
+
|
464 |
+
|
465 |
+
class DurationNet(torch.nn.Module):
|
466 |
+
def __init__(self, vocab_size: int, dim: int, num_layers=2):
|
467 |
+
super().__init__()
|
468 |
+
self.embed = torch.nn.Embedding(vocab_size, embedding_dim=dim)
|
469 |
+
self.rnn = torch.nn.GRU(
|
470 |
+
dim,
|
471 |
+
dim,
|
472 |
+
num_layers=num_layers,
|
473 |
+
batch_first=True,
|
474 |
+
bidirectional=True,
|
475 |
+
dropout=0.2,
|
476 |
+
)
|
477 |
+
self.proj = torch.nn.Linear(2 * dim, 1)
|
478 |
+
|
479 |
+
def forward(self, token, lengths):
|
480 |
+
x = self.embed(token)
|
481 |
+
lengths = lengths.long().cpu()
|
482 |
+
x = pack_padded_sequence(
|
483 |
+
x, lengths=lengths, batch_first=True, enforce_sorted=False
|
484 |
+
)
|
485 |
+
x, _ = self.rnn(x)
|
486 |
+
x, _ = pad_packed_sequence(x, batch_first=True, total_length=token.shape[1])
|
487 |
+
x = self.proj(x)
|
488 |
+
x = torch.nn.functional.softplus(x)
|
489 |
+
return x
|
modules.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import Conv1d
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn.utils import remove_weight_norm, weight_norm
|
6 |
+
|
7 |
+
import commons
|
8 |
+
from commons import get_padding, init_weights
|
9 |
+
|
10 |
+
LRELU_SLOPE = 0.1
|
11 |
+
|
12 |
+
|
13 |
+
class LayerNorm(nn.Module):
|
14 |
+
def __init__(self, channels, eps=1e-5):
|
15 |
+
super().__init__()
|
16 |
+
self.channels = channels
|
17 |
+
self.eps = eps
|
18 |
+
|
19 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
20 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
x = x.transpose(1, -1)
|
24 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
25 |
+
return x.transpose(1, -1)
|
26 |
+
|
27 |
+
|
28 |
+
class ConvReluNorm(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
in_channels,
|
32 |
+
hidden_channels,
|
33 |
+
out_channels,
|
34 |
+
kernel_size,
|
35 |
+
n_layers,
|
36 |
+
p_dropout,
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
self.in_channels = in_channels
|
40 |
+
self.hidden_channels = hidden_channels
|
41 |
+
self.out_channels = out_channels
|
42 |
+
self.kernel_size = kernel_size
|
43 |
+
self.n_layers = n_layers
|
44 |
+
self.p_dropout = p_dropout
|
45 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
46 |
+
|
47 |
+
self.conv_layers = nn.ModuleList()
|
48 |
+
self.norm_layers = nn.ModuleList()
|
49 |
+
self.conv_layers.append(
|
50 |
+
nn.Conv1d(
|
51 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
52 |
+
)
|
53 |
+
)
|
54 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
55 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
56 |
+
for _ in range(n_layers - 1):
|
57 |
+
self.conv_layers.append(
|
58 |
+
nn.Conv1d(
|
59 |
+
hidden_channels,
|
60 |
+
hidden_channels,
|
61 |
+
kernel_size,
|
62 |
+
padding=kernel_size // 2,
|
63 |
+
)
|
64 |
+
)
|
65 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
66 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
67 |
+
self.proj.weight.data.zero_()
|
68 |
+
self.proj.bias.data.zero_()
|
69 |
+
|
70 |
+
def forward(self, x, x_mask):
|
71 |
+
x_org = x
|
72 |
+
for i in range(self.n_layers):
|
73 |
+
x = self.conv_layers[i](x * x_mask)
|
74 |
+
x = self.norm_layers[i](x)
|
75 |
+
x = self.relu_drop(x)
|
76 |
+
x = x_org + self.proj(x)
|
77 |
+
return x * x_mask
|
78 |
+
|
79 |
+
|
80 |
+
class DDSConv(nn.Module):
|
81 |
+
"""
|
82 |
+
Dialted and Depth-Separable Convolution
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
86 |
+
super().__init__()
|
87 |
+
self.channels = channels
|
88 |
+
self.kernel_size = kernel_size
|
89 |
+
self.n_layers = n_layers
|
90 |
+
self.p_dropout = p_dropout
|
91 |
+
|
92 |
+
self.drop = nn.Dropout(p_dropout)
|
93 |
+
self.convs_sep = nn.ModuleList()
|
94 |
+
self.convs_1x1 = nn.ModuleList()
|
95 |
+
self.norms_1 = nn.ModuleList()
|
96 |
+
self.norms_2 = nn.ModuleList()
|
97 |
+
for i in range(n_layers):
|
98 |
+
dilation = kernel_size**i
|
99 |
+
padding = (kernel_size * dilation - dilation) // 2
|
100 |
+
self.convs_sep.append(
|
101 |
+
nn.Conv1d(
|
102 |
+
channels,
|
103 |
+
channels,
|
104 |
+
kernel_size,
|
105 |
+
groups=channels,
|
106 |
+
dilation=dilation,
|
107 |
+
padding=padding,
|
108 |
+
)
|
109 |
+
)
|
110 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
111 |
+
self.norms_1.append(LayerNorm(channels))
|
112 |
+
self.norms_2.append(LayerNorm(channels))
|
113 |
+
|
114 |
+
def forward(self, x, x_mask, g=None):
|
115 |
+
if g is not None:
|
116 |
+
x = x + g
|
117 |
+
for i in range(self.n_layers):
|
118 |
+
y = self.convs_sep[i](x * x_mask)
|
119 |
+
y = self.norms_1[i](y)
|
120 |
+
y = F.gelu(y)
|
121 |
+
y = self.convs_1x1[i](y)
|
122 |
+
y = self.norms_2[i](y)
|
123 |
+
y = F.gelu(y)
|
124 |
+
y = self.drop(y)
|
125 |
+
x = x + y
|
126 |
+
return x * x_mask
|
127 |
+
|
128 |
+
|
129 |
+
class WN(torch.nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
hidden_channels,
|
133 |
+
kernel_size,
|
134 |
+
dilation_rate,
|
135 |
+
n_layers,
|
136 |
+
gin_channels=0,
|
137 |
+
p_dropout=0,
|
138 |
+
):
|
139 |
+
super(WN, self).__init__()
|
140 |
+
assert kernel_size % 2 == 1
|
141 |
+
self.hidden_channels = hidden_channels
|
142 |
+
self.kernel_size = (kernel_size,)
|
143 |
+
self.dilation_rate = dilation_rate
|
144 |
+
self.n_layers = n_layers
|
145 |
+
self.gin_channels = gin_channels
|
146 |
+
self.p_dropout = p_dropout
|
147 |
+
|
148 |
+
self.in_layers = torch.nn.ModuleList()
|
149 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
150 |
+
self.drop = nn.Dropout(p_dropout)
|
151 |
+
|
152 |
+
if gin_channels != 0:
|
153 |
+
cond_layer = torch.nn.Conv1d(
|
154 |
+
gin_channels, 2 * hidden_channels * n_layers, 1
|
155 |
+
)
|
156 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
157 |
+
|
158 |
+
for i in range(n_layers):
|
159 |
+
dilation = dilation_rate**i
|
160 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
161 |
+
in_layer = torch.nn.Conv1d(
|
162 |
+
hidden_channels,
|
163 |
+
2 * hidden_channels,
|
164 |
+
kernel_size,
|
165 |
+
dilation=dilation,
|
166 |
+
padding=padding,
|
167 |
+
)
|
168 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
169 |
+
self.in_layers.append(in_layer)
|
170 |
+
|
171 |
+
# last one is not necessary
|
172 |
+
if i < n_layers - 1:
|
173 |
+
res_skip_channels = 2 * hidden_channels
|
174 |
+
else:
|
175 |
+
res_skip_channels = hidden_channels
|
176 |
+
|
177 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
178 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
179 |
+
self.res_skip_layers.append(res_skip_layer)
|
180 |
+
|
181 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
182 |
+
output = torch.zeros_like(x)
|
183 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
184 |
+
|
185 |
+
if g is not None:
|
186 |
+
g = self.cond_layer(g)
|
187 |
+
|
188 |
+
for i in range(self.n_layers):
|
189 |
+
x_in = self.in_layers[i](x)
|
190 |
+
if g is not None:
|
191 |
+
cond_offset = i * 2 * self.hidden_channels
|
192 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
193 |
+
else:
|
194 |
+
g_l = torch.zeros_like(x_in)
|
195 |
+
|
196 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
197 |
+
acts = self.drop(acts)
|
198 |
+
|
199 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
200 |
+
if i < self.n_layers - 1:
|
201 |
+
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
202 |
+
x = (x + res_acts) * x_mask
|
203 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
204 |
+
else:
|
205 |
+
output = output + res_skip_acts
|
206 |
+
return output * x_mask
|
207 |
+
|
208 |
+
def remove_weight_norm(self):
|
209 |
+
if self.gin_channels != 0:
|
210 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
211 |
+
for l in self.in_layers:
|
212 |
+
torch.nn.utils.remove_weight_norm(l)
|
213 |
+
for l in self.res_skip_layers:
|
214 |
+
torch.nn.utils.remove_weight_norm(l)
|
215 |
+
|
216 |
+
|
217 |
+
class ResBlock1(torch.nn.Module):
|
218 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
219 |
+
super(ResBlock1, self).__init__()
|
220 |
+
self.convs1 = nn.ModuleList(
|
221 |
+
[
|
222 |
+
weight_norm(
|
223 |
+
Conv1d(
|
224 |
+
channels,
|
225 |
+
channels,
|
226 |
+
kernel_size,
|
227 |
+
1,
|
228 |
+
dilation=dilation[0],
|
229 |
+
padding=get_padding(kernel_size, dilation[0]),
|
230 |
+
)
|
231 |
+
),
|
232 |
+
weight_norm(
|
233 |
+
Conv1d(
|
234 |
+
channels,
|
235 |
+
channels,
|
236 |
+
kernel_size,
|
237 |
+
1,
|
238 |
+
dilation=dilation[1],
|
239 |
+
padding=get_padding(kernel_size, dilation[1]),
|
240 |
+
)
|
241 |
+
),
|
242 |
+
weight_norm(
|
243 |
+
Conv1d(
|
244 |
+
channels,
|
245 |
+
channels,
|
246 |
+
kernel_size,
|
247 |
+
1,
|
248 |
+
dilation=dilation[2],
|
249 |
+
padding=get_padding(kernel_size, dilation[2]),
|
250 |
+
)
|
251 |
+
),
|
252 |
+
]
|
253 |
+
)
|
254 |
+
self.convs1.apply(init_weights)
|
255 |
+
|
256 |
+
self.convs2 = nn.ModuleList(
|
257 |
+
[
|
258 |
+
weight_norm(
|
259 |
+
Conv1d(
|
260 |
+
channels,
|
261 |
+
channels,
|
262 |
+
kernel_size,
|
263 |
+
1,
|
264 |
+
dilation=1,
|
265 |
+
padding=get_padding(kernel_size, 1),
|
266 |
+
)
|
267 |
+
),
|
268 |
+
weight_norm(
|
269 |
+
Conv1d(
|
270 |
+
channels,
|
271 |
+
channels,
|
272 |
+
kernel_size,
|
273 |
+
1,
|
274 |
+
dilation=1,
|
275 |
+
padding=get_padding(kernel_size, 1),
|
276 |
+
)
|
277 |
+
),
|
278 |
+
weight_norm(
|
279 |
+
Conv1d(
|
280 |
+
channels,
|
281 |
+
channels,
|
282 |
+
kernel_size,
|
283 |
+
1,
|
284 |
+
dilation=1,
|
285 |
+
padding=get_padding(kernel_size, 1),
|
286 |
+
)
|
287 |
+
),
|
288 |
+
]
|
289 |
+
)
|
290 |
+
self.convs2.apply(init_weights)
|
291 |
+
|
292 |
+
def forward(self, x, x_mask=None):
|
293 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
294 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
295 |
+
if x_mask is not None:
|
296 |
+
xt = xt * x_mask
|
297 |
+
xt = c1(xt)
|
298 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
299 |
+
if x_mask is not None:
|
300 |
+
xt = xt * x_mask
|
301 |
+
xt = c2(xt)
|
302 |
+
x = xt + x
|
303 |
+
if x_mask is not None:
|
304 |
+
x = x * x_mask
|
305 |
+
return x
|
306 |
+
|
307 |
+
def remove_weight_norm(self):
|
308 |
+
for l in self.convs1:
|
309 |
+
remove_weight_norm(l)
|
310 |
+
for l in self.convs2:
|
311 |
+
remove_weight_norm(l)
|
312 |
+
|
313 |
+
|
314 |
+
class ResBlock2(torch.nn.Module):
|
315 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
316 |
+
super(ResBlock2, self).__init__()
|
317 |
+
self.convs = nn.ModuleList(
|
318 |
+
[
|
319 |
+
weight_norm(
|
320 |
+
Conv1d(
|
321 |
+
channels,
|
322 |
+
channels,
|
323 |
+
kernel_size,
|
324 |
+
1,
|
325 |
+
dilation=dilation[0],
|
326 |
+
padding=get_padding(kernel_size, dilation[0]),
|
327 |
+
)
|
328 |
+
),
|
329 |
+
weight_norm(
|
330 |
+
Conv1d(
|
331 |
+
channels,
|
332 |
+
channels,
|
333 |
+
kernel_size,
|
334 |
+
1,
|
335 |
+
dilation=dilation[1],
|
336 |
+
padding=get_padding(kernel_size, dilation[1]),
|
337 |
+
)
|
338 |
+
),
|
339 |
+
]
|
340 |
+
)
|
341 |
+
self.convs.apply(init_weights)
|
342 |
+
|
343 |
+
def forward(self, x, x_mask=None):
|
344 |
+
for c in self.convs:
|
345 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
346 |
+
if x_mask is not None:
|
347 |
+
xt = xt * x_mask
|
348 |
+
xt = c(xt)
|
349 |
+
x = xt + x
|
350 |
+
if x_mask is not None:
|
351 |
+
x = x * x_mask
|
352 |
+
return x
|
353 |
+
|
354 |
+
def remove_weight_norm(self):
|
355 |
+
for l in self.convs:
|
356 |
+
remove_weight_norm(l)
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libsndfile1-dev
|
phone_set.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
["[SEP]", "a", "b", "c", "d", "e", "g", "h", "i", "k", "l", "m", "n", "o", "p", "q", "r", "s", "sil", "spn", "t", "u", "v", "x", "y", "\u00e0", "\u00e1", "\u00e2", "\u00e3", "\u00e8", "\u00e9", "\u00ea", "\u00ec", "\u00ed", "\u00f2", "\u00f3", "\u00f4", "\u00f5", "\u00f9", "\u00fa", "\u00fd", "\u0103", "\u0111", "\u0129", "\u0169", "\u01a1", "\u01b0", "\u1ea1", "\u1ea3", "\u1ea5", "\u1ea7", "\u1ea9", "\u1eab", "\u1ead", "\u1eaf", "\u1eb1", "\u1eb3", "\u1eb5", "\u1eb7", "\u1eb9", "\u1ebb", "\u1ebd", "\u1ebf", "\u1ec1", "\u1ec3", "\u1ec5", "\u1ec7", "\u1ec9", "\u1ecb", "\u1ecd", "\u1ecf", "\u1ed1", "\u1ed3", "\u1ed5", "\u1ed7", "\u1ed9", "\u1edb", "\u1edd", "\u1edf", "\u1ee1", "\u1ee3", "\u1ee5", "\u1ee7", "\u1ee9", "\u1eeb", "\u1eed", "\u1eef", "\u1ef1", "\u1ef3", "\u1ef5", "\u1ef7", "\u1ef9"]
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
regex
|
3 |
+
torch
|