ATForest commited on
Commit
3d67931
0 Parent(s):

Duplicate from ATForest/test

Browse files
.flake8 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore =
3
+ # E203 whitespace before ':'
4
+ E203
5
+ D203,
6
+ # line too long
7
+ E501
8
+ per-file-ignores =
9
+ # imported but unused
10
+ # __init__.py: F401
11
+ test_*.py: F401
12
+ exclude =
13
+ .git,
14
+ __pycache__,
15
+ docs/source/conf.py,
16
+ old,
17
+ build,
18
+ dist,
19
+ .venv
20
+ pad*.py
21
+ max-complexity = 25
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitattributes copy ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ flagged
3
+ call-activate.bat
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Test
3
+ emoji: 😻
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: ATForest/test
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from textwrap import dedent
4
+
5
+ import edge_tts
6
+ import tempfile
7
+ from tts_voice import tts_order_voice
8
+
9
+ from english.translate import Translate
10
+ from english.split_text import sentence_split
11
+ from english.generator import generatorArticle
12
+
13
+ import random
14
+ import codecs
15
+ import torch
16
+ import librosa
17
+ from models import SynthesizerTrn
18
+
19
+ from scipy.io.wavfile import write
20
+ import utils
21
+ from mel_processing import mel_spectrogram_torch
22
+ from speaker_encoder.voice_encoder import SpeakerEncoder
23
+ from transformers import WavLMModel
24
+
25
+ language_dict = tts_order_voice
26
+
27
+ def parse_text(input):
28
+ text = generatorArticle(input).strip()
29
+
30
+ lines = text.split("\n")
31
+ lines = [line for line in lines if line != ""]
32
+ count = 0
33
+ for i, line in enumerate(lines):
34
+ if "```" in line:
35
+ count += 1
36
+ items = line.split("`")
37
+ if count % 2 == 1:
38
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
39
+ else:
40
+ lines[i] = "<br></code></pre>"
41
+ else:
42
+ if i > 0:
43
+ if count % 2 == 1:
44
+ line = line.replace("`", r"\`")
45
+ line = line.replace("<", "&lt;")
46
+ line = line.replace(">", "&gt;")
47
+ line = line.replace(" ", "&nbsp;")
48
+ line = line.replace("*", "&ast;")
49
+ line = line.replace("_", "&lowbar;")
50
+ line = line.replace("-", "&#45;")
51
+ line = line.replace(".", "&#46;")
52
+ line = line.replace("!", "&#33;")
53
+ line = line.replace("(", "&#40;")
54
+ line = line.replace(")", "&#41;")
55
+ line = line.replace("$", "&#36;")
56
+ lines[i] = "<br>" + line
57
+ return text
58
+
59
+ def predict(input):
60
+ article = parse_text(input)
61
+ yield article,article
62
+
63
+ async def text_to_speech_edge(text, language_code):
64
+ voice = language_dict[language_code]
65
+ communicate = edge_tts.Communicate(text, voice)
66
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
67
+ tmp_path = tmp_file.name
68
+ await communicate.save(tmp_path)
69
+
70
+ return tmp_path
71
+
72
+ def tran_2_chianese(text):
73
+ translate = Translate()
74
+ sentence_str = sentence_split(text)
75
+ i = 0
76
+ result=''
77
+ length = len(sentence_str)
78
+ while(i < length):
79
+ tmp = sentence_str[i]
80
+ print('\n'+tmp)
81
+ tran = translate.translateToZh(tmp)
82
+ result = result+tmp+'\n'+tran+'\n'
83
+ i+=1
84
+ return result
85
+
86
+ def readWorldsFile(file_path):
87
+ fp = codecs.open(file_path, 'r', encoding='gb2312')
88
+ lines = fp.readlines()
89
+ worlds ,paraphrase = [],[]
90
+ for line in lines:
91
+ tmp = line.split('|')
92
+ worlds.append(tmp[0].strip())
93
+ paraphrase.append(tmp[1].strip())
94
+ fp.close()
95
+ return worlds, paraphrase
96
+
97
+ def generatorWorlds(file_path):
98
+ worlds,paraphrase = readWorldsFile(file_path)
99
+ length = len(worlds)
100
+
101
+ index = 0
102
+ worlds_text = ''
103
+
104
+ while index < 15:
105
+ num = random.randint(0,length)
106
+ worlds_text += f'{worlds[num]},【{paraphrase[num]}】\n'
107
+ index += 1
108
+
109
+ print('\n' + worlds_text)
110
+ return worlds_text
111
+
112
+ def choose_word_from_file(input):
113
+ result = generatorWorlds(input.orig_name)
114
+ return result
115
+
116
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
+
118
+ smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
119
+
120
+ print("Loading FreeVC(24k)...")
121
+ hps = utils.get_hparams_from_file("configs/freevc-24.json")
122
+ freevc_24 = SynthesizerTrn(
123
+ hps.data.filter_length // 2 + 1,
124
+ hps.train.segment_size // hps.data.hop_length,
125
+ **hps.model).to(device)
126
+ _ = freevc_24.eval()
127
+ _ = utils.load_checkpoint("checkpoints/freevc-24.pth", freevc_24, None)
128
+
129
+ print("Loading WavLM for content...")
130
+ cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
131
+
132
+
133
+ def convert(model, src, tgt):
134
+ with torch.no_grad():
135
+ # tgt
136
+ wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
137
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
138
+ if model == "FreeVC" or model == "FreeVC (24kHz)":
139
+ g_tgt = smodel.embed_utterance(wav_tgt)
140
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
141
+ else:
142
+ wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
143
+ mel_tgt = mel_spectrogram_torch(
144
+ wav_tgt,
145
+ hps.data.filter_length,
146
+ hps.data.n_mel_channels,
147
+ hps.data.sampling_rate,
148
+ hps.data.hop_length,
149
+ hps.data.win_length,
150
+ hps.data.mel_fmin,
151
+ hps.data.mel_fmax
152
+ )
153
+ # src
154
+ wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
155
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
156
+ c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
157
+ # infer
158
+ if model == "FreeVC":
159
+ audio = freevc.infer(c, g=g_tgt)
160
+ elif model == "FreeVC-s":
161
+ audio = freevc_s.infer(c, mel=mel_tgt)
162
+ else:
163
+ audio = freevc_24.infer(c, g=g_tgt)
164
+ audio = audio[0][0].data.cpu().float().numpy()
165
+ if model == "FreeVC" or model == "FreeVC-s":
166
+ write("out.wav", hps.data.sampling_rate, audio)
167
+ else:
168
+ write("out.wav", 24000, audio)
169
+ out = "out.wav"
170
+ return out
171
+
172
+ with gr.Blocks(title="Learn English By AI", theme=gr.themes.Soft(text_size="sm")) as demo:
173
+ gr.HTML("<center>"
174
+ "<h1>OpenAI + 声音克隆:根据单词生成短文,帮助理解单词使用的语境!!</h1>"
175
+ "</center>")
176
+
177
+ with gr.Accordion("📒 相关信息", open=True):
178
+ _ = f"""OpenAI Prompt 的可选参数信息:
179
+ * 输入 10-15 个单词为宜
180
+ * prompt = '你是一个非常厉害的英语助手,请将'{'words'}'组成一篇英语文章,字数限制在100 字以内'
181
+ * Open AI 用的是限制账号,每分钟请求 3 次
182
+ * 单词文件:每个单词及解释单独成行,单词与注释同行,用 “|” 分割
183
+ """
184
+ gr.Markdown(dedent(_))
185
+
186
+ with gr.Row():
187
+
188
+ file = gr.File()
189
+ chooseBtn = gr.Button("从文件提取或输入 -》", variant="secondary")
190
+ user_input = gr.Textbox(
191
+ max_lines=5,
192
+ lines=3,
193
+ label="单词用逗号分割:",
194
+ placeholder="10-15 words will be better",
195
+ )
196
+
197
+ with gr.Column(scale=1):
198
+ submitBtn = gr.Button("开始生成英语短文", variant="primary")
199
+ chatbot = gr.Textbox(label="英语短文:", lines = 5, max_lines=8)
200
+
201
+ chooseBtn.click(
202
+ choose_word_from_file,
203
+ inputs=[file],
204
+ outputs=[user_input],
205
+ show_progress="full",
206
+ api_name="choose_word_from_file"
207
+ )
208
+
209
+ with gr.Column(scale=3):
210
+ with gr.Row():
211
+ tran_result = gr.Textbox(label="翻译结果", lines = 5,max_lines=8,scale=2)
212
+ tran_btn = gr.Button("翻译", variant="primary")
213
+
214
+ tran_btn.click(
215
+ tran_2_chianese,
216
+ inputs=[chatbot],
217
+ outputs=[tran_result],
218
+ show_progress="full",
219
+ api_name="tran_2_chianese"
220
+ )
221
+
222
+ with gr.Column(min_width=32, scale=2):
223
+ with gr.Row():
224
+ with gr.Column():
225
+ language = gr.Dropdown(choices=list(language_dict.keys()), value="普通话 (中国大陆)-Xiaoxiao-女", label="请选择文本对应的语言及您喜欢的说话人")
226
+ tts_btn = gr.Button("生成对应的音频吧", variant="primary")
227
+ output_audio = gr.Audio(type="filepath", label="为您生成的音频", interactive=False)
228
+
229
+ tts_btn.click(text_to_speech_edge, inputs=[chatbot, language], outputs=[output_audio])
230
+
231
+ with gr.Row():
232
+ model_choice = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC (24kHz)", label="Model", visible=False)
233
+ audio1 = output_audio
234
+ audio2 = gr.Audio(label="请上传您喜欢的声音进行声音克隆", type='filepath')
235
+ clone_btn = gr.Button("开始AI声音克隆吧", variant="primary")
236
+ audio_cloned = gr.Audio(label="为您生成的专属声音克隆音频", type='filepath')
237
+
238
+ clone_btn.click(convert, inputs=[model_choice, audio1, audio2], outputs=[audio_cloned])
239
+
240
+ user_input.submit(
241
+ predict,
242
+ [user_input],
243
+ [chatbot,tran_result],
244
+ show_progress="full",
245
+ )
246
+
247
+ submitBtn.click(
248
+ predict,
249
+ [user_input],
250
+ [chatbot,tran_result],
251
+ show_progress="full",
252
+ api_name="predict",
253
+ )
254
+ # submitBtn.click(reset_user_input, [], [user_input])
255
+
256
+ demo.queue().launch(show_error=True, debug=True)
app_backup.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import gradio as gr
5
+ from scipy.io.wavfile import write
6
+ from transformers import WavLMModel
7
+
8
+ import utils
9
+ from models import SynthesizerTrn
10
+ from mel_processing import mel_spectrogram_torch
11
+ from speaker_encoder.voice_encoder import SpeakerEncoder
12
+
13
+ import time
14
+ from textwrap import dedent
15
+
16
+ import mdtex2html
17
+ from loguru import logger
18
+ from transformers import AutoModel, AutoTokenizer
19
+
20
+ from tts_voice import tts_order_voice
21
+ import edge_tts
22
+ import tempfile
23
+ import anyio
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
28
+
29
+ print("Loading FreeVC(24k)...")
30
+ hps = utils.get_hparams_from_file("configs/freevc-24.json")
31
+ freevc_24 = SynthesizerTrn(
32
+ hps.data.filter_length // 2 + 1,
33
+ hps.train.segment_size // hps.data.hop_length,
34
+ **hps.model).to(device)
35
+ _ = freevc_24.eval()
36
+ _ = utils.load_checkpoint("checkpoints/freevc-24.pth", freevc_24, None)
37
+
38
+ print("Loading WavLM for content...")
39
+ cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
40
+
41
+ def convert(model, src, tgt):
42
+ with torch.no_grad():
43
+ # tgt
44
+ wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
45
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
46
+ if model == "FreeVC" or model == "FreeVC (24kHz)":
47
+ g_tgt = smodel.embed_utterance(wav_tgt)
48
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
49
+ else:
50
+ wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
51
+ mel_tgt = mel_spectrogram_torch(
52
+ wav_tgt,
53
+ hps.data.filter_length,
54
+ hps.data.n_mel_channels,
55
+ hps.data.sampling_rate,
56
+ hps.data.hop_length,
57
+ hps.data.win_length,
58
+ hps.data.mel_fmin,
59
+ hps.data.mel_fmax
60
+ )
61
+ # src
62
+ wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
63
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
64
+ c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
65
+ # infer
66
+ if model == "FreeVC":
67
+ audio = freevc.infer(c, g=g_tgt)
68
+ elif model == "FreeVC-s":
69
+ audio = freevc_s.infer(c, mel=mel_tgt)
70
+ else:
71
+ audio = freevc_24.infer(c, g=g_tgt)
72
+ audio = audio[0][0].data.cpu().float().numpy()
73
+ if model == "FreeVC" or model == "FreeVC-s":
74
+ write("out.wav", hps.data.sampling_rate, audio)
75
+ else:
76
+ write("out.wav", 24000, audio)
77
+ out = "out.wav"
78
+ return out
79
+
80
+ # GLM2
81
+
82
+ language_dict = tts_order_voice
83
+
84
+ # fix timezone in Linux
85
+ os.environ["TZ"] = "Asia/Shanghai"
86
+ try:
87
+ time.tzset() # type: ignore # pylint: disable=no-member
88
+ except Exception:
89
+ # Windows
90
+ logger.warning("Windows, cant run time.tzset()")
91
+
92
+ # model_name = "THUDM/chatglm2-6b"
93
+ model_name = "THUDM/chatglm2-6b-int4"
94
+
95
+ RETRY_FLAG = False
96
+
97
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
98
+
99
+ # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
100
+
101
+ # 4/8 bit
102
+ # model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
103
+
104
+ has_cuda = torch.cuda.is_available()
105
+
106
+ # has_cuda = False # force cpu
107
+
108
+ if has_cuda:
109
+ model_glm = (
110
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
111
+ ) # 3.92G
112
+ else:
113
+ model_glm = AutoModel.from_pretrained(
114
+ model_name, trust_remote_code=True
115
+ ).float() # .float() .half().float()
116
+
117
+ model_glm = model_glm.eval()
118
+
119
+ _ = """Override Chatbot.postprocess"""
120
+
121
+
122
+ def postprocess(self, y):
123
+ if y is None:
124
+ return []
125
+ for i, (message, response) in enumerate(y):
126
+ y[i] = (
127
+ None if message is None else mdtex2html.convert((message)),
128
+ None if response is None else mdtex2html.convert(response),
129
+ )
130
+ return y
131
+
132
+
133
+ gr.Chatbot.postprocess = postprocess
134
+
135
+
136
+ def parse_text(text):
137
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
138
+ lines = text.split("\n")
139
+ lines = [line for line in lines if line != ""]
140
+ count = 0
141
+ for i, line in enumerate(lines):
142
+ if "```" in line:
143
+ count += 1
144
+ items = line.split("`")
145
+ if count % 2 == 1:
146
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
147
+ else:
148
+ lines[i] = "<br></code></pre>"
149
+ else:
150
+ if i > 0:
151
+ if count % 2 == 1:
152
+ line = line.replace("`", r"\`")
153
+ line = line.replace("<", "&lt;")
154
+ line = line.replace(">", "&gt;")
155
+ line = line.replace(" ", "&nbsp;")
156
+ line = line.replace("*", "&ast;")
157
+ line = line.replace("_", "&lowbar;")
158
+ line = line.replace("-", "&#45;")
159
+ line = line.replace(".", "&#46;")
160
+ line = line.replace("!", "&#33;")
161
+ line = line.replace("(", "&#40;")
162
+ line = line.replace(")", "&#41;")
163
+ line = line.replace("$", "&#36;")
164
+ lines[i] = "<br>" + line
165
+ text = "".join(lines)
166
+ return text
167
+
168
+ def predict(
169
+ RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values
170
+ ):
171
+ try:
172
+ chatbot.append((parse_text(input), ""))
173
+ except Exception as exc:
174
+ logger.error(exc)
175
+ logger.debug(f"{chatbot=}")
176
+ _ = """
177
+ if chatbot:
178
+ chatbot[-1] = (parse_text(input), str(exc))
179
+ yield chatbot, history, past_key_values
180
+ # """
181
+ yield chatbot, history, past_key_values
182
+
183
+ for response, history, past_key_values in model_glm.stream_chat(
184
+ tokenizer,
185
+ input,
186
+ history,
187
+ past_key_values=past_key_values,
188
+ return_past_key_values=True,
189
+ max_length=max_length,
190
+ top_p=top_p,
191
+ temperature=temperature,
192
+ ):
193
+ chatbot[-1] = (parse_text(input), parse_text(response))
194
+ # chatbot[-1][-1] = parse_text(response)
195
+
196
+ yield chatbot, history, past_key_values, parse_text(response)
197
+
198
+ def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
199
+ if max_length < 10:
200
+ max_length = 4096
201
+ if top_p < 0.1 or top_p > 1:
202
+ top_p = 0.85
203
+ if temperature <= 0 or temperature > 1:
204
+ temperature = 0.01
205
+ try:
206
+ res, _ = model_glm.chat(
207
+ tokenizer,
208
+ input,
209
+ history=[],
210
+ past_key_values=None,
211
+ max_length=max_length,
212
+ top_p=top_p,
213
+ temperature=temperature,
214
+ )
215
+ # logger.debug(f"{res=} \n{_=}")
216
+ except Exception as exc:
217
+ logger.error(f"{exc=}")
218
+ res = str(exc)
219
+
220
+ return res
221
+
222
+ def reset_user_input():
223
+ return gr.update(value="")
224
+
225
+
226
+ def reset_state():
227
+ return [], [], None, ""
228
+
229
+
230
+ # Delete last turn
231
+ def delete_last_turn(chat, history):
232
+ if chat and history:
233
+ chat.pop(-1)
234
+ history.pop(-1)
235
+ return chat, history
236
+
237
+ # Regenerate response
238
+ def retry_last_answer(
239
+ user_input, chatbot, max_length, top_p, temperature, history, past_key_values
240
+ ):
241
+ if chatbot and history:
242
+ # Removing the previous conversation from chat
243
+ chatbot.pop(-1)
244
+ # Setting up a flag to capture a retry
245
+ RETRY_FLAG = True
246
+ # Getting last message from user
247
+ user_input = history[-1][0]
248
+ # Removing bot response from the history
249
+ history.pop(-1)
250
+
251
+ yield from predict(
252
+ RETRY_FLAG, # type: ignore
253
+ user_input,
254
+ chatbot,
255
+ max_length,
256
+ top_p,
257
+ temperature,
258
+ history,
259
+ past_key_values,
260
+ )
261
+
262
+ # print
263
+
264
+ def print(text):
265
+ return text
266
+
267
+ # TTS
268
+
269
+ async def text_to_speech_edge(text, language_code):
270
+ voice = language_dict[language_code]
271
+ communicate = edge_tts.Communicate(text, voice)
272
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
273
+ tmp_path = tmp_file.name
274
+
275
+ await communicate.save(tmp_path)
276
+
277
+ return tmp_path
278
+
279
+ with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm")) as demo:
280
+ gr.HTML("<center>"
281
+ "<h1>🥳💕🎶 - ChatGLM2 + 声音克隆:和你喜欢的角色畅所欲言吧!</h1>"
282
+ "</center>")
283
+
284
+ with gr.Accordion("📒 相关信息", open=False):
285
+ _ = f""" ChatGLM2的可选参数信息:
286
+ * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
287
+ * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
288
+ * Top P controls dynamic vocabulary selection based on context.\n
289
+ 如果您想让ChatGLM2进行角色扮演并与之对话,请先输入恰当的提示词,如“请你扮演成动漫角色蜡笔小新并和我进行对话”;您也可以为ChatGLM2提供自定义的角色设定\n
290
+ 当您使用声音克隆功能时,请先在此程序的对应位置上传一段您喜欢的音频
291
+ """
292
+ gr.Markdown(dedent(_))
293
+
294
+ chatbot = gr.Chatbot(height=300)
295
+ with gr.Row():
296
+ with gr.Column(scale=4):
297
+ with gr.Column(scale=12):
298
+ user_input = gr.Textbox(
299
+ label="请在此处和GLM2聊天 (按回车键即可发送)",
300
+ placeholder="聊点什么吧",
301
+ )
302
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
303
+ with gr.Column(min_width=32, scale=1):
304
+ with gr.Row():
305
+ submitBtn = gr.Button("开始和GLM2交流吧", variant="primary")
306
+ deleteBtn = gr.Button("删除最新一轮对话", variant="secondary")
307
+ retryBtn = gr.Button("重新生成最新一轮对话", variant="secondary")
308
+
309
+ with gr.Accordion("🔧 更多设置", open=False):
310
+ with gr.Row():
311
+ emptyBtn = gr.Button("清空所有聊天记录")
312
+ max_length = gr.Slider(
313
+ 0,
314
+ 32768,
315
+ value=8192,
316
+ step=1.0,
317
+ label="Maximum length",
318
+ interactive=True,
319
+ )
320
+ top_p = gr.Slider(
321
+ 0, 1, value=0.85, step=0.01, label="Top P", interactive=True
322
+ )
323
+ temperature = gr.Slider(
324
+ 0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True
325
+ )
326
+
327
+
328
+ with gr.Row():
329
+ test1 = gr.Textbox(label="GLM2的最新回答 (可编辑)", lines = 3)
330
+ with gr.Column():
331
+ language = gr.Dropdown(choices=list(language_dict.keys()), value="普通话 (中国大陆)-Xiaoxiao-女", label="请选择文本对应的语言及您喜欢的说话人")
332
+ tts_btn = gr.Button("生成对应的音频吧", variant="primary")
333
+ output_audio = gr.Audio(type="filepath", label="为您生成的音频", interactive=False)
334
+
335
+ tts_btn.click(text_to_speech_edge, inputs=[test1, language], outputs=[output_audio])
336
+
337
+ with gr.Row():
338
+ model_choice = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC (24kHz)", label="Model", visible=False)
339
+ audio1 = output_audio
340
+ audio2 = gr.Audio(label="请上传您喜欢的声音进行声音克隆", type='filepath')
341
+ clone_btn = gr.Button("开始AI声音克隆吧", variant="primary")
342
+ audio_cloned = gr.Audio(label="为您生成的专属声音克隆音频", type='filepath')
343
+
344
+ clone_btn.click(convert, inputs=[model_choice, audio1, audio2], outputs=[audio_cloned])
345
+
346
+ history = gr.State([])
347
+ past_key_values = gr.State(None)
348
+
349
+ user_input.submit(
350
+ predict,
351
+ [
352
+ RETRY_FLAG,
353
+ user_input,
354
+ chatbot,
355
+ max_length,
356
+ top_p,
357
+ temperature,
358
+ history,
359
+ past_key_values,
360
+ ],
361
+ [chatbot, history, past_key_values, test1],
362
+ show_progress="full",
363
+ )
364
+ submitBtn.click(
365
+ predict,
366
+ [
367
+ RETRY_FLAG,
368
+ user_input,
369
+ chatbot,
370
+ max_length,
371
+ top_p,
372
+ temperature,
373
+ history,
374
+ past_key_values,
375
+ ],
376
+ [chatbot, history, past_key_values, test1],
377
+ show_progress="full",
378
+ api_name="predict",
379
+ )
380
+ submitBtn.click(reset_user_input, [], [user_input])
381
+
382
+ emptyBtn.click(
383
+ reset_state, outputs=[chatbot, history, past_key_values, test1], show_progress="full"
384
+ )
385
+
386
+ retryBtn.click(
387
+ retry_last_answer,
388
+ inputs=[
389
+ user_input,
390
+ chatbot,
391
+ max_length,
392
+ top_p,
393
+ temperature,
394
+ history,
395
+ past_key_values,
396
+ ],
397
+ # outputs = [chatbot, history, last_user_message, user_message]
398
+ outputs=[chatbot, history, past_key_values, test1],
399
+ )
400
+ deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
401
+
402
+ with gr.Accordion("For Chat/Translation API", open=False):
403
+ input_text = gr.Text()
404
+ tr_btn = gr.Button("Go", variant="primary")
405
+ out_text = gr.Text()
406
+ tr_btn.click(
407
+ trans_api,
408
+ [input_text, max_length, top_p, temperature],
409
+ out_text,
410
+ # show_progress="full",
411
+ api_name="tr",
412
+ )
413
+ _ = """
414
+ input_text.submit(
415
+ trans_api,
416
+ [input_text, max_length, top_p, temperature],
417
+ out_text,
418
+ show_progress="full",
419
+ api_name="tr1",
420
+ )
421
+ # """
422
+
423
+ gr.Markdown("### <center>注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。</center>")
424
+ gr.Markdown("<center>💡 - 如何使用此程序:输入您对ChatGLM的提问后,依次点击“开始和GLM2交流吧”、“生成对应的音频吧”、“开始AI声音克隆吧”三个按键即可;使用声音克隆功能时,请先上传一段您喜欢的音频</center>")
425
+ gr.HTML('''
426
+ <div class="footer">
427
+ <p>Reedit by Forest, Thanks 明·顾璘</p>
428
+ </div>
429
+ ''')
430
+
431
+ demo.queue().launch(show_error=True, debug=True)
checkpoints/freevc-24.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b39a86fefbc9ec6e30be8d26ee2a6aa5ffe6d235f6ab15773d01cdf348e5b20
3
+ size 472644351
commons.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size*dilation - dilation)/2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def intersperse(lst, item):
25
+ result = [item] * (len(lst) * 2 + 1)
26
+ result[1::2] = lst
27
+ return result
28
+
29
+
30
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
31
+ """KL(P||Q)"""
32
+ kl = (logs_q - logs_p) - 0.5
33
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def rand_spec_segments(x, x_lengths=None, segment_size=4):
68
+ b, d, t = x.size()
69
+ if x_lengths is None:
70
+ x_lengths = t
71
+ ids_str_max = x_lengths - segment_size
72
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
73
+ ret = slice_segments(x, ids_str, segment_size)
74
+ return ret, ids_str
75
+
76
+
77
+ def get_timing_signal_1d(
78
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
79
+ position = torch.arange(length, dtype=torch.float)
80
+ num_timescales = channels // 2
81
+ log_timescale_increment = (
82
+ math.log(float(max_timescale) / float(min_timescale)) /
83
+ (num_timescales - 1))
84
+ inv_timescales = min_timescale * torch.exp(
85
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
86
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
87
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
88
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
89
+ signal = signal.view(1, channels, length)
90
+ return signal
91
+
92
+
93
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
94
+ b, channels, length = x.size()
95
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
96
+ return x + signal.to(dtype=x.dtype, device=x.device)
97
+
98
+
99
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
100
+ b, channels, length = x.size()
101
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
102
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
103
+
104
+
105
+ def subsequent_mask(length):
106
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
107
+ return mask
108
+
109
+
110
+ @torch.jit.script
111
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
112
+ n_channels_int = n_channels[0]
113
+ in_act = input_a + input_b
114
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
115
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
116
+ acts = t_act * s_act
117
+ return acts
118
+
119
+
120
+ def convert_pad_shape(pad_shape):
121
+ l = pad_shape[::-1]
122
+ pad_shape = [item for sublist in l for item in sublist]
123
+ return pad_shape
124
+
125
+
126
+ def shift_1d(x):
127
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
128
+ return x
129
+
130
+
131
+ def sequence_mask(length, max_length=None):
132
+ if max_length is None:
133
+ max_length = length.max()
134
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
135
+ return x.unsqueeze(0) < length.unsqueeze(1)
136
+
137
+
138
+ def generate_path(duration, mask):
139
+ """
140
+ duration: [b, 1, t_x]
141
+ mask: [b, 1, t_y, t_x]
142
+ """
143
+ device = duration.device
144
+
145
+ b, _, t_y, t_x = mask.shape
146
+ cum_duration = torch.cumsum(duration, -1)
147
+
148
+ cum_duration_flat = cum_duration.view(b * t_x)
149
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
150
+ path = path.view(b, t_x, t_y)
151
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
152
+ path = path.unsqueeze(1).transpose(2,3) * mask
153
+ return path
154
+
155
+
156
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
157
+ if isinstance(parameters, torch.Tensor):
158
+ parameters = [parameters]
159
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
160
+ norm_type = float(norm_type)
161
+ if clip_value is not None:
162
+ clip_value = float(clip_value)
163
+
164
+ total_norm = 0
165
+ for p in parameters:
166
+ param_norm = p.grad.data.norm(norm_type)
167
+ total_norm += param_norm.item() ** norm_type
168
+ if clip_value is not None:
169
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
170
+ total_norm = total_norm ** (1. / norm_type)
171
+ return total_norm
configs/freevc-24.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 10000,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 64,
11
+ "fp16_run": false,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8640,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0,
18
+ "use_sr": true,
19
+ "max_speclen": 128,
20
+ "port": "8008"
21
+ },
22
+ "data": {
23
+ "training_files":"filelists/train.txt",
24
+ "validation_files":"filelists/val.txt",
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 16000,
27
+ "filter_length": 1280,
28
+ "hop_length": 320,
29
+ "win_length": 1280,
30
+ "n_mel_channels": 80,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null
33
+ },
34
+ "model": {
35
+ "inter_channels": 192,
36
+ "hidden_channels": 192,
37
+ "filter_channels": 768,
38
+ "n_heads": 2,
39
+ "n_layers": 6,
40
+ "kernel_size": 3,
41
+ "p_dropout": 0.1,
42
+ "resblock": "1",
43
+ "resblock_kernel_sizes": [3,7,11],
44
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
45
+ "upsample_rates": [10,6,4,2],
46
+ "upsample_initial_channel": 512,
47
+ "upsample_kernel_sizes": [16,16,4,4],
48
+ "n_layers_q": 3,
49
+ "use_spectral_norm": false,
50
+ "gin_channels": 256,
51
+ "ssl_dim": 1024,
52
+ "use_spk": true
53
+ }
54
+ }
english/generator.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import openai
3
+ openai.organization = ""
4
+
5
+
6
+ TEMPERATURE = 1
7
+ # COMPLETION_MODEL = "gpt-3.5-turbo-0613"
8
+ COMPLETION_MODEL ='text-davinci-003'
9
+
10
+ def generatorByOpenAI(worlds):
11
+ prompt = f'你是一个非常厉害的英语助手,请将{worlds}组成一篇英语文章,字数限制在100 字以内'
12
+ response = openai.Completion.create(
13
+ prompt=prompt,
14
+ engine = COMPLETION_MODEL,
15
+ temperature = TEMPERATURE,
16
+ max_tokens = 512,
17
+ n = 1,
18
+ stop = None,
19
+ timeout = 20
20
+ )
21
+
22
+ # print(response)
23
+ return response
24
+
25
+ def generatorArticle(worlds):
26
+ result = generatorByOpenAI(worlds)
27
+ return result.choices[0].text
english/models/deltalm/configuration_deltalm.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """ deltalm model configuration"""
5
+
6
+ import warnings
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from transformers.utils import logging
9
+ logger = logging.get_logger(__name__)
10
+
11
+ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
12
+ "IDEA/Deltalm": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
13
+ # See all deltalm models at https://huggingface.co/models?filter=deltam
14
+ }
15
+
16
+
17
+ class DeltalmConfig(PretrainedConfig):
18
+ r"""
19
+ This is the configuration class to store the configuration of a [`DeltalmModel`]. It is used to instantiate a Deltalm
20
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
21
+ defaults will yield a similar configuration to that of the BART
22
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+ Args:
26
+ vocab_size (`int`, *optional*, defaults to 50265):
27
+ Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
28
+ `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
29
+ d_model (`int`, *optional*, defaults to 1024):
30
+ Dimensionality of the layers and the pooler layer.
31
+ encoder_layers (`int`, *optional*, defaults to 12):
32
+ Number of encoder layers.
33
+ decoder_layers (`int`, *optional*, defaults to 12):
34
+ Number of decoder layers.
35
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
38
+ Number of attention heads for each attention layer in the Transformer decoder.
39
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
40
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
41
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
42
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
43
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
44
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
45
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
46
+ dropout (`float`, *optional*, defaults to 0.1):
47
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
48
+ attention_dropout (`float`, *optional*, defaults to 0.0):
49
+ The dropout ratio for the attention probabilities.
50
+ activation_dropout (`float`, *optional*, defaults to 0.0):
51
+ The dropout ratio for activations inside the fully connected layer.
52
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
53
+ The dropout ratio for classifier.
54
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
55
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
56
+ just in case (e.g., 512 or 1024 or 2048).
57
+ init_std (`float`, *optional*, defaults to 0.02):
58
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
59
+ encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
60
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
61
+ for more details.
62
+ decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
63
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
64
+ for more details.
65
+ scale_embedding (`bool`, *optional*, defaults to `False`):
66
+ Scale embeddings by diving by sqrt(d_model).
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models).
69
+ num_labels: (`int`, *optional*, defaults to 3):
70
+ The number of labels to use in [`BartForSequenceClassification`].
71
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
72
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
73
+ `eos_token_id`.
74
+ Example:
75
+ ```python
76
+ >>> from transformers import BartModel, BartConfig
77
+ >>> # Initializing a BART facebook/bart-large style configuration
78
+ >>> configuration = BartConfig()
79
+ >>> # Initializing a model from the facebook/bart-large style configuration
80
+ >>> model = BartModel(configuration)
81
+ >>> # Accessing the model configuration
82
+ >>> configuration = model.config
83
+ ```"""
84
+ model_type = "Deltalm"
85
+ keys_to_ignore_at_inference = ["past_key_values"]
86
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
87
+
88
+ def __init__(
89
+ self,
90
+ vocab_size=250001,
91
+ max_position_embeddings=1024,
92
+ encoder_layers=12,
93
+ encoder_ffn_dim=3072,
94
+ encoder_attention_heads=12,
95
+ decoder_layers=6,
96
+ decoder_ffn_dim=3072,
97
+ decoder_attention_heads=12,
98
+ encoder_layerdrop=0.0,
99
+ decoder_layerdrop=0.0,
100
+ activation_function="gelu",
101
+ d_model=1024,
102
+ dropout=0.1,
103
+ attention_dropout=0.0,
104
+ activation_dropout=0.0,
105
+ init_std=0.02,
106
+ classifier_dropout=0.0,
107
+ scale_embedding=False,
108
+ use_cache=True,
109
+ num_labels=3,
110
+ pad_token_id=1,
111
+ bos_token_id=0,
112
+ eos_token_id=2,
113
+ is_encoder_decoder=True,
114
+ decoder_start_token_id=0,
115
+ forced_eos_token_id=2,
116
+ label_smoothing=0.1,
117
+ length_penalty=1.0,
118
+ encoder_normalize_before=False,
119
+ **kwargs
120
+ ):
121
+ self.vocab_size = vocab_size
122
+ self.max_position_embeddings = max_position_embeddings
123
+ self.d_model = d_model
124
+ self.encoder_ffn_dim = encoder_ffn_dim
125
+ self.encoder_layers = encoder_layers
126
+ self.encoder_attention_heads = encoder_attention_heads
127
+ self.decoder_ffn_dim = decoder_ffn_dim
128
+ self.decoder_layers = decoder_layers
129
+ self.decoder_attention_heads = decoder_attention_heads
130
+ self.dropout = dropout
131
+ self.attention_dropout = attention_dropout
132
+ self.activation_dropout = activation_dropout
133
+ self.activation_function = activation_function
134
+ self.init_std = init_std
135
+ self.encoder_layerdrop = encoder_layerdrop
136
+ self.decoder_layerdrop = decoder_layerdrop
137
+ self.classifier_dropout = classifier_dropout
138
+ self.use_cache = use_cache
139
+ self.num_hidden_layers = encoder_layers
140
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
141
+ self.label_smoothing = label_smoothing
142
+ self.encoder_normalize_before = encoder_normalize_before
143
+
144
+ super().__init__(
145
+ num_labels=num_labels,
146
+ pad_token_id=pad_token_id,
147
+ bos_token_id=bos_token_id,
148
+ eos_token_id=eos_token_id,
149
+ is_encoder_decoder=is_encoder_decoder,
150
+ decoder_start_token_id=decoder_start_token_id,
151
+ forced_eos_token_id=forced_eos_token_id,
152
+ length_penalty=length_penalty,
153
+ **kwargs,
154
+ )
155
+
156
+ # ensure backward compatibility for BART CNN models
157
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
158
+ self.forced_bos_token_id = self.bos_token_id
159
+ warnings.warn(
160
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
161
+ "The config can simply be saved and uploaded again to be fixed."
162
+ )
163
+
164
+ @property
165
+ def num_attention_heads(self) -> int:
166
+ return self.encoder_attention_heads
167
+
168
+ @property
169
+ def hidden_size(self) -> int:
170
+ return self.d_model
english/models/deltalm/modeling_deltalm.py ADDED
@@ -0,0 +1,1551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ import copy
6
+ import math
7
+ import random
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint
11
+ from torch.nn import CrossEntropyLoss
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutput,
18
+ BaseModelOutputWithPastAndCrossAttentions,
19
+ CausalLMOutputWithCrossAttentions,
20
+ Seq2SeqModelOutput,
21
+ Seq2SeqLMOutput,
22
+ )
23
+ from transformers.file_utils import (
24
+ add_end_docstrings,
25
+ add_start_docstrings,
26
+ add_start_docstrings_to_model_forward,
27
+ replace_return_docstrings,
28
+ )
29
+
30
+ import logging
31
+ from .configuration_deltalm import DeltalmConfig
32
+ logger = logging.getLogger(__name__)
33
+
34
+ _CHECKPOINT_FOR_DOC = "IDEA-CCNL/Randeng-Deltalm-362M-En-Zn"
35
+ _CONFIG_FOR_DOC = "DeltalmConfig"
36
+ _TOKENIZER_FOR_DOC = "DeltalmTokenizer"
37
+
38
+ # Base model docstring
39
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
40
+
41
+
42
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
43
+ """
44
+ Shift input ids one token to the right.
45
+ """
46
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
47
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
48
+ shifted_input_ids[:, 0] = decoder_start_token_id
49
+
50
+ if pad_token_id is None:
51
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
52
+ # replace possible -100 values in labels by `pad_token_id`
53
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
54
+
55
+ return shifted_input_ids
56
+
57
+
58
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
59
+ """
60
+ Make causal mask used for bi-directional self-attention.
61
+ """
62
+ bsz, tgt_len = input_ids_shape
63
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
64
+ mask_cond = torch.arange(mask.size(-1))
65
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
66
+ mask = mask.to(dtype)
67
+
68
+ if past_key_values_length > 0:
69
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
70
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
71
+
72
+
73
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
74
+ """
75
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
76
+ """
77
+ bsz, src_len = mask.size()
78
+ tgt_len = tgt_len if tgt_len is not None else src_len
79
+
80
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
81
+
82
+ inverted_mask = 1.0 - expanded_mask
83
+
84
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
85
+
86
+
87
+ class DeltalmLearnedPositionalEmbedding(nn.Embedding):
88
+ """
89
+ This module learns positional embeddings up to a fixed maximum size.
90
+ """
91
+
92
+ def __init__(self, num_embeddings: int, embedding_dim: int):
93
+ # Deltalm is set up so that if padding_idx is specified then offset the embedding ids by 2
94
+ # and adjust num_embeddings appropriately. Other models don't have this hack
95
+ self.offset = 2
96
+ super().__init__(num_embeddings + self.offset, embedding_dim)
97
+
98
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
99
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
100
+ bsz, seq_len = input_ids_shape[:2]
101
+ positions = torch.arange(
102
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
103
+ )
104
+ return super().forward(positions + self.offset)
105
+
106
+
107
+ class DeltalmAttention(nn.Module):
108
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
109
+
110
+ def __init__(
111
+ self,
112
+ embed_dim: int,
113
+ num_heads: int,
114
+ dropout: float = 0.0,
115
+ is_decoder: bool = False,
116
+ bias: bool = True,
117
+ ):
118
+ super().__init__()
119
+ self.embed_dim = embed_dim
120
+ self.num_heads = num_heads
121
+ self.dropout = dropout
122
+ self.head_dim = embed_dim // num_heads
123
+
124
+ if (self.head_dim * num_heads) != self.embed_dim:
125
+ raise ValueError(
126
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
127
+ f" and `num_heads`: {num_heads})."
128
+ )
129
+ self.scaling = self.head_dim**-0.5
130
+ self.is_decoder = is_decoder
131
+
132
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
133
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
134
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
135
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
136
+
137
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
138
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
139
+
140
+ def forward(
141
+ self,
142
+ hidden_states: torch.Tensor,
143
+ key_value_states: Optional[torch.Tensor] = None,
144
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ layer_head_mask: Optional[torch.Tensor] = None,
147
+ output_attentions: bool = False,
148
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
149
+ """Input shape: Batch x Time x Channel"""
150
+
151
+ # if key_value_states are provided this layer is used as a cross-attention layer
152
+ # for the decoder
153
+ is_cross_attention = key_value_states is not None
154
+
155
+ bsz, tgt_len, _ = hidden_states.size()
156
+
157
+ # get query proj
158
+ query_states = self.q_proj(hidden_states) * self.scaling
159
+ # get key, value proj
160
+ if is_cross_attention and past_key_value is not None:
161
+ # reuse k,v, cross_attentions
162
+ key_states = past_key_value[0]
163
+ value_states = past_key_value[1]
164
+ elif is_cross_attention:
165
+ # cross_attentions
166
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
167
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
168
+ elif past_key_value is not None:
169
+ # reuse k, v, self_attention
170
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
171
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
172
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
173
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
174
+ else:
175
+ # self_attention
176
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
177
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
178
+
179
+ if self.is_decoder:
180
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
181
+ # Further calls to cross_attention layer can then reuse all cross-attention
182
+ # key/value_states (first "if" case)
183
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
184
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
185
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
186
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
187
+ past_key_value = (key_states, value_states)
188
+
189
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
190
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
191
+ key_states = key_states.view(*proj_shape)
192
+ value_states = value_states.view(*proj_shape)
193
+
194
+ src_len = key_states.size(1)
195
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
196
+
197
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
198
+ raise ValueError(
199
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
200
+ f" {attn_weights.size()}"
201
+ )
202
+
203
+ if attention_mask is not None:
204
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
205
+ raise ValueError(
206
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
207
+ )
208
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
209
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
210
+
211
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
212
+
213
+ if layer_head_mask is not None:
214
+ if layer_head_mask.size() != (self.num_heads,):
215
+ raise ValueError(
216
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
217
+ f" {layer_head_mask.size()}"
218
+ )
219
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
220
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
221
+
222
+ if output_attentions:
223
+ # this operation is a bit awkward, but it's required to
224
+ # make sure that attn_weights keeps its gradient.
225
+ # In order to do so, attn_weights have to be reshaped
226
+ # twice and have to be reused in the following
227
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
228
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
229
+ else:
230
+ attn_weights_reshaped = None
231
+
232
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
233
+
234
+ attn_output = torch.bmm(attn_probs, value_states)
235
+
236
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
237
+ raise ValueError(
238
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
239
+ f" {attn_output.size()}"
240
+ )
241
+
242
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
243
+ attn_output = attn_output.transpose(1, 2)
244
+
245
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
246
+ # partitioned aross GPUs when using tensor-parallelism.
247
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
248
+
249
+ attn_output = self.out_proj(attn_output)
250
+
251
+ return attn_output, attn_weights_reshaped, past_key_value
252
+
253
+
254
+ class DeltalmEncoderLayer(nn.Module):
255
+ def __init__(self, config: DeltalmConfig):
256
+ super().__init__()
257
+ self.embed_dim = config.d_model
258
+ self.self_attn = DeltalmAttention(
259
+ embed_dim=self.embed_dim,
260
+ num_heads=config.encoder_attention_heads,
261
+ dropout=config.attention_dropout,
262
+ )
263
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
264
+ self.dropout = config.dropout
265
+ self.activation_fn = ACT2FN[config.activation_function]
266
+ self.activation_dropout = config.activation_dropout
267
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
268
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
269
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states: torch.FloatTensor,
274
+ attention_mask: torch.FloatTensor,
275
+ layer_head_mask: torch.FloatTensor,
276
+ output_attentions: Optional[bool] = False,
277
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
278
+ """
279
+ Args:
280
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
281
+ attention_mask (`torch.FloatTensor`): attention mask of size
282
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
283
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
284
+ `(encoder_attention_heads,)`.
285
+ output_attentions (`bool`, *optional*):
286
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
287
+ returned tensors for more detail.
288
+ """
289
+ residual = hidden_states
290
+ hidden_states, attn_weights, _ = self.self_attn(
291
+ hidden_states=hidden_states,
292
+ attention_mask=attention_mask,
293
+ layer_head_mask=layer_head_mask,
294
+ output_attentions=output_attentions,
295
+ )
296
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
297
+ hidden_states = residual + hidden_states
298
+ hidden_states = self.self_attn_layer_norm(hidden_states)
299
+
300
+ residual = hidden_states
301
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
302
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
303
+ hidden_states = self.fc2(hidden_states)
304
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
305
+ hidden_states = residual + hidden_states
306
+ hidden_states = self.final_layer_norm(hidden_states)
307
+
308
+ if hidden_states.dtype == torch.float16 and (
309
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
310
+ ):
311
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
312
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
313
+
314
+ outputs = (hidden_states,)
315
+
316
+ if output_attentions:
317
+ outputs += (attn_weights,)
318
+
319
+ return outputs
320
+
321
+
322
+ class DeltalmDecoderLayer(nn.Module):
323
+ def __init__(self, config: DeltalmConfig):
324
+ super().__init__()
325
+ self.embed_dim = config.d_model
326
+
327
+ self.self_attn = DeltalmAttention(
328
+ embed_dim=self.embed_dim,
329
+ num_heads=config.decoder_attention_heads,
330
+ dropout=config.attention_dropout,
331
+ is_decoder=True,
332
+ )
333
+ self.dropout = config.dropout
334
+ self.activation_fn = ACT2FN[config.activation_function]
335
+ self.activation_dropout = config.activation_dropout
336
+
337
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
338
+ self.encoder_attn = DeltalmAttention(
339
+ self.embed_dim,
340
+ config.decoder_attention_heads,
341
+ dropout=config.attention_dropout,
342
+ is_decoder=True,
343
+ )
344
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
345
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
346
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
347
+ self.fc3 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
348
+ self.fc4 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
349
+
350
+ self.ffn_layer_norm = nn.LayerNorm(self.embed_dim)
351
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
352
+
353
+ def forward(
354
+ self,
355
+ hidden_states: torch.Tensor,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ encoder_hidden_states: Optional[torch.Tensor] = None,
358
+ encoder_attention_mask: Optional[torch.Tensor] = None,
359
+ layer_head_mask: Optional[torch.Tensor] = None,
360
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
361
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
362
+ output_attentions: Optional[bool] = False,
363
+ use_cache: Optional[bool] = True,
364
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
365
+ """
366
+ Args:
367
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
368
+ attention_mask (`torch.FloatTensor`): attention mask of size
369
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
370
+ encoder_hidden_states (`torch.FloatTensor`):
371
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
372
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
373
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
374
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
375
+ `(encoder_attention_heads,)`.
376
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
377
+ size `(decoder_attention_heads,)`.
378
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
379
+ output_attentions (`bool`, *optional*):
380
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
381
+ returned tensors for more detail.
382
+ """
383
+ residual = hidden_states
384
+
385
+ # Self Attention
386
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
387
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
388
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
389
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
390
+ hidden_states=hidden_states,
391
+ past_key_value=self_attn_past_key_value,
392
+ attention_mask=attention_mask,
393
+ layer_head_mask=layer_head_mask,
394
+ output_attentions=output_attentions,
395
+ )
396
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
397
+ hidden_states = residual + hidden_states
398
+ hidden_states = self.self_attn_layer_norm(hidden_states)
399
+
400
+ # Add another ffn after self-attention to keep the structure same to encoder-layer
401
+ residual = hidden_states
402
+ hidden_states = self.activation_fn(self.fc3(hidden_states))
403
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
404
+ hidden_states = self.fc4(hidden_states)
405
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
406
+ hidden_states = residual + hidden_states
407
+ hidden_states = self.ffn_layer_norm(hidden_states)
408
+
409
+ # Cross-Attention Block
410
+ cross_attn_present_key_value = None
411
+ cross_attn_weights = None
412
+ if encoder_hidden_states is not None:
413
+ residual = hidden_states
414
+
415
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
416
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
417
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
418
+ hidden_states=hidden_states,
419
+ key_value_states=encoder_hidden_states,
420
+ attention_mask=encoder_attention_mask,
421
+ layer_head_mask=cross_attn_layer_head_mask,
422
+ past_key_value=cross_attn_past_key_value,
423
+ output_attentions=output_attentions,
424
+ )
425
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
426
+ hidden_states = residual + hidden_states
427
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
428
+
429
+ # add cross-attn to positions 3,4 of present_key_value tuple
430
+ present_key_value = present_key_value + cross_attn_present_key_value
431
+
432
+ # Fully Connected
433
+ residual = hidden_states
434
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
435
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
436
+ hidden_states = self.fc2(hidden_states)
437
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
438
+ hidden_states = residual + hidden_states
439
+ hidden_states = self.final_layer_norm(hidden_states)
440
+
441
+ outputs = (hidden_states,)
442
+
443
+ if output_attentions:
444
+ outputs += (self_attn_weights, cross_attn_weights)
445
+
446
+ if use_cache:
447
+ outputs += (present_key_value,)
448
+
449
+ return outputs
450
+
451
+
452
+ class DeltalmPretrainedModel(PreTrainedModel):
453
+ config_class = DeltalmConfig
454
+ base_model_prefix = "model"
455
+ supports_gradient_checkpointing = True
456
+
457
+ def _init_weights(self, module):
458
+ std = self.config.init_std
459
+ if isinstance(module, nn.Linear):
460
+ module.weight.data.normal_(mean=0.0, std=std)
461
+ if module.bias is not None:
462
+ module.bias.data.zero_()
463
+ elif isinstance(module, nn.Embedding):
464
+ module.weight.data.normal_(mean=0.0, std=std)
465
+ if module.padding_idx is not None:
466
+ module.weight.data[module.padding_idx].zero_()
467
+
468
+ def _set_gradient_checkpointing(self, module, value=False):
469
+ if isinstance(module, (DeltalmDecoder, DeltalmEncoder)):
470
+ module.gradient_checkpointing = value
471
+
472
+
473
+ class DeltalmDecoder(DeltalmPretrainedModel):
474
+ """
475
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeltalmDecoderLayer`]
476
+ Args:
477
+ config: DeltalmConfig
478
+ embed_tokens (nn.Embedding): output embedding
479
+ """
480
+
481
+ def __init__(self, config: DeltalmConfig, embed_tokens: Optional[nn.Embedding] = None):
482
+ super().__init__(config)
483
+ self.dropout = config.dropout
484
+ self.layerdrop = config.decoder_layerdrop
485
+ self.padding_idx = config.pad_token_id
486
+ self.max_target_positions = config.max_position_embeddings
487
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
488
+
489
+ if embed_tokens is not None:
490
+ self.embed_tokens = embed_tokens
491
+ else:
492
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
493
+
494
+ self.embed_positions = DeltalmLearnedPositionalEmbedding(
495
+ config.max_position_embeddings,
496
+ config.d_model,
497
+ )
498
+ self.layers = nn.ModuleList([DeltalmDecoderLayer(config) for _ in range(config.decoder_layers)])
499
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
500
+
501
+ self.gradient_checkpointing = False
502
+ # Initialize weights and apply final processing
503
+ self.post_init()
504
+
505
+ # fairseq实现了一个 nn.init.normal_(self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5) 对最后的output权重做正态分布转换?
506
+
507
+ def get_input_embeddings(self):
508
+ return self.embed_tokens
509
+
510
+ def set_input_embeddings(self, value):
511
+ self.embed_tokens = value
512
+
513
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
514
+ # create causal mask
515
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
516
+ combined_attention_mask = None
517
+ if input_shape[-1] > 1:
518
+ combined_attention_mask = _make_causal_mask(
519
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
520
+ ).to(inputs_embeds.device)
521
+
522
+ if attention_mask is not None:
523
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
524
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
525
+ combined_attention_mask = (
526
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
527
+ )
528
+
529
+ return combined_attention_mask
530
+
531
+ def forward(
532
+ self,
533
+ input_ids: torch.LongTensor = None,
534
+ attention_mask: Optional[torch.Tensor] = None,
535
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
536
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
537
+ head_mask: Optional[torch.Tensor] = None,
538
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
539
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
540
+ inputs_embeds: Optional[torch.FloatTensor] = None,
541
+ use_cache: Optional[bool] = None,
542
+ output_attentions: Optional[bool] = None,
543
+ output_hidden_states: Optional[bool] = None,
544
+ return_dict: Optional[bool] = None,
545
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
546
+ r"""
547
+ Args:
548
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
549
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
550
+ provide it.
551
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
552
+ [`PreTrainedTokenizer.__call__`] for details.
553
+ [What are input IDs?](../glossary#input-ids)
554
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
555
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
556
+ - 1 for tokens that are **not masked**,
557
+ - 0 for tokens that are **masked**.
558
+ [What are attention masks?](../glossary#attention-mask)
559
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
560
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
561
+ of the decoder.
562
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
563
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
564
+ selected in `[0, 1]`:
565
+ - 1 for tokens that are **not masked**,
566
+ - 0 for tokens that are **masked**.
567
+ [What are attention masks?](../glossary#attention-mask)
568
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
569
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
570
+ - 1 indicates the head is **not masked**,
571
+ - 0 indicates the head is **masked**.
572
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
573
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
574
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
575
+ - 1 indicates the head is **not masked**,
576
+ - 0 indicates the head is **masked**.
577
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
578
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
579
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
580
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
581
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
582
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
583
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
584
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
585
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
586
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
587
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
588
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
589
+ embedding lookup matrix.
590
+ output_attentions (`bool`, *optional*):
591
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
592
+ returned tensors for more detail.
593
+ output_hidden_states (`bool`, *optional*):
594
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
595
+ for more detail.
596
+ return_dict (`bool`, *optional*):
597
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
598
+ """
599
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
600
+ output_hidden_states = (
601
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
602
+ )
603
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
604
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
605
+
606
+ # retrieve input_ids and inputs_embeds
607
+ if input_ids is not None and inputs_embeds is not None:
608
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
609
+ elif input_ids is not None:
610
+ input_shape = input_ids.size()
611
+ input_ids = input_ids.view(-1, input_shape[-1])
612
+ elif inputs_embeds is not None:
613
+ input_shape = inputs_embeds.size()[:-1]
614
+ else:
615
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
616
+
617
+ # past_key_values_length
618
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
619
+
620
+ if inputs_embeds is None:
621
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
622
+
623
+ attention_mask = self._prepare_decoder_attention_mask(
624
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
625
+ )
626
+
627
+ # expand encoder attention mask
628
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
629
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
630
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
631
+
632
+ # embed positions
633
+ positions = self.embed_positions(input_shape, past_key_values_length)
634
+
635
+ hidden_states = inputs_embeds + positions
636
+ hidden_states = self.layernorm_embedding(hidden_states)
637
+
638
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
639
+
640
+ # decoder layers
641
+ all_hidden_states = () if output_hidden_states else None
642
+ all_self_attns = () if output_attentions else None
643
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
644
+ next_decoder_cache = () if use_cache else None
645
+
646
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
647
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
648
+ if attn_mask is not None:
649
+ if attn_mask.size()[0] != (len(self.layers)):
650
+ raise ValueError(
651
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
652
+ f" {head_mask.size()[0]}."
653
+ )
654
+
655
+ for idx, decoder_layer in enumerate(self.layers):
656
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
657
+ if output_hidden_states:
658
+ all_hidden_states += (hidden_states,)
659
+ dropout_probability = random.uniform(0, 1)
660
+ if self.training and (dropout_probability < self.layerdrop):
661
+ continue
662
+
663
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
664
+
665
+ if self.gradient_checkpointing and self.training:
666
+
667
+ if use_cache:
668
+ logger.warning(
669
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
670
+ )
671
+ use_cache = False
672
+
673
+ def create_custom_forward(module):
674
+ def custom_forward(*inputs):
675
+ # None for past_key_value
676
+ return module(*inputs, output_attentions, use_cache)
677
+
678
+ return custom_forward
679
+
680
+ layer_outputs = torch.utils.checkpoint.checkpoint(
681
+ create_custom_forward(decoder_layer),
682
+ hidden_states,
683
+ attention_mask,
684
+ encoder_hidden_states,
685
+ encoder_attention_mask,
686
+ head_mask[idx] if head_mask is not None else None,
687
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
688
+ None,
689
+ )
690
+ else:
691
+
692
+ layer_outputs = decoder_layer(
693
+ hidden_states,
694
+ attention_mask=attention_mask,
695
+ encoder_hidden_states=encoder_hidden_states,
696
+ encoder_attention_mask=encoder_attention_mask,
697
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
698
+ cross_attn_layer_head_mask=(
699
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
700
+ ),
701
+ past_key_value=past_key_value,
702
+ output_attentions=output_attentions,
703
+ use_cache=use_cache,
704
+ )
705
+ hidden_states = layer_outputs[0]
706
+
707
+ if use_cache:
708
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
709
+
710
+ if output_attentions:
711
+ all_self_attns += (layer_outputs[1],)
712
+
713
+ if encoder_hidden_states is not None:
714
+ all_cross_attentions += (layer_outputs[2],)
715
+
716
+ # add hidden states from the last decoder layer
717
+ if output_hidden_states:
718
+ all_hidden_states += (hidden_states,)
719
+
720
+ next_cache = next_decoder_cache if use_cache else None
721
+ if not return_dict:
722
+ return tuple(
723
+ v
724
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
725
+ if v is not None
726
+ )
727
+ return BaseModelOutputWithPastAndCrossAttentions(
728
+ last_hidden_state=hidden_states,
729
+ past_key_values=next_cache,
730
+ hidden_states=all_hidden_states,
731
+ attentions=all_self_attns,
732
+ cross_attentions=all_cross_attentions,
733
+ )
734
+
735
+
736
+ DELTALM_START_DOCSTRING = r"""
737
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
738
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
739
+ etc.)
740
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
741
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
742
+ and behavior.
743
+ Parameters:
744
+ config ([`DeltalmConfig`]):
745
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
746
+ load the weights associated with the model, only the configuration. Check out the
747
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
748
+ """
749
+
750
+ DELTALM_GENERATION_EXAMPLE = r"""
751
+ Summarization example:
752
+ ```python
753
+ >>> from transformers import DeltalmTokenizer, DeltalmForConditionalGeneration
754
+ >>> model = DeltalmForConditionalGeneration.from_pretrained("facebook/deltalm-large-cnn")
755
+ >>> tokenizer = DeltalmTokenizer.from_pretrained("facebook/deltalm-large-cnn")
756
+ >>> ARTICLE_TO_SUMMARIZE = (
757
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
758
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
759
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
760
+ ... )
761
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
762
+ >>> # Generate Summary
763
+ >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
764
+ >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
765
+ 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
766
+ ```
767
+ Mask filling example:
768
+ ```python
769
+ >>> from transformers import DeltalmTokenizer, DeltalmForConditionalGeneration
770
+ >>> tokenizer = DeltalmTokenizer.from_pretrained("facebook/deltalm-base")
771
+ >>> model = DeltalmForConditionalGeneration.from_pretrained("facebook/deltalm-base")
772
+ >>> TXT = "My friends are <mask> but they eat too many carbs."
773
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
774
+ >>> logits = model(input_ids).logits
775
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
776
+ >>> probs = logits[0, masked_index].softmax(dim=0)
777
+ >>> values, predictions = probs.topk(5)
778
+ >>> tokenizer.decode(predictions).split()
779
+ ['not', 'good', 'healthy', 'great', 'very']
780
+ ```
781
+ """
782
+
783
+ DELTALM_INPUTS_DOCSTRING = r"""
784
+ Args:
785
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
786
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
787
+ it.
788
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
789
+ [`PreTrainedTokenizer.__call__`] for details.
790
+ [What are input IDs?](../glossary#input-ids)
791
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
792
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
793
+ - 1 for tokens that are **not masked**,
794
+ - 0 for tokens that are **masked**.
795
+ [What are attention masks?](../glossary#attention-mask)
796
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
797
+ Indices of decoder input sequence tokens in the vocabulary.
798
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
799
+ [`PreTrainedTokenizer.__call__`] for details.
800
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
801
+ Deltalm uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
802
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
803
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
804
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
805
+ for denoising pre-training following the paper.
806
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
807
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
808
+ be used by default.
809
+ If you want to change padding behavior, you should read [`modeling_deltalm._prepare_decoder_attention_mask`]
810
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
811
+ information on the default strategy.
812
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
813
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
814
+ - 1 indicates the head is **not masked**,
815
+ - 0 indicates the head is **masked**.
816
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
817
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
818
+ - 1 indicates the head is **not masked**,
819
+ - 0 indicates the head is **masked**.
820
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
821
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
822
+ 1]`:
823
+ - 1 indicates the head is **not masked**,
824
+ - 0 indicates the head is **masked**.
825
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
826
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
827
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
828
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
829
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
830
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
831
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
832
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
833
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
834
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
835
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
836
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
837
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
838
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
839
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
840
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
841
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
842
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
843
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
844
+ input (see `past_key_values`). This is useful if you want more control over how to convert
845
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
846
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
847
+ of `inputs_embeds`.
848
+ use_cache (`bool`, *optional*):
849
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
850
+ `past_key_values`).
851
+ output_attentions (`bool`, *optional*):
852
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
853
+ tensors for more detail.
854
+ output_hidden_states (`bool`, *optional*):
855
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
856
+ more detail.
857
+ return_dict (`bool`, *optional*):
858
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
859
+ """
860
+
861
+
862
+ class DeltalmEncoder(DeltalmPretrainedModel):
863
+ """
864
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
865
+ [`DeltalmEncoderLayer`].
866
+ Args:
867
+ config: DeltalmConfig
868
+ embed_tokens (nn.Embedding): output embedding
869
+ """
870
+
871
+ def __init__(self, config: DeltalmConfig, embed_tokens: Optional[nn.Embedding] = None):
872
+ super().__init__(config)
873
+
874
+ self.dropout = config.dropout
875
+ self.layerdrop = config.encoder_layerdrop
876
+
877
+ embed_dim = config.d_model
878
+ self.padding_idx = config.pad_token_id
879
+ self.max_source_positions = config.max_position_embeddings
880
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
881
+
882
+ if embed_tokens is not None:
883
+ self.embed_tokens = embed_tokens
884
+ else:
885
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
886
+
887
+ self.embed_positions = DeltalmLearnedPositionalEmbedding(
888
+ config.max_position_embeddings,
889
+ embed_dim,
890
+ )
891
+ self.layers = nn.ModuleList([DeltalmEncoderLayer(config) for _ in range(config.encoder_layers)])
892
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
893
+
894
+ self.gradient_checkpointing = False
895
+ if config.encoder_normalize_before:
896
+ self.layer_norm = nn.LayerNorm(embed_dim)
897
+ else:
898
+ self.layer_norm = None
899
+ # Initialize weights and apply final processing
900
+ self.post_init()
901
+
902
+ def get_input_embeddings(self):
903
+ return self.embed_tokens
904
+
905
+ def set_input_embeddings(self, value):
906
+ self.embed_tokens = value
907
+
908
+ def forward(
909
+ self,
910
+ input_ids: torch.LongTensor = None,
911
+ attention_mask: Optional[torch.Tensor] = None,
912
+ head_mask: Optional[torch.Tensor] = None,
913
+ inputs_embeds: Optional[torch.FloatTensor] = None,
914
+ output_attentions: Optional[bool] = None,
915
+ output_hidden_states: Optional[bool] = None,
916
+ return_dict: Optional[bool] = None,
917
+ ) -> Union[Tuple, BaseModelOutput]:
918
+ r"""
919
+ Args:
920
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
921
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
922
+ provide it.
923
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
924
+ [`PreTrainedTokenizer.__call__`] for details.
925
+ [What are input IDs?](../glossary#input-ids)
926
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
927
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
928
+ - 1 for tokens that are **not masked**,
929
+ - 0 for tokens that are **masked**.
930
+ [What are attention masks?](../glossary#attention-mask)
931
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
932
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
933
+ - 1 indicates the head is **not masked**,
934
+ - 0 indicates the head is **masked**.
935
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
936
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
937
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
938
+ than the model's internal embedding lookup matrix.
939
+ output_attentions (`bool`, *optional*):
940
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
941
+ returned tensors for more detail.
942
+ output_hidden_states (`bool`, *optional*):
943
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
944
+ for more detail.
945
+ return_dict (`bool`, *optional*):
946
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
947
+ """
948
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
949
+ output_hidden_states = (
950
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
951
+ )
952
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
953
+
954
+ # retrieve input_ids and inputs_embeds
955
+ if input_ids is not None and inputs_embeds is not None:
956
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
957
+ elif input_ids is not None:
958
+ input_shape = input_ids.size()
959
+ input_ids = input_ids.view(-1, input_shape[-1])
960
+ elif inputs_embeds is not None:
961
+ input_shape = inputs_embeds.size()[:-1]
962
+ else:
963
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
964
+
965
+ if inputs_embeds is None:
966
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
967
+
968
+ embed_pos = self.embed_positions(input_shape)
969
+
970
+ hidden_states = inputs_embeds + embed_pos
971
+ hidden_states = self.layernorm_embedding(hidden_states)
972
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
973
+
974
+ # expand attention_mask
975
+ if attention_mask is not None:
976
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
977
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
978
+
979
+ encoder_states = () if output_hidden_states else None
980
+ all_attentions = () if output_attentions else None
981
+
982
+ # check if head_mask has a correct number of layers specified if desired
983
+ if head_mask is not None:
984
+ if head_mask.size()[0] != (len(self.layers)):
985
+ raise ValueError(
986
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
987
+ f" {head_mask.size()[0]}."
988
+ )
989
+
990
+ for idx, encoder_layer in enumerate(self.layers):
991
+ if output_hidden_states:
992
+ encoder_states = encoder_states + (hidden_states,)
993
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
994
+ dropout_probability = random.uniform(0, 1)
995
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
996
+ layer_outputs = (None, None)
997
+ else:
998
+ if self.gradient_checkpointing and self.training:
999
+
1000
+ def create_custom_forward(module):
1001
+ def custom_forward(*inputs):
1002
+ return module(*inputs, output_attentions)
1003
+
1004
+ return custom_forward
1005
+
1006
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1007
+ create_custom_forward(encoder_layer),
1008
+ hidden_states,
1009
+ attention_mask,
1010
+ (head_mask[idx] if head_mask is not None else None),
1011
+ )
1012
+ else:
1013
+ layer_outputs = encoder_layer(
1014
+ hidden_states,
1015
+ attention_mask,
1016
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1017
+ output_attentions=output_attentions,
1018
+ )
1019
+
1020
+ hidden_states = layer_outputs[0]
1021
+
1022
+ if output_attentions:
1023
+ all_attentions = all_attentions + (layer_outputs[1],)
1024
+
1025
+ if self.layer_norm is not None:
1026
+ hidden_states = self.layer_norm(hidden_states)
1027
+ # hidden_states = self.layernorm_embedding(hidden_states)
1028
+
1029
+ if output_hidden_states:
1030
+ encoder_states = encoder_states + (hidden_states,)
1031
+
1032
+ if not return_dict:
1033
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1034
+ return BaseModelOutput(
1035
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1036
+ )
1037
+
1038
+
1039
+ class DeltalmModel(DeltalmPretrainedModel):
1040
+ def __init__(self, config: DeltalmConfig):
1041
+ super().__init__(config)
1042
+
1043
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1044
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
1045
+
1046
+ self.encoder = DeltalmEncoder(config, self.shared)
1047
+ self.decoder = DeltalmDecoder(config, self.shared)
1048
+
1049
+ # Initialize weights and apply final processing
1050
+ self.post_init()
1051
+
1052
+ def get_input_embeddings(self):
1053
+ return self.shared
1054
+
1055
+ def set_input_embeddings(self, value):
1056
+ self.shared = value
1057
+ self.encoder.embed_tokens = self.shared
1058
+ self.decoder.embed_tokens = self.shared
1059
+
1060
+ def get_encoder(self):
1061
+ return self.encoder
1062
+
1063
+ def get_decoder(self):
1064
+ return self.decoder
1065
+
1066
+ @add_start_docstrings_to_model_forward(DELTALM_INPUTS_DOCSTRING)
1067
+ # @add_code_sample_docstrings(
1068
+ # processor_class=_TOKENIZER_FOR_DOC,
1069
+ # checkpoint=_CHECKPOINT_FOR_DOC,
1070
+ # output_type=Seq2SeqModelOutput,
1071
+ # config_class=_CONFIG_FOR_DOC,
1072
+ # expected_output=_EXPECTED_OUTPUT_SHAPE,
1073
+ # )
1074
+ def forward(
1075
+ self,
1076
+ input_ids: torch.LongTensor = None,
1077
+ attention_mask: Optional[torch.Tensor] = None,
1078
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1079
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1080
+ head_mask: Optional[torch.Tensor] = None,
1081
+ decoder_head_mask: Optional[torch.Tensor] = None,
1082
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1083
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1084
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1085
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1086
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1087
+ use_cache: Optional[bool] = None,
1088
+ output_attentions: Optional[bool] = None,
1089
+ output_hidden_states: Optional[bool] = None,
1090
+ return_dict: Optional[bool] = None,
1091
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
1092
+
1093
+ # different to other models, Deltalm automatically creates decoder_input_ids from
1094
+ # input_ids if no decoder_input_ids are provided
1095
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1096
+ if input_ids is None:
1097
+ raise ValueError(
1098
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
1099
+ "passed, `input_ids` cannot be `None`. Please pass either "
1100
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
1101
+ )
1102
+
1103
+ decoder_input_ids = shift_tokens_right(
1104
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
1105
+ )
1106
+
1107
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1108
+ output_hidden_states = (
1109
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1110
+ )
1111
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1112
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1113
+
1114
+ if encoder_outputs is None:
1115
+ encoder_outputs = self.encoder(
1116
+ input_ids=input_ids,
1117
+ attention_mask=attention_mask,
1118
+ head_mask=head_mask,
1119
+ inputs_embeds=inputs_embeds,
1120
+ output_attentions=output_attentions,
1121
+ output_hidden_states=output_hidden_states,
1122
+ return_dict=return_dict,
1123
+ )
1124
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1125
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1126
+ encoder_outputs = BaseModelOutput(
1127
+ last_hidden_state=encoder_outputs[0],
1128
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1129
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1130
+ )
1131
+
1132
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1133
+ decoder_outputs = self.decoder(
1134
+ input_ids=decoder_input_ids,
1135
+ attention_mask=decoder_attention_mask,
1136
+ encoder_hidden_states=encoder_outputs[0],
1137
+ encoder_attention_mask=attention_mask,
1138
+ head_mask=decoder_head_mask,
1139
+ cross_attn_head_mask=cross_attn_head_mask,
1140
+ past_key_values=past_key_values,
1141
+ inputs_embeds=decoder_inputs_embeds,
1142
+ use_cache=use_cache,
1143
+ output_attentions=output_attentions,
1144
+ output_hidden_states=output_hidden_states,
1145
+ return_dict=return_dict,
1146
+ )
1147
+
1148
+ if not return_dict:
1149
+ return decoder_outputs + encoder_outputs
1150
+
1151
+ logger.debug("last_hidden_state.size: %s", decoder_outputs.last_hidden_state)
1152
+ return Seq2SeqModelOutput(
1153
+ last_hidden_state=decoder_outputs.last_hidden_state,
1154
+ past_key_values=decoder_outputs.past_key_values,
1155
+ decoder_hidden_states=decoder_outputs.hidden_states,
1156
+ decoder_attentions=decoder_outputs.attentions,
1157
+ cross_attentions=decoder_outputs.cross_attentions,
1158
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1159
+ encoder_hidden_states=encoder_outputs.hidden_states,
1160
+ encoder_attentions=encoder_outputs.attentions,
1161
+ )
1162
+
1163
+
1164
+ @add_start_docstrings(
1165
+ "The DELTALM Model with a language modeling head. Can be used for translation.", DELTALM_START_DOCSTRING
1166
+ )
1167
+ class DeltalmForConditionalGeneration(DeltalmPretrainedModel):
1168
+ base_model_prefix = "model"
1169
+ _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"]
1170
+
1171
+ def __init__(self, config: DeltalmConfig):
1172
+ super().__init__(config)
1173
+ self.model = DeltalmModel(config)
1174
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1175
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
1176
+
1177
+ # Initialize weights and apply final processing
1178
+ self.post_init()
1179
+
1180
+ def get_encoder(self):
1181
+ return self.model.get_encoder()
1182
+
1183
+ def get_decoder(self):
1184
+ return self.model.get_decoder()
1185
+
1186
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
1187
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
1188
+ self._resize_final_logits_bias(new_num_tokens)
1189
+ return new_embeddings
1190
+
1191
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
1192
+ logger.debug("Debug: coming to _resize_final_logits_bias")
1193
+ old_num_tokens = self.final_logits_bias.shape[-1]
1194
+ if new_num_tokens <= old_num_tokens:
1195
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
1196
+ else:
1197
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1198
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1199
+ self.register_buffer("final_logits_bias", new_bias)
1200
+
1201
+ def get_output_embeddings(self):
1202
+ return self.lm_head
1203
+
1204
+ def set_output_embeddings(self, new_embeddings):
1205
+ self.lm_head = new_embeddings
1206
+
1207
+ @add_start_docstrings_to_model_forward(DELTALM_INPUTS_DOCSTRING)
1208
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1209
+ @add_end_docstrings(DELTALM_GENERATION_EXAMPLE)
1210
+ def forward(
1211
+ self,
1212
+ input_ids: torch.LongTensor = None,
1213
+ attention_mask: Optional[torch.Tensor] = None,
1214
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1215
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1216
+ head_mask: Optional[torch.Tensor] = None,
1217
+ decoder_head_mask: Optional[torch.Tensor] = None,
1218
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1219
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1220
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1221
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1222
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1223
+ labels: Optional[torch.LongTensor] = None,
1224
+ use_cache: Optional[bool] = None,
1225
+ output_attentions: Optional[bool] = None,
1226
+ output_hidden_states: Optional[bool] = None,
1227
+ return_dict: Optional[bool] = None,
1228
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
1229
+ r"""
1230
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1231
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1232
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1233
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1234
+ Returns:
1235
+ """
1236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1237
+
1238
+ logger.debug("Comming to Generation!")
1239
+
1240
+ if labels is not None:
1241
+ logger.debug("Debug: *************** Before label ***************** ")
1242
+ logger.debug("Debug: %s", labels.size())
1243
+ if use_cache:
1244
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
1245
+ use_cache = False
1246
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1247
+ decoder_input_ids = shift_tokens_right(
1248
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1249
+ )
1250
+
1251
+ logger.debug("Debug: ************ After labels ************")
1252
+ logger.debug("Debug: %s", labels.size())
1253
+
1254
+ outputs = self.model(
1255
+ input_ids,
1256
+ attention_mask=attention_mask,
1257
+ decoder_input_ids=decoder_input_ids,
1258
+ encoder_outputs=encoder_outputs,
1259
+ decoder_attention_mask=decoder_attention_mask,
1260
+ head_mask=head_mask,
1261
+ decoder_head_mask=decoder_head_mask,
1262
+ cross_attn_head_mask=cross_attn_head_mask,
1263
+ past_key_values=past_key_values,
1264
+ inputs_embeds=inputs_embeds,
1265
+ decoder_inputs_embeds=decoder_inputs_embeds,
1266
+ use_cache=use_cache,
1267
+ output_attentions=output_attentions,
1268
+ output_hidden_states=output_hidden_states,
1269
+ return_dict=return_dict,
1270
+ )
1271
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1272
+ # print(self.lm_head)
1273
+ logger.debug("Debug: logit_size: %s", lm_logits.size())
1274
+
1275
+ # logger.debug("Debug: change logit size: ", lm_logits.view(-1, self.config.vocab_size).size())
1276
+ # logger.debug("Debug: change label size: ", labels.view(-1).size())
1277
+ masked_lm_loss = None
1278
+
1279
+ if labels is not None:
1280
+ # logger.debug("Debug: model label_size: %s", labels.size())
1281
+ # loss_fct = CrossEntropyLoss()
1282
+ # masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1283
+ loss_fct = CrossEntropyLoss()
1284
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1285
+ # label_smoothing = self.config.label_smoothing
1286
+ # # logger.debug("Debug: label.size: ", )
1287
+ # if label_smoothing == 0:
1288
+ # # compute label smoothed loss
1289
+ # loss_fct = CrossEntropyLoss()
1290
+ # masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1291
+ # else:
1292
+ # m = torch.nn.LogSoftmax(dim=-1)
1293
+ # lprobs = m(lm_logits.float())
1294
+ # # lprobs = m(lm_logits)
1295
+ # # # torch.set_printoptions(linewidth=200)
1296
+ # loss_fn = label_smoothed_nll_loss
1297
+ # masked_lm_loss, _ = loss_fn(lprobs.view(-1, lprobs.size(-1)), labels.view(-1), label_smoothing, self.config.pad_token_id)
1298
+
1299
+ if not return_dict:
1300
+ logger.debug("Debug: not return dict")
1301
+ output = (lm_logits,) + outputs[1:]
1302
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1303
+
1304
+ return Seq2SeqLMOutput(
1305
+ loss=masked_lm_loss,
1306
+ logits=lm_logits,
1307
+ past_key_values=outputs.past_key_values,
1308
+ decoder_hidden_states=outputs.decoder_hidden_states,
1309
+ decoder_attentions=outputs.decoder_attentions,
1310
+ cross_attentions=outputs.cross_attentions,
1311
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1312
+ encoder_hidden_states=outputs.encoder_hidden_states,
1313
+ encoder_attentions=outputs.encoder_attentions,
1314
+ )
1315
+
1316
+ def prepare_inputs_for_generation(
1317
+ self,
1318
+ decoder_input_ids,
1319
+ past=None,
1320
+ attention_mask=None,
1321
+ head_mask=None,
1322
+ decoder_head_mask=None,
1323
+ cross_attn_head_mask=None,
1324
+ use_cache=None,
1325
+ encoder_outputs=None,
1326
+ **kwargs
1327
+ ):
1328
+ # cut decoder_input_ids if past is used
1329
+ if past is not None:
1330
+ decoder_input_ids = decoder_input_ids[:, -1:]
1331
+
1332
+ return {
1333
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1334
+ "encoder_outputs": encoder_outputs,
1335
+ "past_key_values": past,
1336
+ "decoder_input_ids": decoder_input_ids,
1337
+ "attention_mask": attention_mask,
1338
+ "head_mask": head_mask,
1339
+ "decoder_head_mask": decoder_head_mask,
1340
+ "cross_attn_head_mask": cross_attn_head_mask,
1341
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1342
+ }
1343
+
1344
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1345
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1346
+
1347
+ @staticmethod
1348
+ def _reorder_cache(past, beam_idx):
1349
+ reordered_past = ()
1350
+ for layer_past in past:
1351
+ # cached cross_attention states don't have to be reordered -> they are always the same
1352
+ reordered_past += (
1353
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1354
+ )
1355
+ return reordered_past
1356
+
1357
+
1358
+ class DeltalmDecoderWrapper(DeltalmPretrainedModel):
1359
+ """
1360
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1361
+ used in combination with the [`EncoderDecoderModel`] framework.
1362
+ """
1363
+
1364
+ def __init__(self, config):
1365
+ super().__init__(config)
1366
+ self.decoder = DeltalmDecoder(config)
1367
+
1368
+ def forward(self, *args, **kwargs):
1369
+ return self.decoder(*args, **kwargs)
1370
+
1371
+
1372
+ class DeltalmForCausalLM(DeltalmPretrainedModel):
1373
+ def __init__(self, config):
1374
+ config = copy.deepcopy(config)
1375
+ config.is_decoder = True
1376
+ config.is_encoder_decoder = False
1377
+ super().__init__(config)
1378
+ self.model = DeltalmDecoderWrapper(config)
1379
+
1380
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1381
+
1382
+ # Initialize weights and apply final processing
1383
+ self.post_init()
1384
+
1385
+ def get_input_embeddings(self):
1386
+ return self.model.decoder.embed_tokens
1387
+
1388
+ def set_input_embeddings(self, value):
1389
+ self.model.decoder.embed_tokens = value
1390
+
1391
+ def get_output_embeddings(self):
1392
+ return self.lm_head
1393
+
1394
+ def set_output_embeddings(self, new_embeddings):
1395
+ self.lm_head = new_embeddings
1396
+
1397
+ def set_decoder(self, decoder):
1398
+ self.model.decoder = decoder
1399
+
1400
+ def get_decoder(self):
1401
+ return self.model.decoder
1402
+
1403
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1404
+ def forward(
1405
+ self,
1406
+ input_ids: torch.LongTensor = None,
1407
+ attention_mask: Optional[torch.Tensor] = None,
1408
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1409
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1410
+ head_mask: Optional[torch.Tensor] = None,
1411
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1412
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1413
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1414
+ labels: Optional[torch.LongTensor] = None,
1415
+ use_cache: Optional[bool] = None,
1416
+ output_attentions: Optional[bool] = None,
1417
+ output_hidden_states: Optional[bool] = None,
1418
+ return_dict: Optional[bool] = None,
1419
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1420
+ r"""
1421
+ Args:
1422
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1423
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1424
+ provide it.
1425
+ Indices can be obtained using [`DeltalmTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1426
+ [`PreTrainedTokenizer.__call__`] for details.
1427
+ [What are input IDs?](../glossary#input-ids)
1428
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1429
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1430
+ - 1 for tokens that are **not masked**,
1431
+ - 0 for tokens that are **masked**.
1432
+ [What are attention masks?](../glossary#attention-mask)
1433
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1434
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1435
+ if the model is configured as a decoder.
1436
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1437
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
1438
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1439
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1440
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1441
+ - 1 indicates the head is **not masked**,
1442
+ - 0 indicates the head is **masked**.
1443
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1444
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
1445
+ - 1 indicates the head is **not masked**,
1446
+ - 0 indicates the head is **masked**.
1447
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1448
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1449
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1450
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
1451
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
1452
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1453
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1454
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1455
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1456
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1457
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1458
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1459
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1460
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1461
+ use_cache (`bool`, *optional*):
1462
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1463
+ (see `past_key_values`).
1464
+ - 1 for tokens that are **not masked**,
1465
+ - 0 for tokens that are **masked**.
1466
+ output_attentions (`bool`, *optional*):
1467
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1468
+ returned tensors for more detail.
1469
+ output_hidden_states (`bool`, *optional*):
1470
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1471
+ for more detail.
1472
+ return_dict (`bool`, *optional*):
1473
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1474
+ Returns:
1475
+ Example:
1476
+ ```python
1477
+ >>> from transformers import DeltalmTokenizer, DeltalmForCausalLM
1478
+ >>> tokenizer = DeltalmTokenizer.from_pretrained("facebook/deltalm-base")
1479
+ >>> model = DeltalmForCausalLM.from_pretrained("facebook/deltalm-base", add_cross_attention=False)
1480
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
1481
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1482
+ >>> outputs = model(**inputs)
1483
+ >>> logits = outputs.logits
1484
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
1485
+ >>> list(logits.shape) == expected_shape
1486
+ True
1487
+ ```"""
1488
+
1489
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1490
+ output_hidden_states = (
1491
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1492
+ )
1493
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1494
+
1495
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1496
+ outputs = self.model.decoder(
1497
+ input_ids=input_ids,
1498
+ attention_mask=attention_mask,
1499
+ encoder_hidden_states=encoder_hidden_states,
1500
+ encoder_attention_mask=encoder_attention_mask,
1501
+ head_mask=head_mask,
1502
+ cross_attn_head_mask=cross_attn_head_mask,
1503
+ past_key_values=past_key_values,
1504
+ inputs_embeds=inputs_embeds,
1505
+ use_cache=use_cache,
1506
+ output_attentions=output_attentions,
1507
+ output_hidden_states=output_hidden_states,
1508
+ return_dict=return_dict,
1509
+ )
1510
+
1511
+ logits = self.lm_head(outputs[0])
1512
+
1513
+ loss = None
1514
+ if labels is not None:
1515
+ loss_fct = CrossEntropyLoss()
1516
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1517
+
1518
+ if not return_dict:
1519
+ output = (logits,) + outputs[1:]
1520
+ return (loss,) + output if loss is not None else output
1521
+
1522
+ return CausalLMOutputWithCrossAttentions(
1523
+ loss=loss,
1524
+ logits=logits,
1525
+ past_key_values=outputs.past_key_values,
1526
+ hidden_states=outputs.hidden_states,
1527
+ attentions=outputs.attentions,
1528
+ cross_attentions=outputs.cross_attentions,
1529
+ )
1530
+
1531
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
1532
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1533
+ if attention_mask is None:
1534
+ attention_mask = input_ids.new_ones(input_ids.shape)
1535
+
1536
+ if past:
1537
+ input_ids = input_ids[:, -1:]
1538
+ # first step, decoder_cached_states are empty
1539
+ return {
1540
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
1541
+ "attention_mask": attention_mask,
1542
+ "past_key_values": past,
1543
+ "use_cache": use_cache,
1544
+ }
1545
+
1546
+ @staticmethod
1547
+ def _reorder_cache(past, beam_idx):
1548
+ reordered_past = ()
1549
+ for layer_past in past:
1550
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1551
+ return reordered_past
english/models/deltalm/tokenizer_deltalm.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import re
6
+ import warnings
7
+ from shutil import copyfile
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ # import sentencepiece as spm
11
+
12
+ from transformers.tokenization_utils import PreTrainedTokenizer
13
+ from transformers.utils import logging
14
+
15
+
16
+ SPIECE_UNDERLINE = "▁"
17
+
18
+ VOCAB_FILES_NAMES = {"vocab_file": "spm.model"}
19
+
20
+ PRETRAINED_VOCAB_FILES_MAP = {
21
+ "vocab_file": {"IDEA-CCNL/deltalm": "https://huggingface.co/IDEA-CCNL/Randeng-Deltalm-362M-En-Zn/resolve/main/spm.model"}
22
+ }
23
+
24
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
25
+ "IDEA-CCNL/deltalm": 512,
26
+ }
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class DeltalmTokenizer(PreTrainedTokenizer):
33
+ """
34
+ Construct a T5 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
35
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
36
+ this superclass for more information regarding those methods.
37
+ Args:
38
+ vocab_file (`str`):
39
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
40
+ contains the vocabulary necessary to instantiate a tokenizer.
41
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
42
+ The end of sequence token.
43
+ <Tip>
44
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
45
+ The token used is the `sep_token`.
46
+ </Tip>
47
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
48
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
49
+ token instead.
50
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
51
+ The token used for padding, for example when batching sequences of different lengths.
52
+ extra_ids (`int`, *optional*, defaults to 100):
53
+ Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
54
+ accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
55
+ indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
56
+ like in T5 preprocessing see
57
+ [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
58
+ additional_special_tokens (`List[str]`, *optional*):
59
+ Additional special tokens used by the tokenizer.
60
+ sp_model_kwargs (`dict`, *optional*):
61
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
62
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
63
+ to set:
64
+ - `enable_sampling`: Enable subword regularization.
65
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
66
+ - `nbest_size = {0,1}`: No sampling is performed.
67
+ - `nbest_size > 1`: samples from the nbest_size results.
68
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
69
+ using forward-filtering-and-backward-sampling algorithm.
70
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
71
+ BPE-dropout.
72
+ Attributes:
73
+ sp_model (`SentencePieceProcessor`):
74
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
75
+ """
76
+
77
+ vocab_files_names = VOCAB_FILES_NAMES
78
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
79
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
80
+ model_input_names = ["input_ids", "attention_mask"]
81
+
82
+ def __init__(
83
+ self,
84
+ vocab_file,
85
+ bos_token="<s>",
86
+ eos_token="</s>",
87
+ unk_token="<unk>",
88
+ pad_token="<pad>",
89
+ extra_ids=0,
90
+ additional_special_tokens=None,
91
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
92
+ **kwargs
93
+ ) -> None:
94
+ # Add extra_ids to the special token list
95
+ if extra_ids > 0 and additional_special_tokens is None:
96
+ additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
97
+ elif extra_ids > 0 and additional_special_tokens is not None:
98
+ # Check that we have the right number of extra_id special tokens
99
+ extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
100
+ if extra_tokens != extra_ids:
101
+ raise ValueError(
102
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
103
+ " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids"
104
+ " tokens"
105
+ )
106
+
107
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
108
+ super().__init__(
109
+ bos_token=bos_token,
110
+ eos_token=eos_token,
111
+ unk_token=unk_token,
112
+ pad_token=pad_token,
113
+ additional_special_tokens=additional_special_tokens,
114
+ extra_ids=extra_ids,
115
+ sp_model_kwargs=self.sp_model_kwargs,
116
+ **kwargs,
117
+ )
118
+
119
+ self.vocab_file = vocab_file
120
+ self.offset = 1
121
+ self._extra_ids = extra_ids
122
+
123
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
124
+ self.sp_model.Load(vocab_file)
125
+
126
+ self.encoder: Dict[int, str] = {
127
+ 0: self.bos_token,
128
+ 1: self.pad_token,
129
+ 2: self.eos_token,
130
+ 3: self.unk_token,
131
+ }
132
+
133
+ self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}
134
+
135
+ @staticmethod
136
+ def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length):
137
+ if pretrained_model_name_or_path in DeltalmTokenizer.max_model_input_sizes:
138
+ deprecated_max_model_length = DeltalmTokenizer.max_model_input_sizes[pretrained_model_name_or_path]
139
+ if init_max_model_length is not None and init_max_model_length != max_model_length:
140
+ return init_max_model_length
141
+ elif init_max_model_length is None:
142
+ warnings.warn(
143
+ "This tokenizer was incorrectly instantiated with a model max length of"
144
+ f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this"
145
+ " behavior is kept to avoid breaking backwards compatibility when padding/encoding with"
146
+ " `truncation is True`.\n- Be aware that you SHOULD NOT rely on"
147
+ f" {pretrained_model_name_or_path} automatically truncating your input to"
148
+ f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences"
149
+ f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with"
150
+ " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please"
151
+ " instantiate this tokenizer with `model_max_length` set to your preferred value.",
152
+ FutureWarning,
153
+ )
154
+
155
+ return max_model_length
156
+
157
+ @property
158
+ def vocab_size(self):
159
+ return self.sp_model.get_piece_size() # + self._extra_ids
160
+
161
+ def get_vocab(self):
162
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
163
+ vocab.update(self.added_tokens_encoder)
164
+ return vocab
165
+
166
+ def get_special_tokens_mask(
167
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
168
+ ) -> List[int]:
169
+ """
170
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
171
+ special tokens using the tokenizer `prepare_for_model` method.
172
+ Args:
173
+ token_ids_0 (`List[int]`):
174
+ List of IDs.
175
+ token_ids_1 (`List[int]`, *optional*):
176
+ Optional second list of IDs for sequence pairs.
177
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
178
+ Whether or not the token list is already formatted with special tokens for the model.
179
+ Returns:
180
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
181
+ """
182
+ if already_has_special_tokens:
183
+ return super().get_special_tokens_mask(
184
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
185
+ )
186
+
187
+ # normal case: some special tokens
188
+ if token_ids_1 is None:
189
+ return ([0] * len(token_ids_0)) + [1]
190
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
191
+
192
+ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
193
+ """Do not add eos again if user already added it."""
194
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
195
+ warnings.warn(
196
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
197
+ " eos tokens being added."
198
+ )
199
+ return token_ids
200
+ else:
201
+ return token_ids + [self.eos_token_id]
202
+
203
+ def create_token_type_ids_from_sequences(
204
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
205
+ ) -> List[int]:
206
+ """
207
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
208
+ use of token type ids, therefore a list of zeros is returned.
209
+ Args:
210
+ token_ids_0 (`List[int]`):
211
+ List of IDs.
212
+ token_ids_1 (`List[int]`, *optional*):
213
+ Optional second list of IDs for sequence pairs.
214
+ Returns:
215
+ `List[int]`: List of zeros.
216
+ """
217
+ eos = [self.eos_token_id]
218
+
219
+ if token_ids_1 is None:
220
+ return len(token_ids_0 + eos) * [0]
221
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
222
+
223
+ def build_inputs_with_special_tokens(
224
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
225
+ ) -> List[int]:
226
+ """
227
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
228
+ adding special tokens. A sequence has the following format:
229
+ - single sequence: `X </s>`
230
+ - pair of sequences: `A </s> B </s>`
231
+ Args:
232
+ token_ids_0 (`List[int]`):
233
+ List of IDs to which the special tokens will be added.
234
+ token_ids_1 (`List[int]`, *optional*):
235
+ Optional second list of IDs for sequence pairs.
236
+ Returns:
237
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
238
+ """
239
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
240
+ if token_ids_1 is None:
241
+ return token_ids_0
242
+ else:
243
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
244
+ return token_ids_0 + token_ids_1
245
+
246
+ def __getstate__(self):
247
+ state = self.__dict__.copy()
248
+ state["sp_model"] = None
249
+ return state
250
+
251
+ def __setstate__(self, d):
252
+ self.__dict__ = d
253
+
254
+ # for backward compatibility
255
+ if not hasattr(self, "sp_model_kwargs"):
256
+ self.sp_model_kwargs = {}
257
+
258
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
259
+ self.sp_model.Load(self.vocab_file)
260
+
261
+ def _tokenize(self, text: str) -> List[str]:
262
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
263
+ return self.sp_model.encode(text, out_type=str)
264
+
265
+ def _convert_token_to_id(self, token):
266
+ """Converts a token (str) in an id using the vocab."""
267
+ if token.startswith("<extra_id_"):
268
+ match = re.match(r"<extra_id_(\d+)>", token)
269
+ num = int(match.group(1))
270
+ return self.vocab_size - num - 1
271
+ elif token in self.decoder:
272
+ return self.decoder[token]
273
+
274
+ sp_id = self.sp_model.piece_to_id(token)
275
+ return sp_id + self.offset
276
+
277
+ def _convert_id_to_token(self, index):
278
+ """Converts an index (integer) in a token (str) using the vocab."""
279
+ # if index < self.sp_model.get_piece_size():
280
+ # token = self.sp_model.IdToPiece(index)
281
+ # else:
282
+ # token = f"<extra_id_{self.vocab_size - 1 - index}>"
283
+ # return token
284
+ if index in self.encoder:
285
+ return self.encoder[index]
286
+ elif index in self.added_tokens_encoder:
287
+ return self.added_tokens_encoder[index]
288
+ elif index < self.sp_model.get_piece_size() + 4:
289
+ token = self.sp_model.IdToPiece(index-self.offset)
290
+ else:
291
+ token = f"<extra_id_{self.vocab_size - 1 - index}>"
292
+ return token
293
+
294
+ def convert_tokens_to_string(self, tokens):
295
+ """Converts a sequence of tokens (string) in a single string."""
296
+ current_sub_tokens = []
297
+ out_string = ""
298
+ for token in tokens:
299
+ # make sure that special tokens are not decoded using sentencepiece model
300
+ if token in self.all_special_tokens:
301
+ out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " "
302
+ current_sub_tokens = []
303
+ else:
304
+ current_sub_tokens.append(token)
305
+ out_string += self.sp_model.decode_pieces(current_sub_tokens)
306
+ return out_string.strip()
307
+
308
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
309
+ if not os.path.isdir(save_directory):
310
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
311
+ return
312
+ out_vocab_file = os.path.join(
313
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
314
+ )
315
+
316
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
317
+ copyfile(self.vocab_file, out_vocab_file)
318
+ elif not os.path.isfile(self.vocab_file):
319
+ with open(out_vocab_file, "wb") as fi:
320
+ content_spiece_model = self.sp_model.serialized_model_proto()
321
+ fi.write(content_spiece_model)
322
+
323
+ return (out_vocab_file,)
english/split_text.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nltk.tokenize import sent_tokenize
2
+
3
+ def sentence_token_nltk(str):
4
+ print("start split sentence_token_nltk...\n")
5
+ sent_tokenize_list = sent_tokenize(str)
6
+ return sent_tokenize_list
7
+
8
+ def sentence_split(str_centence):
9
+ list_ret = list()
10
+ for s_str in str_centence.split('.'):
11
+ if '?' in s_str:
12
+ list_ret.extend(s_str.split('?'))
13
+ elif '!' in s_str:
14
+ list_ret.extend(s_str.split('!'))
15
+ else:
16
+ list_ret.append(s_str)
17
+ return list_ret
18
+
19
+
english/translate.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from english.models.deltalm.modeling_deltalm import DeltalmForConditionalGeneration
2
+ from transformers import AutoTokenizer
3
+ import os
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
+
6
+ class Translate:
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.model = DeltalmForConditionalGeneration.from_pretrained("IDEA-CCNL/Randeng-Deltalm-362M-En-Zn")
10
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/infoxlm-base")
11
+
12
+ def translateToZh(self,text):
13
+ inputs = self.tokenizer(text, max_length=512, truncation=True,return_tensors="pt")
14
+ generate_ids = self.model.generate(inputs["input_ids"], max_length=512)
15
+ output = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
16
+ print(output)
17
+
18
+ return output
mel_processing.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.data
8
+ import numpy as np
9
+ import librosa
10
+ import librosa.util as librosa_util
11
+ from librosa.util import normalize, pad_center, tiny
12
+ from scipy.signal import get_window
13
+ from scipy.io.wavfile import read
14
+ from librosa.filters import mel as librosa_mel_fn
15
+
16
+ MAX_WAV_VALUE = 32768.0
17
+
18
+
19
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
+ """
21
+ PARAMS
22
+ ------
23
+ C: compression factor
24
+ """
25
+ return torch.log(torch.clamp(x, min=clip_val) * C)
26
+
27
+
28
+ def dynamic_range_decompression_torch(x, C=1):
29
+ """
30
+ PARAMS
31
+ ------
32
+ C: compression factor used to compress
33
+ """
34
+ return torch.exp(x) / C
35
+
36
+
37
+ def spectral_normalize_torch(magnitudes):
38
+ output = dynamic_range_compression_torch(magnitudes)
39
+ return output
40
+
41
+
42
+ def spectral_de_normalize_torch(magnitudes):
43
+ output = dynamic_range_decompression_torch(magnitudes)
44
+ return output
45
+
46
+
47
+ mel_basis = {}
48
+ hann_window = {}
49
+
50
+
51
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52
+ if torch.min(y) < -1.:
53
+ print('min value is ', torch.min(y))
54
+ if torch.max(y) > 1.:
55
+ print('max value is ', torch.max(y))
56
+
57
+ global hann_window
58
+ dtype_device = str(y.dtype) + '_' + str(y.device)
59
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
60
+ if wnsize_dtype_device not in hann_window:
61
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
62
+
63
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
64
+ y = y.squeeze(1)
65
+
66
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
67
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
68
+
69
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
70
+ return spec
71
+
72
+
73
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
74
+ global mel_basis
75
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
76
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
77
+ if fmax_dtype_device not in mel_basis:
78
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
79
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
80
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
81
+ spec = spectral_normalize_torch(spec)
82
+ return spec
83
+
84
+
85
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
86
+ if torch.min(y) < -1.:
87
+ print('min value is ', torch.min(y))
88
+ if torch.max(y) > 1.:
89
+ print('max value is ', torch.max(y))
90
+
91
+ global mel_basis, hann_window
92
+ dtype_device = str(y.dtype) + '_' + str(y.device)
93
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
94
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
95
+ if fmax_dtype_device not in mel_basis:
96
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
97
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
98
+ if wnsize_dtype_device not in hann_window:
99
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
100
+
101
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
102
+ y = y.squeeze(1)
103
+
104
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
105
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
106
+
107
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
108
+
109
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
110
+ spec = spectral_normalize_torch(spec)
111
+
112
+ return spec
models.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ import modules
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from commons import init_weights, get_padding
13
+
14
+
15
+ class ResidualCouplingBlock(nn.Module):
16
+ def __init__(self,
17
+ channels,
18
+ hidden_channels,
19
+ kernel_size,
20
+ dilation_rate,
21
+ n_layers,
22
+ n_flows=4,
23
+ gin_channels=0):
24
+ super().__init__()
25
+ self.channels = channels
26
+ self.hidden_channels = hidden_channels
27
+ self.kernel_size = kernel_size
28
+ self.dilation_rate = dilation_rate
29
+ self.n_layers = n_layers
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.flows = nn.ModuleList()
34
+ for i in range(n_flows):
35
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
36
+ self.flows.append(modules.Flip())
37
+
38
+ def forward(self, x, x_mask, g=None, reverse=False):
39
+ if not reverse:
40
+ for flow in self.flows:
41
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
42
+ else:
43
+ for flow in reversed(self.flows):
44
+ x = flow(x, x_mask, g=g, reverse=reverse)
45
+ return x
46
+
47
+
48
+ class Encoder(nn.Module):
49
+ def __init__(self,
50
+ in_channels,
51
+ out_channels,
52
+ hidden_channels,
53
+ kernel_size,
54
+ dilation_rate,
55
+ n_layers,
56
+ gin_channels=0):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ self.out_channels = out_channels
60
+ self.hidden_channels = hidden_channels
61
+ self.kernel_size = kernel_size
62
+ self.dilation_rate = dilation_rate
63
+ self.n_layers = n_layers
64
+ self.gin_channels = gin_channels
65
+
66
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
67
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
68
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
69
+
70
+ def forward(self, x, x_lengths, g=None):
71
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
72
+ x = self.pre(x) * x_mask
73
+ x = self.enc(x, x_mask, g=g)
74
+ stats = self.proj(x) * x_mask
75
+ m, logs = torch.split(stats, self.out_channels, dim=1)
76
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
77
+ return z, m, logs, x_mask
78
+
79
+
80
+ class Generator(torch.nn.Module):
81
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
82
+ super(Generator, self).__init__()
83
+ self.num_kernels = len(resblock_kernel_sizes)
84
+ self.num_upsamples = len(upsample_rates)
85
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
86
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
87
+
88
+ self.ups = nn.ModuleList()
89
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
90
+ self.ups.append(weight_norm(
91
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
92
+ k, u, padding=(k-u)//2)))
93
+
94
+ self.resblocks = nn.ModuleList()
95
+ for i in range(len(self.ups)):
96
+ ch = upsample_initial_channel//(2**(i+1))
97
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
98
+ self.resblocks.append(resblock(ch, k, d))
99
+
100
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
101
+ self.ups.apply(init_weights)
102
+
103
+ if gin_channels != 0:
104
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
105
+
106
+ def forward(self, x, g=None):
107
+ x = self.conv_pre(x)
108
+ if g is not None:
109
+ x = x + self.cond(g)
110
+
111
+ for i in range(self.num_upsamples):
112
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
113
+ x = self.ups[i](x)
114
+ xs = None
115
+ for j in range(self.num_kernels):
116
+ if xs is None:
117
+ xs = self.resblocks[i*self.num_kernels+j](x)
118
+ else:
119
+ xs += self.resblocks[i*self.num_kernels+j](x)
120
+ x = xs / self.num_kernels
121
+ x = F.leaky_relu(x)
122
+ x = self.conv_post(x)
123
+ x = torch.tanh(x)
124
+
125
+ return x
126
+
127
+ def remove_weight_norm(self):
128
+ print('Removing weight norm...')
129
+ for l in self.ups:
130
+ remove_weight_norm(l)
131
+ for l in self.resblocks:
132
+ l.remove_weight_norm()
133
+
134
+
135
+ class DiscriminatorP(torch.nn.Module):
136
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
137
+ super(DiscriminatorP, self).__init__()
138
+ self.period = period
139
+ self.use_spectral_norm = use_spectral_norm
140
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
141
+ self.convs = nn.ModuleList([
142
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
143
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
144
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
145
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
146
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
147
+ ])
148
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
149
+
150
+ def forward(self, x):
151
+ fmap = []
152
+
153
+ # 1d to 2d
154
+ b, c, t = x.shape
155
+ if t % self.period != 0: # pad first
156
+ n_pad = self.period - (t % self.period)
157
+ x = F.pad(x, (0, n_pad), "reflect")
158
+ t = t + n_pad
159
+ x = x.view(b, c, t // self.period, self.period)
160
+
161
+ for l in self.convs:
162
+ x = l(x)
163
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
164
+ fmap.append(x)
165
+ x = self.conv_post(x)
166
+ fmap.append(x)
167
+ x = torch.flatten(x, 1, -1)
168
+
169
+ return x, fmap
170
+
171
+
172
+ class DiscriminatorS(torch.nn.Module):
173
+ def __init__(self, use_spectral_norm=False):
174
+ super(DiscriminatorS, self).__init__()
175
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
176
+ self.convs = nn.ModuleList([
177
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
178
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
179
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
180
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
181
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
182
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
183
+ ])
184
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
185
+
186
+ def forward(self, x):
187
+ fmap = []
188
+
189
+ for l in self.convs:
190
+ x = l(x)
191
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
192
+ fmap.append(x)
193
+ x = self.conv_post(x)
194
+ fmap.append(x)
195
+ x = torch.flatten(x, 1, -1)
196
+
197
+ return x, fmap
198
+
199
+
200
+ class MultiPeriodDiscriminator(torch.nn.Module):
201
+ def __init__(self, use_spectral_norm=False):
202
+ super(MultiPeriodDiscriminator, self).__init__()
203
+ periods = [2,3,5,7,11]
204
+
205
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
206
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
207
+ self.discriminators = nn.ModuleList(discs)
208
+
209
+ def forward(self, y, y_hat):
210
+ y_d_rs = []
211
+ y_d_gs = []
212
+ fmap_rs = []
213
+ fmap_gs = []
214
+ for i, d in enumerate(self.discriminators):
215
+ y_d_r, fmap_r = d(y)
216
+ y_d_g, fmap_g = d(y_hat)
217
+ y_d_rs.append(y_d_r)
218
+ y_d_gs.append(y_d_g)
219
+ fmap_rs.append(fmap_r)
220
+ fmap_gs.append(fmap_g)
221
+
222
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
223
+
224
+
225
+ class SpeakerEncoder(torch.nn.Module):
226
+ def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
227
+ super(SpeakerEncoder, self).__init__()
228
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
229
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
230
+ self.relu = nn.ReLU()
231
+
232
+ def forward(self, mels):
233
+ self.lstm.flatten_parameters()
234
+ _, (hidden, _) = self.lstm(mels)
235
+ embeds_raw = self.relu(self.linear(hidden[-1]))
236
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
237
+
238
+ def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
239
+ mel_slices = []
240
+ for i in range(0, total_frames-partial_frames, partial_hop):
241
+ mel_range = torch.arange(i, i+partial_frames)
242
+ mel_slices.append(mel_range)
243
+
244
+ return mel_slices
245
+
246
+ def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
247
+ mel_len = mel.size(1)
248
+ last_mel = mel[:,-partial_frames:]
249
+
250
+ if mel_len > partial_frames:
251
+ mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
252
+ mels = list(mel[:,s] for s in mel_slices)
253
+ mels.append(last_mel)
254
+ mels = torch.stack(tuple(mels), 0).squeeze(1)
255
+
256
+ with torch.no_grad():
257
+ partial_embeds = self(mels)
258
+ embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
259
+ #embed = embed / torch.linalg.norm(embed, 2)
260
+ else:
261
+ with torch.no_grad():
262
+ embed = self(last_mel)
263
+
264
+ return embed
265
+
266
+
267
+ class SynthesizerTrn(nn.Module):
268
+ """
269
+ Synthesizer for Training
270
+ """
271
+
272
+ def __init__(self,
273
+ spec_channels,
274
+ segment_size,
275
+ inter_channels,
276
+ hidden_channels,
277
+ filter_channels,
278
+ n_heads,
279
+ n_layers,
280
+ kernel_size,
281
+ p_dropout,
282
+ resblock,
283
+ resblock_kernel_sizes,
284
+ resblock_dilation_sizes,
285
+ upsample_rates,
286
+ upsample_initial_channel,
287
+ upsample_kernel_sizes,
288
+ gin_channels,
289
+ ssl_dim,
290
+ use_spk,
291
+ **kwargs):
292
+
293
+ super().__init__()
294
+ self.spec_channels = spec_channels
295
+ self.inter_channels = inter_channels
296
+ self.hidden_channels = hidden_channels
297
+ self.filter_channels = filter_channels
298
+ self.n_heads = n_heads
299
+ self.n_layers = n_layers
300
+ self.kernel_size = kernel_size
301
+ self.p_dropout = p_dropout
302
+ self.resblock = resblock
303
+ self.resblock_kernel_sizes = resblock_kernel_sizes
304
+ self.resblock_dilation_sizes = resblock_dilation_sizes
305
+ self.upsample_rates = upsample_rates
306
+ self.upsample_initial_channel = upsample_initial_channel
307
+ self.upsample_kernel_sizes = upsample_kernel_sizes
308
+ self.segment_size = segment_size
309
+ self.gin_channels = gin_channels
310
+ self.ssl_dim = ssl_dim
311
+ self.use_spk = use_spk
312
+
313
+ self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16)
314
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
315
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
316
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
317
+
318
+ if not self.use_spk:
319
+ self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels)
320
+
321
+ def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None):
322
+ if c_lengths == None:
323
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
324
+ if spec_lengths == None:
325
+ spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
326
+
327
+ if not self.use_spk:
328
+ g = self.enc_spk(mel.transpose(1,2))
329
+ g = g.unsqueeze(-1)
330
+
331
+ _, m_p, logs_p, _ = self.enc_p(c, c_lengths)
332
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
333
+ z_p = self.flow(z, spec_mask, g=g)
334
+
335
+ z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
336
+ o = self.dec(z_slice, g=g)
337
+
338
+ return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
339
+
340
+ def infer(self, c, g=None, mel=None, c_lengths=None):
341
+ if c_lengths == None:
342
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
343
+ if not self.use_spk:
344
+ g = self.enc_spk.embed_utterance(mel.transpose(1,2))
345
+ g = g.unsqueeze(-1)
346
+
347
+ z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
348
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
349
+ o = self.dec(z * c_mask, g=g)
350
+
351
+ return o
modules.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ import commons
13
+ from commons import init_weights, get_padding
14
+
15
+
16
+ LRELU_SLOPE = 0.1
17
+
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.gamma = nn.Parameter(torch.ones(channels))
26
+ self.beta = nn.Parameter(torch.zeros(channels))
27
+
28
+ def forward(self, x):
29
+ x = x.transpose(1, -1)
30
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31
+ return x.transpose(1, -1)
32
+
33
+
34
+ class ConvReluNorm(nn.Module):
35
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
36
+ super().__init__()
37
+ self.in_channels = in_channels
38
+ self.hidden_channels = hidden_channels
39
+ self.out_channels = out_channels
40
+ self.kernel_size = kernel_size
41
+ self.n_layers = n_layers
42
+ self.p_dropout = p_dropout
43
+ assert n_layers > 1, "Number of layers should be larger than 0."
44
+
45
+ self.conv_layers = nn.ModuleList()
46
+ self.norm_layers = nn.ModuleList()
47
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
48
+ self.norm_layers.append(LayerNorm(hidden_channels))
49
+ self.relu_drop = nn.Sequential(
50
+ nn.ReLU(),
51
+ nn.Dropout(p_dropout))
52
+ for _ in range(n_layers-1):
53
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
54
+ self.norm_layers.append(LayerNorm(hidden_channels))
55
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
56
+ self.proj.weight.data.zero_()
57
+ self.proj.bias.data.zero_()
58
+
59
+ def forward(self, x, x_mask):
60
+ x_org = x
61
+ for i in range(self.n_layers):
62
+ x = self.conv_layers[i](x * x_mask)
63
+ x = self.norm_layers[i](x)
64
+ x = self.relu_drop(x)
65
+ x = x_org + self.proj(x)
66
+ return x * x_mask
67
+
68
+
69
+ class DDSConv(nn.Module):
70
+ """
71
+ Dialted and Depth-Separable Convolution
72
+ """
73
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.kernel_size = kernel_size
77
+ self.n_layers = n_layers
78
+ self.p_dropout = p_dropout
79
+
80
+ self.drop = nn.Dropout(p_dropout)
81
+ self.convs_sep = nn.ModuleList()
82
+ self.convs_1x1 = nn.ModuleList()
83
+ self.norms_1 = nn.ModuleList()
84
+ self.norms_2 = nn.ModuleList()
85
+ for i in range(n_layers):
86
+ dilation = kernel_size ** i
87
+ padding = (kernel_size * dilation - dilation) // 2
88
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
89
+ groups=channels, dilation=dilation, padding=padding
90
+ ))
91
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
92
+ self.norms_1.append(LayerNorm(channels))
93
+ self.norms_2.append(LayerNorm(channels))
94
+
95
+ def forward(self, x, x_mask, g=None):
96
+ if g is not None:
97
+ x = x + g
98
+ for i in range(self.n_layers):
99
+ y = self.convs_sep[i](x * x_mask)
100
+ y = self.norms_1[i](y)
101
+ y = F.gelu(y)
102
+ y = self.convs_1x1[i](y)
103
+ y = self.norms_2[i](y)
104
+ y = F.gelu(y)
105
+ y = self.drop(y)
106
+ x = x + y
107
+ return x * x_mask
108
+
109
+
110
+ class WN(torch.nn.Module):
111
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
112
+ super(WN, self).__init__()
113
+ assert(kernel_size % 2 == 1)
114
+ self.hidden_channels =hidden_channels
115
+ self.kernel_size = kernel_size,
116
+ self.dilation_rate = dilation_rate
117
+ self.n_layers = n_layers
118
+ self.gin_channels = gin_channels
119
+ self.p_dropout = p_dropout
120
+
121
+ self.in_layers = torch.nn.ModuleList()
122
+ self.res_skip_layers = torch.nn.ModuleList()
123
+ self.drop = nn.Dropout(p_dropout)
124
+
125
+ if gin_channels != 0:
126
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
127
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
128
+
129
+ for i in range(n_layers):
130
+ dilation = dilation_rate ** i
131
+ padding = int((kernel_size * dilation - dilation) / 2)
132
+ in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
133
+ dilation=dilation, padding=padding)
134
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
135
+ self.in_layers.append(in_layer)
136
+
137
+ # last one is not necessary
138
+ if i < n_layers - 1:
139
+ res_skip_channels = 2 * hidden_channels
140
+ else:
141
+ res_skip_channels = hidden_channels
142
+
143
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
144
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
145
+ self.res_skip_layers.append(res_skip_layer)
146
+
147
+ def forward(self, x, x_mask, g=None, **kwargs):
148
+ output = torch.zeros_like(x)
149
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
150
+
151
+ if g is not None:
152
+ g = self.cond_layer(g)
153
+
154
+ for i in range(self.n_layers):
155
+ x_in = self.in_layers[i](x)
156
+ if g is not None:
157
+ cond_offset = i * 2 * self.hidden_channels
158
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
159
+ else:
160
+ g_l = torch.zeros_like(x_in)
161
+
162
+ acts = commons.fused_add_tanh_sigmoid_multiply(
163
+ x_in,
164
+ g_l,
165
+ n_channels_tensor)
166
+ acts = self.drop(acts)
167
+
168
+ res_skip_acts = self.res_skip_layers[i](acts)
169
+ if i < self.n_layers - 1:
170
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
171
+ x = (x + res_acts) * x_mask
172
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
173
+ else:
174
+ output = output + res_skip_acts
175
+ return output * x_mask
176
+
177
+ def remove_weight_norm(self):
178
+ if self.gin_channels != 0:
179
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
180
+ for l in self.in_layers:
181
+ torch.nn.utils.remove_weight_norm(l)
182
+ for l in self.res_skip_layers:
183
+ torch.nn.utils.remove_weight_norm(l)
184
+
185
+
186
+ class ResBlock1(torch.nn.Module):
187
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
188
+ super(ResBlock1, self).__init__()
189
+ self.convs1 = nn.ModuleList([
190
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
191
+ padding=get_padding(kernel_size, dilation[0]))),
192
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
193
+ padding=get_padding(kernel_size, dilation[1]))),
194
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
195
+ padding=get_padding(kernel_size, dilation[2])))
196
+ ])
197
+ self.convs1.apply(init_weights)
198
+
199
+ self.convs2 = nn.ModuleList([
200
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201
+ padding=get_padding(kernel_size, 1))),
202
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203
+ padding=get_padding(kernel_size, 1))),
204
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
205
+ padding=get_padding(kernel_size, 1)))
206
+ ])
207
+ self.convs2.apply(init_weights)
208
+
209
+ def forward(self, x, x_mask=None):
210
+ for c1, c2 in zip(self.convs1, self.convs2):
211
+ xt = F.leaky_relu(x, LRELU_SLOPE)
212
+ if x_mask is not None:
213
+ xt = xt * x_mask
214
+ xt = c1(xt)
215
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
216
+ if x_mask is not None:
217
+ xt = xt * x_mask
218
+ xt = c2(xt)
219
+ x = xt + x
220
+ if x_mask is not None:
221
+ x = x * x_mask
222
+ return x
223
+
224
+ def remove_weight_norm(self):
225
+ for l in self.convs1:
226
+ remove_weight_norm(l)
227
+ for l in self.convs2:
228
+ remove_weight_norm(l)
229
+
230
+
231
+ class ResBlock2(torch.nn.Module):
232
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
233
+ super(ResBlock2, self).__init__()
234
+ self.convs = nn.ModuleList([
235
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]))),
237
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
238
+ padding=get_padding(kernel_size, dilation[1])))
239
+ ])
240
+ self.convs.apply(init_weights)
241
+
242
+ def forward(self, x, x_mask=None):
243
+ for c in self.convs:
244
+ xt = F.leaky_relu(x, LRELU_SLOPE)
245
+ if x_mask is not None:
246
+ xt = xt * x_mask
247
+ xt = c(xt)
248
+ x = xt + x
249
+ if x_mask is not None:
250
+ x = x * x_mask
251
+ return x
252
+
253
+ def remove_weight_norm(self):
254
+ for l in self.convs:
255
+ remove_weight_norm(l)
256
+
257
+
258
+ class Log(nn.Module):
259
+ def forward(self, x, x_mask, reverse=False, **kwargs):
260
+ if not reverse:
261
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
262
+ logdet = torch.sum(-y, [1, 2])
263
+ return y, logdet
264
+ else:
265
+ x = torch.exp(x) * x_mask
266
+ return x
267
+
268
+
269
+ class Flip(nn.Module):
270
+ def forward(self, x, *args, reverse=False, **kwargs):
271
+ x = torch.flip(x, [1])
272
+ if not reverse:
273
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
274
+ return x, logdet
275
+ else:
276
+ return x
277
+
278
+
279
+ class ElementwiseAffine(nn.Module):
280
+ def __init__(self, channels):
281
+ super().__init__()
282
+ self.channels = channels
283
+ self.m = nn.Parameter(torch.zeros(channels,1))
284
+ self.logs = nn.Parameter(torch.zeros(channels,1))
285
+
286
+ def forward(self, x, x_mask, reverse=False, **kwargs):
287
+ if not reverse:
288
+ y = self.m + torch.exp(self.logs) * x
289
+ y = y * x_mask
290
+ logdet = torch.sum(self.logs * x_mask, [1,2])
291
+ return y, logdet
292
+ else:
293
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
294
+ return x
295
+
296
+
297
+ class ResidualCouplingLayer(nn.Module):
298
+ def __init__(self,
299
+ channels,
300
+ hidden_channels,
301
+ kernel_size,
302
+ dilation_rate,
303
+ n_layers,
304
+ p_dropout=0,
305
+ gin_channels=0,
306
+ mean_only=False):
307
+ assert channels % 2 == 0, "channels should be divisible by 2"
308
+ super().__init__()
309
+ self.channels = channels
310
+ self.hidden_channels = hidden_channels
311
+ self.kernel_size = kernel_size
312
+ self.dilation_rate = dilation_rate
313
+ self.n_layers = n_layers
314
+ self.half_channels = channels // 2
315
+ self.mean_only = mean_only
316
+
317
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
318
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
319
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
320
+ self.post.weight.data.zero_()
321
+ self.post.bias.data.zero_()
322
+
323
+ def forward(self, x, x_mask, g=None, reverse=False):
324
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
325
+ h = self.pre(x0) * x_mask
326
+ h = self.enc(h, x_mask, g=g)
327
+ stats = self.post(h) * x_mask
328
+ if not self.mean_only:
329
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
330
+ else:
331
+ m = stats
332
+ logs = torch.zeros_like(m)
333
+
334
+ if not reverse:
335
+ x1 = m + x1 * torch.exp(logs) * x_mask
336
+ x = torch.cat([x0, x1], 1)
337
+ logdet = torch.sum(logs, [1,2])
338
+ return x, logdet
339
+ else:
340
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
341
+ x = torch.cat([x0, x1], 1)
342
+ return x
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.22.0
2
+ scipy
3
+ torch
4
+ transformers
5
+ librosa==0.8.1
6
+ webrtcvad==2.0.10
7
+ protobuf
8
+ cpm_kernels
9
+ mdtex2html
10
+ sentencepiece
11
+ accelerate
12
+ loguru
13
+ edge_tts
14
+ altair
15
+ gradio==3.36.1
16
+ nltk
17
+ openai
speaker_encoder/__init__.py ADDED
File without changes
speaker_encoder/audio.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from speaker_encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ import numpy as np
6
+ import webrtcvad
7
+ import librosa
8
+ import struct
9
+
10
+ int16_max = (2 ** 15) - 1
11
+
12
+
13
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
14
+ source_sr: Optional[int] = None):
15
+ """
16
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
17
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
18
+
19
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
20
+ just .wav), either the waveform as a numpy array of floats.
21
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
22
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
23
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
24
+ this argument will be ignored.
25
+ """
26
+ # Load the wav from disk if needed
27
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
28
+ wav, source_sr = librosa.load(fpath_or_wav, sr=None)
29
+ else:
30
+ wav = fpath_or_wav
31
+
32
+ # Resample the wav if needed
33
+ if source_sr is not None and source_sr != sampling_rate:
34
+ wav = librosa.resample(wav, source_sr, sampling_rate)
35
+
36
+ # Apply the preprocessing: normalize volume and shorten long silences
37
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
38
+ wav = trim_long_silences(wav)
39
+
40
+ return wav
41
+
42
+
43
+ def wav_to_mel_spectrogram(wav):
44
+ """
45
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
46
+ Note: this not a log-mel spectrogram.
47
+ """
48
+ frames = librosa.feature.melspectrogram(
49
+ y=wav,
50
+ sr=sampling_rate,
51
+ n_fft=int(sampling_rate * mel_window_length / 1000),
52
+ hop_length=int(sampling_rate * mel_window_step / 1000),
53
+ n_mels=mel_n_channels
54
+ )
55
+ return frames.astype(np.float32).T
56
+
57
+
58
+ def trim_long_silences(wav):
59
+ """
60
+ Ensures that segments without voice in the waveform remain no longer than a
61
+ threshold determined by the VAD parameters in params.py.
62
+
63
+ :param wav: the raw waveform as a numpy array of floats
64
+ :return: the same waveform with silences trimmed away (length <= original wav length)
65
+ """
66
+ # Compute the voice detection window size
67
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
68
+
69
+ # Trim the end of the audio to have a multiple of the window size
70
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
71
+
72
+ # Convert the float waveform to 16-bit mono PCM
73
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
74
+
75
+ # Perform voice activation detection
76
+ voice_flags = []
77
+ vad = webrtcvad.Vad(mode=3)
78
+ for window_start in range(0, len(wav), samples_per_window):
79
+ window_end = window_start + samples_per_window
80
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
81
+ sample_rate=sampling_rate))
82
+ voice_flags = np.array(voice_flags)
83
+
84
+ # Smooth the voice detection with a moving average
85
+ def moving_average(array, width):
86
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
87
+ ret = np.cumsum(array_padded, dtype=float)
88
+ ret[width:] = ret[width:] - ret[:-width]
89
+ return ret[width - 1:] / width
90
+
91
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
92
+ audio_mask = np.round(audio_mask).astype(np.bool)
93
+
94
+ # Dilate the voiced regions
95
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
96
+ audio_mask = np.repeat(audio_mask, samples_per_window)
97
+
98
+ return wav[audio_mask == True]
99
+
100
+
101
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
102
+ if increase_only and decrease_only:
103
+ raise ValueError("Both increase only and decrease only are set")
104
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
105
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
106
+ return wav
107
+ return wav * (10 ** (dBFS_change / 20))
speaker_encoder/ckpt/pretrained_bak_5805000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
3
+ size 17090379
speaker_encoder/compute_embed.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder import inference as encoder
2
+ from multiprocessing.pool import Pool
3
+ from functools import partial
4
+ from pathlib import Path
5
+ # from utils import logmmse
6
+ # from tqdm import tqdm
7
+ # import numpy as np
8
+ # import librosa
9
+
10
+
11
+ def embed_utterance(fpaths, encoder_model_fpath):
12
+ if not encoder.is_loaded():
13
+ encoder.load_model(encoder_model_fpath)
14
+
15
+ # Compute the speaker embedding of the utterance
16
+ wav_fpath, embed_fpath = fpaths
17
+ wav = np.load(wav_fpath)
18
+ wav = encoder.preprocess_wav(wav)
19
+ embed = encoder.embed_utterance(wav)
20
+ np.save(embed_fpath, embed, allow_pickle=False)
21
+
22
+
23
+ def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
24
+
25
+ wav_dir = outdir_root.joinpath("audio")
26
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
27
+ assert wav_dir.exists() and metadata_fpath.exists()
28
+ embed_dir = synthesizer_root.joinpath("embeds")
29
+ embed_dir.mkdir(exist_ok=True)
30
+
31
+ # Gather the input wave filepath and the target output embed filepath
32
+ with metadata_fpath.open("r") as metadata_file:
33
+ metadata = [line.split("|") for line in metadata_file]
34
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
35
+
36
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
37
+ # Embed the utterances in separate threads
38
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
39
+ job = Pool(n_processes).imap(func, fpaths)
40
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
speaker_encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
speaker_encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
speaker_encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
speaker_encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
speaker_encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+
5
+ class SpeakerBatch:
6
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7
+ self.speakers = speakers
8
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9
+
10
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
speaker_encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+ from speaker_encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
speaker_encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
speaker_encoder/hparams.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Mel-filterbank
2
+ mel_window_length = 25 # In milliseconds
3
+ mel_window_step = 10 # In milliseconds
4
+ mel_n_channels = 40
5
+
6
+
7
+ ## Audio
8
+ sampling_rate = 16000
9
+ # Number of spectrogram frames in a partial utterance
10
+ partials_n_frames = 160 # 1600 ms
11
+
12
+
13
+ ## Voice Activation Detection
14
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
15
+ # This sets the granularity of the VAD. Should not need to be changed.
16
+ vad_window_length = 30 # In milliseconds
17
+ # Number of frames to average together when performing the moving average smoothing.
18
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
19
+ vad_moving_average_width = 8
20
+ # Maximum number of consecutive silent frames a segment can have.
21
+ vad_max_silence_length = 6
22
+
23
+
24
+ ## Audio volume normalization
25
+ audio_norm_target_dBFS = -30
26
+
27
+
28
+ ## Model parameters
29
+ model_hidden_size = 256
30
+ model_embedding_size = 256
31
+ model_num_layers = 3
speaker_encoder/inference.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_data import *
2
+ from speaker_encoder.model import SpeakerEncoder
3
+ from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ _model = None # type: SpeakerEncoder
12
+ _device = None # type: torch.device
13
+
14
+
15
+ def load_model(weights_fpath: Path, device=None):
16
+ """
17
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
18
+ first call to embed_frames() with the default weights file.
19
+
20
+ :param weights_fpath: the path to saved model weights.
21
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
23
+ If None, will default to your GPU if it"s available, otherwise your CPU.
24
+ """
25
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
26
+ # was saved on. Worth investigating.
27
+ global _model, _device
28
+ if device is None:
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ elif isinstance(device, str):
31
+ _device = torch.device(device)
32
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath)
34
+ _model.load_state_dict(checkpoint["model_state"])
35
+ _model.eval()
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ cbar.set_clim(*color_range)
175
+
176
+ ax.set_xticks([]), ax.set_yticks([])
177
+ ax.set_title(title)
speaker_encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_model import *
2
+ from speaker_encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
19
+ hidden_size=model_hidden_size, # 256
20
+ num_layers=model_num_layers, # 3
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
speaker_encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
speaker_encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
speaker_encoder/preprocess.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocess.pool import ThreadPool
2
+ from speaker_encoder.params_data import *
3
+ from speaker_encoder.config import librispeech_datasets, anglophone_nationalites
4
+ from datetime import datetime
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+
11
+ class DatasetLog:
12
+ """
13
+ Registers metadata about the dataset in a text file.
14
+ """
15
+ def __init__(self, root, name):
16
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
17
+ self.sample_data = dict()
18
+
19
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
20
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
21
+ self.write_line("-----")
22
+ self._log_params()
23
+
24
+ def _log_params(self):
25
+ from speaker_encoder import params_data
26
+ self.write_line("Parameter values:")
27
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
28
+ value = getattr(params_data, param_name)
29
+ self.write_line("\t%s: %s" % (param_name, value))
30
+ self.write_line("-----")
31
+
32
+ def write_line(self, line):
33
+ self.text_file.write("%s\n" % line)
34
+
35
+ def add_sample(self, **kwargs):
36
+ for param_name, value in kwargs.items():
37
+ if not param_name in self.sample_data:
38
+ self.sample_data[param_name] = []
39
+ self.sample_data[param_name].append(value)
40
+
41
+ def finalize(self):
42
+ self.write_line("Statistics:")
43
+ for param_name, values in self.sample_data.items():
44
+ self.write_line("\t%s:" % param_name)
45
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
46
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
47
+ self.write_line("-----")
48
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
49
+ self.write_line("Finished on %s" % end_time)
50
+ self.text_file.close()
51
+
52
+
53
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
54
+ dataset_root = datasets_root.joinpath(dataset_name)
55
+ if not dataset_root.exists():
56
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57
+ return None, None
58
+ return dataset_root, DatasetLog(out_dir, dataset_name)
59
+
60
+
61
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62
+ skip_existing, logger):
63
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64
+
65
+ # Function to preprocess utterances for one speaker
66
+ def preprocess_speaker(speaker_dir: Path):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90
+ # Check if the target output file already exists
91
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
93
+ if skip_existing and out_fname in existing_fnames:
94
+ continue
95
+
96
+ # Load and preprocess the waveform
97
+ wav = audio.preprocess_wav(in_fpath)
98
+ if len(wav) == 0:
99
+ continue
100
+
101
+ # Create the mel spectrogram, discard those that are too short
102
+ frames = audio.wav_to_mel_spectrogram(wav)
103
+ if len(frames) < partials_n_frames:
104
+ continue
105
+
106
+ out_fpath = speaker_out_dir.joinpath(out_fname)
107
+ np.save(out_fpath, frames)
108
+ logger.add_sample(duration=len(wav) / sampling_rate)
109
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
110
+
111
+ sources_file.close()
112
+
113
+ # Process the utterances for each speaker
114
+ with ThreadPool(8) as pool:
115
+ list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
116
+ unit="speakers"))
117
+ logger.finalize()
118
+ print("Done preprocessing %s.\n" % dataset_name)
119
+
120
+
121
+ # Function to preprocess utterances for one speaker
122
+ def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool):
123
+ # Give a name to the speaker that includes its dataset
124
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
125
+
126
+ # Create an output directory with that name, as well as a txt file containing a
127
+ # reference to each source file.
128
+ speaker_out_dir = out_dir.joinpath(speaker_name)
129
+ speaker_out_dir.mkdir(exist_ok=True)
130
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
131
+
132
+ # There's a possibility that the preprocessing was interrupted earlier, check if
133
+ # there already is a sources file.
134
+ # if sources_fpath.exists():
135
+ # try:
136
+ # with sources_fpath.open("r") as sources_file:
137
+ # existing_fnames = {line.split(",")[0] for line in sources_file}
138
+ # except:
139
+ # existing_fnames = {}
140
+ # else:
141
+ # existing_fnames = {}
142
+ existing_fnames = {}
143
+ # Gather all audio files for that speaker recursively
144
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
145
+
146
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
147
+ # Check if the target output file already exists
148
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
149
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
150
+ if skip_existing and out_fname in existing_fnames:
151
+ continue
152
+
153
+ # Load and preprocess the waveform
154
+ wav = audio.preprocess_wav(in_fpath)
155
+ if len(wav) == 0:
156
+ continue
157
+
158
+ # Create the mel spectrogram, discard those that are too short
159
+ frames = audio.wav_to_mel_spectrogram(wav)
160
+ if len(frames) < partials_n_frames:
161
+ continue
162
+
163
+ out_fpath = speaker_out_dir.joinpath(out_fname)
164
+ np.save(out_fpath, frames)
165
+ # logger.add_sample(duration=len(wav) / sampling_rate)
166
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
167
+
168
+ sources_file.close()
169
+ return len(wav)
170
+
171
+ def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
172
+ skip_existing, logger):
173
+ # from multiprocessing import Pool, cpu_count
174
+ from pathos.multiprocessing import ProcessingPool as Pool
175
+ # Function to preprocess utterances for one speaker
176
+ def __preprocess_speaker(speaker_dir: Path):
177
+ # Give a name to the speaker that includes its dataset
178
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
179
+
180
+ # Create an output directory with that name, as well as a txt file containing a
181
+ # reference to each source file.
182
+ speaker_out_dir = out_dir.joinpath(speaker_name)
183
+ speaker_out_dir.mkdir(exist_ok=True)
184
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
185
+
186
+ existing_fnames = {}
187
+ # Gather all audio files for that speaker recursively
188
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
189
+ wav_lens = []
190
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
191
+ # Check if the target output file already exists
192
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
193
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
194
+ if skip_existing and out_fname in existing_fnames:
195
+ continue
196
+
197
+ # Load and preprocess the waveform
198
+ wav = audio.preprocess_wav(in_fpath)
199
+ if len(wav) == 0:
200
+ continue
201
+
202
+ # Create the mel spectrogram, discard those that are too short
203
+ frames = audio.wav_to_mel_spectrogram(wav)
204
+ if len(frames) < partials_n_frames:
205
+ continue
206
+
207
+ out_fpath = speaker_out_dir.joinpath(out_fname)
208
+ np.save(out_fpath, frames)
209
+ # logger.add_sample(duration=len(wav) / sampling_rate)
210
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
211
+ wav_lens.append(len(wav))
212
+ sources_file.close()
213
+ return wav_lens
214
+
215
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
216
+ # Process the utterances for each speaker
217
+ # with ThreadPool(8) as pool:
218
+ # list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
219
+ # unit="speakers"))
220
+ pool = Pool(processes=20)
221
+ for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1):
222
+ for wav_len in wav_lens:
223
+ logger.add_sample(duration=wav_len / sampling_rate)
224
+ print(f'{i}/{len(speaker_dirs)} \r')
225
+
226
+ logger.finalize()
227
+ print("Done preprocessing %s.\n" % dataset_name)
228
+
229
+
230
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
231
+ for dataset_name in librispeech_datasets["train"]["other"]:
232
+ # Initialize the preprocessing
233
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
234
+ if not dataset_root:
235
+ return
236
+
237
+ # Preprocess all speakers
238
+ speaker_dirs = list(dataset_root.glob("*"))
239
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
240
+ skip_existing, logger)
241
+
242
+
243
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
244
+ # Initialize the preprocessing
245
+ dataset_name = "VoxCeleb1"
246
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
247
+ if not dataset_root:
248
+ return
249
+
250
+ # Get the contents of the meta file
251
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
252
+ metadata = [line.split("\t") for line in metafile][1:]
253
+
254
+ # Select the ID and the nationality, filter out non-anglophone speakers
255
+ nationalities = {line[0]: line[3] for line in metadata}
256
+ # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
257
+ # nationality.lower() in anglophone_nationalites]
258
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()]
259
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
260
+ (len(keep_speaker_ids), len(nationalities)))
261
+
262
+ # Get the speaker directories for anglophone speakers only
263
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
264
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
265
+ speaker_dir.name in keep_speaker_ids]
266
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
267
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
268
+
269
+ # Preprocess all speakers
270
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
271
+ skip_existing, logger)
272
+
273
+
274
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
275
+ # Initialize the preprocessing
276
+ dataset_name = "VoxCeleb2"
277
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
278
+ if not dataset_root:
279
+ return
280
+
281
+ # Get the speaker directories
282
+ # Preprocess all speakers
283
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
284
+ _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
285
+ skip_existing, logger)
speaker_encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.visualizations import Visualizations
2
+ from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3
+ from speaker_encoder.params_model import *
4
+ from speaker_encoder.model import SpeakerEncoder
5
+ from utils.profiler import Profiler
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ def sync(device: torch.device):
10
+ # FIXME
11
+ return
12
+ # For correct profiling (cuda operations are async)
13
+ if device.type == "cuda":
14
+ torch.cuda.synchronize(device)
15
+
16
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
17
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
18
+ no_visdom: bool):
19
+ # Create a dataset and a dataloader
20
+ dataset = SpeakerVerificationDataset(clean_data_root)
21
+ loader = SpeakerVerificationDataLoader(
22
+ dataset,
23
+ speakers_per_batch, # 64
24
+ utterances_per_speaker, # 10
25
+ num_workers=8,
26
+ )
27
+
28
+ # Setup the device on which to run the forward pass and the loss. These can be different,
29
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
30
+ # hyperparameters) faster on the CPU.
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ # FIXME: currently, the gradient is None if loss_device is cuda
33
+ loss_device = torch.device("cpu")
34
+
35
+ # Create the model and the optimizer
36
+ model = SpeakerEncoder(device, loss_device)
37
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
38
+ init_step = 1
39
+
40
+ # Configure file path for the model
41
+ state_fpath = models_dir.joinpath(run_id + ".pt")
42
+ backup_dir = models_dir.joinpath(run_id + "_backups")
43
+
44
+ # Load any existing model
45
+ if not force_restart:
46
+ if state_fpath.exists():
47
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
48
+ checkpoint = torch.load(state_fpath)
49
+ init_step = checkpoint["step"]
50
+ model.load_state_dict(checkpoint["model_state"])
51
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
52
+ optimizer.param_groups[0]["lr"] = learning_rate_init
53
+ else:
54
+ print("No model \"%s\" found, starting training from scratch." % run_id)
55
+ else:
56
+ print("Starting the training from scratch.")
57
+ model.train()
58
+
59
+ # Initialize the visualization environment
60
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
61
+ vis.log_dataset(dataset)
62
+ vis.log_params()
63
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
64
+ vis.log_implementation({"Device": device_name})
65
+
66
+ # Training loop
67
+ profiler = Profiler(summarize_every=10, disabled=False)
68
+ for step, speaker_batch in enumerate(loader, init_step):
69
+ profiler.tick("Blocking, waiting for batch (threaded)")
70
+
71
+ # Forward pass
72
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
73
+ sync(device)
74
+ profiler.tick("Data to %s" % device)
75
+ embeds = model(inputs)
76
+ sync(device)
77
+ profiler.tick("Forward pass")
78
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
79
+ loss, eer = model.loss(embeds_loss)
80
+ sync(loss_device)
81
+ profiler.tick("Loss")
82
+
83
+ # Backward pass
84
+ model.zero_grad()
85
+ loss.backward()
86
+ profiler.tick("Backward pass")
87
+ model.do_gradient_ops()
88
+ optimizer.step()
89
+ profiler.tick("Parameter update")
90
+
91
+ # Update visualizations
92
+ # learning_rate = optimizer.param_groups[0]["lr"]
93
+ vis.update(loss.item(), eer, step)
94
+
95
+ # Draw projections and save them to the backup folder
96
+ if umap_every != 0 and step % umap_every == 0:
97
+ print("Drawing and saving projections (step %d)" % step)
98
+ backup_dir.mkdir(exist_ok=True)
99
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
100
+ embeds = embeds.detach().cpu().numpy()
101
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
102
+ vis.save()
103
+
104
+ # Overwrite the latest version of the model
105
+ if save_every != 0 and step % save_every == 0:
106
+ print("Saving the model (step %d)" % step)
107
+ torch.save({
108
+ "step": step + 1,
109
+ "model_state": model.state_dict(),
110
+ "optimizer_state": optimizer.state_dict(),
111
+ }, state_fpath)
112
+
113
+ # Make a backup
114
+ if backup_every != 0 and step % backup_every == 0:
115
+ print("Making a backup (step %d)" % step)
116
+ backup_dir.mkdir(exist_ok=True)
117
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
118
+ torch.save({
119
+ "step": step + 1,
120
+ "model_state": model.state_dict(),
121
+ "optimizer_state": optimizer.state_dict(),
122
+ }, backup_fpath)
123
+
124
+ profiler.tick("Extras (visualizations, saving)")
125
+
speaker_encoder/visualizations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from datetime import datetime
3
+ from time import perf_counter as timer
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ # import webbrowser
7
+ import visdom
8
+ import umap
9
+
10
+ colormap = np.array([
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ], dtype=np.float) / 255
25
+
26
+
27
+ class Visualizations:
28
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29
+ # Tracking data
30
+ self.last_update_timestamp = timer()
31
+ self.update_every = update_every
32
+ self.step_times = []
33
+ self.losses = []
34
+ self.eers = []
35
+ print("Updating the visualizations every %d steps." % update_every)
36
+
37
+ # If visdom is disabled TODO: use a better paradigm for that
38
+ self.disabled = disabled
39
+ if self.disabled:
40
+ return
41
+
42
+ # Set the environment name
43
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
44
+ if env_name is None:
45
+ self.env_name = now
46
+ else:
47
+ self.env_name = "%s (%s)" % (env_name, now)
48
+
49
+ # Connect to visdom and open the corresponding window in the browser
50
+ try:
51
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52
+ except ConnectionError:
53
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54
+ "start it.")
55
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56
+
57
+ # Create the windows
58
+ self.loss_win = None
59
+ self.eer_win = None
60
+ # self.lr_win = None
61
+ self.implementation_win = None
62
+ self.projection_win = None
63
+ self.implementation_string = ""
64
+
65
+ def log_params(self):
66
+ if self.disabled:
67
+ return
68
+ from speaker_encoder import params_data
69
+ from speaker_encoder import params_model
70
+ param_string = "<b>Model parameters</b>:<br>"
71
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72
+ value = getattr(params_model, param_name)
73
+ param_string += "\t%s: %s<br>" % (param_name, value)
74
+ param_string += "<b>Data parameters</b>:<br>"
75
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76
+ value = getattr(params_data, param_name)
77
+ param_string += "\t%s: %s<br>" % (param_name, value)
78
+ self.vis.text(param_string, opts={"title": "Parameters"})
79
+
80
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
81
+ if self.disabled:
82
+ return
83
+ dataset_string = ""
84
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
85
+ dataset_string += "\n" + dataset.get_logs()
86
+ dataset_string = dataset_string.replace("\n", "<br>")
87
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
88
+
89
+ def log_implementation(self, params):
90
+ if self.disabled:
91
+ return
92
+ implementation_string = ""
93
+ for param, value in params.items():
94
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
95
+ implementation_string = implementation_string.replace("\n", "<br>")
96
+ self.implementation_string = implementation_string
97
+ self.implementation_win = self.vis.text(
98
+ implementation_string,
99
+ opts={"title": "Training implementation"}
100
+ )
101
+
102
+ def update(self, loss, eer, step):
103
+ # Update the tracking data
104
+ now = timer()
105
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
106
+ self.last_update_timestamp = now
107
+ self.losses.append(loss)
108
+ self.eers.append(eer)
109
+ print(".", end="")
110
+
111
+ # Update the plots every <update_every> steps
112
+ if step % self.update_every != 0:
113
+ return
114
+ time_string = "Step time: mean: %5dms std: %5dms" % \
115
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
118
+ if not self.disabled:
119
+ self.loss_win = self.vis.line(
120
+ [np.mean(self.losses)],
121
+ [step],
122
+ win=self.loss_win,
123
+ update="append" if self.loss_win else None,
124
+ opts=dict(
125
+ legend=["Avg. loss"],
126
+ xlabel="Step",
127
+ ylabel="Loss",
128
+ title="Loss",
129
+ )
130
+ )
131
+ self.eer_win = self.vis.line(
132
+ [np.mean(self.eers)],
133
+ [step],
134
+ win=self.eer_win,
135
+ update="append" if self.eer_win else None,
136
+ opts=dict(
137
+ legend=["Avg. EER"],
138
+ xlabel="Step",
139
+ ylabel="EER",
140
+ title="Equal error rate"
141
+ )
142
+ )
143
+ if self.implementation_win is not None:
144
+ self.vis.text(
145
+ self.implementation_string + ("<b>%s</b>" % time_string),
146
+ win=self.implementation_win,
147
+ opts={"title": "Training implementation"},
148
+ )
149
+
150
+ # Reset the tracking
151
+ self.losses.clear()
152
+ self.eers.clear()
153
+ self.step_times.clear()
154
+
155
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156
+ max_speakers=10):
157
+ max_speakers = min(max_speakers, len(colormap))
158
+ embeds = embeds[:max_speakers * utterances_per_speaker]
159
+
160
+ n_speakers = len(embeds) // utterances_per_speaker
161
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162
+ colors = [colormap[i] for i in ground_truth]
163
+
164
+ reducer = umap.UMAP()
165
+ projected = reducer.fit_transform(embeds)
166
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167
+ plt.gca().set_aspect("equal", "datalim")
168
+ plt.title("UMAP projection (step %d)" % step)
169
+ if not self.disabled:
170
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171
+ if out_fpath is not None:
172
+ plt.savefig(out_fpath)
173
+ plt.clf()
174
+
175
+ def save(self):
176
+ if not self.disabled:
177
+ self.vis.save([self.env_name])
178
+
speaker_encoder/voice_encoder.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.hparams import *
2
+ from speaker_encoder import audio
3
+ from pathlib import Path
4
+ from typing import Union, List
5
+ from torch import nn
6
+ from time import perf_counter as timer
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ class SpeakerEncoder(nn.Module):
12
+ def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
13
+ """
14
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
15
+ If None, defaults to cuda if it is available on your machine, otherwise the model will
16
+ run on cpu. Outputs are always returned on the cpu, as numpy arrays.
17
+ """
18
+ super().__init__()
19
+
20
+ # Define the network
21
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
22
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
23
+ self.relu = nn.ReLU()
24
+
25
+ # Get the target device
26
+ if device is None:
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ elif isinstance(device, str):
29
+ device = torch.device(device)
30
+ self.device = device
31
+
32
+ # Load the pretrained model'speaker weights
33
+ # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
34
+ # if not weights_fpath.exists():
35
+ # raise Exception("Couldn't find the voice encoder pretrained model at %s." %
36
+ # weights_fpath)
37
+
38
+ start = timer()
39
+ checkpoint = torch.load(weights_fpath, map_location="cpu")
40
+
41
+ self.load_state_dict(checkpoint["model_state"], strict=False)
42
+ self.to(device)
43
+
44
+ if verbose:
45
+ print("Loaded the voice encoder model on %s in %.2f seconds." %
46
+ (device.type, timer() - start))
47
+
48
+ def forward(self, mels: torch.FloatTensor):
49
+ """
50
+ Computes the embeddings of a batch of utterance spectrograms.
51
+ :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
52
+ (batch_size, n_frames, n_channels)
53
+ :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
54
+ Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
55
+ """
56
+ # Pass the input through the LSTM layers and retrieve the final hidden state of the last
57
+ # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
58
+ _, (hidden, _) = self.lstm(mels)
59
+ embeds_raw = self.relu(self.linear(hidden[-1]))
60
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
61
+
62
+ @staticmethod
63
+ def compute_partial_slices(n_samples: int, rate, min_coverage):
64
+ """
65
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to
66
+ obtain partial utterances of <partials_n_frames> each. Both the waveform and the
67
+ mel spectrogram slices are returned, so as to make each partial utterance waveform
68
+ correspond to its spectrogram.
69
+
70
+ The returned ranges may be indexing further than the length of the waveform. It is
71
+ recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
72
+
73
+ :param n_samples: the number of samples in the waveform
74
+ :param rate: how many partial utterances should occur per second. Partial utterances must
75
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
76
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
77
+ the minimum rate is thus 0.625.
78
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
79
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
80
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
81
+ it will be discarded. If there aren't enough frames for one partial utterance,
82
+ this parameter is ignored so that the function always returns at least one slice.
83
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
84
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
85
+ utterances.
86
+ """
87
+ assert 0 < min_coverage <= 1
88
+
89
+ # Compute how many frames separate two partial utterances
90
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
91
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
92
+ frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
93
+ assert 0 < frame_step, "The rate is too high"
94
+ assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
95
+ (sampling_rate / (samples_per_frame * partials_n_frames))
96
+
97
+ # Compute the slices
98
+ wav_slices, mel_slices = [], []
99
+ steps = max(1, n_frames - partials_n_frames + frame_step + 1)
100
+ for i in range(0, steps, frame_step):
101
+ mel_range = np.array([i, i + partials_n_frames])
102
+ wav_range = mel_range * samples_per_frame
103
+ mel_slices.append(slice(*mel_range))
104
+ wav_slices.append(slice(*wav_range))
105
+
106
+ # Evaluate whether extra padding is warranted or not
107
+ last_wav_range = wav_slices[-1]
108
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
109
+ if coverage < min_coverage and len(mel_slices) > 1:
110
+ mel_slices = mel_slices[:-1]
111
+ wav_slices = wav_slices[:-1]
112
+
113
+ return wav_slices, mel_slices
114
+
115
+ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
116
+ """
117
+ Computes an embedding for a single utterance. The utterance is divided in partial
118
+ utterances and an embedding is computed for each. The complete utterance embedding is the
119
+ L2-normed average embedding of the partial utterances.
120
+
121
+ TODO: independent batched version of this function
122
+
123
+ :param wav: a preprocessed utterance waveform as a numpy array of float32
124
+ :param return_partials: if True, the partial embeddings will also be returned along with
125
+ the wav slices corresponding to each partial utterance.
126
+ :param rate: how many partial utterances should occur per second. Partial utterances must
127
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
128
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
129
+ the minimum rate is thus 0.625.
130
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
131
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
132
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
133
+ it will be discarded. If there aren't enough frames for one partial utterance,
134
+ this parameter is ignored so that the function always returns at least one slice.
135
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
136
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
137
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
138
+ returned.
139
+ """
140
+ # Compute where to split the utterance into partials and pad the waveform with zeros if
141
+ # the partial utterances cover a larger range.
142
+ wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
143
+ max_wave_length = wav_slices[-1].stop
144
+ if max_wave_length >= len(wav):
145
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
146
+
147
+ # Split the utterance into partials and forward them through the model
148
+ mel = audio.wav_to_mel_spectrogram(wav)
149
+ mels = np.array([mel[s] for s in mel_slices])
150
+ with torch.no_grad():
151
+ mels = torch.from_numpy(mels).to(self.device)
152
+ partial_embeds = self(mels).cpu().numpy()
153
+
154
+ # Compute the utterance embedding from the partial embeddings
155
+ raw_embed = np.mean(partial_embeds, axis=0)
156
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
157
+
158
+ if return_partials:
159
+ return embed, partial_embeds, wav_slices
160
+ return embed
161
+
162
+ def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
163
+ """
164
+ Compute the embedding of a collection of wavs (presumably from the same speaker) by
165
+ averaging their embedding and L2-normalizing it.
166
+
167
+ :param wavs: list of wavs a numpy arrays of float32.
168
+ :param kwargs: extra arguments to embed_utterance()
169
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
170
+ """
171
+ raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
172
+ for wav in wavs], axis=0)
173
+ return raw_embed / np.linalg.norm(raw_embed, 2)
tts_voice.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tts_order_voice = {'英语 (美国)-Jenny-女': 'en-US-JennyNeural',
2
+ '英语 (美国)-Guy-男': 'en-US-GuyNeural',
3
+ '英语 (美国)-Ana-女': 'en-US-AnaNeural',
4
+ '英语 (美国)-Aria-女': 'en-US-AriaNeural',
5
+ '英语 (美国)-Christopher-男': 'en-US-ChristopherNeural',
6
+ '英语 (美国)-Eric-男': 'en-US-EricNeural',
7
+ '英语 (美国)-Michelle-女': 'en-US-MichelleNeural',
8
+ '英语 (美国)-Roger-男': 'en-US-RogerNeural',
9
+ '西班牙语 (墨西哥)-Dalia-女': 'es-MX-DaliaNeural',
10
+ '西班牙语 (墨西哥)-Jorge-男': 'es-MX-JorgeNeural',
11
+ '韩语 (韩国)-Sun-Hi-女': 'ko-KR-SunHiNeural',
12
+ '韩语 (韩国)-InJoon-男': 'ko-KR-InJoonNeural',
13
+ '泰语 (泰国)-Premwadee-女': 'th-TH-PremwadeeNeural',
14
+ '泰语 (泰国)-Niwat-男': 'th-TH-NiwatNeural',
15
+ '越南语 (越南)-HoaiMy-女': 'vi-VN-HoaiMyNeural',
16
+ '越南语 (越南)-NamMinh-男': 'vi-VN-NamMinhNeural',
17
+ '日语 (日本)-Nanami-女': 'ja-JP-NanamiNeural',
18
+ '日语 (日本)-Keita-男': 'ja-JP-KeitaNeural',
19
+ '法语 (法国)-Denise-女': 'fr-FR-DeniseNeural',
20
+ '法语 (法国)-Eloise-女': 'fr-FR-EloiseNeural',
21
+ '法语 (法国)-Henri-男': 'fr-FR-HenriNeural',
22
+ '葡萄牙语 (巴西)-Francisca-女': 'pt-BR-FranciscaNeural',
23
+ '葡萄牙语 (巴西)-Antonio-男': 'pt-BR-AntonioNeural',
24
+ '印度尼西亚语 (印度尼西亚)-Ardi-男': 'id-ID-ArdiNeural',
25
+ '印度尼西亚语 (印度尼西亚)-Gadis-女': 'id-ID-GadisNeural',
26
+ '希伯来语 (以色列)-Avri-男': 'he-IL-AvriNeural',
27
+ '希伯来语 (以色列)-Hila-女': 'he-IL-HilaNeural',
28
+ '意大利语 (意大利)-Isabella-女': 'it-IT-IsabellaNeural',
29
+ '意大利语 (意大利)-Diego-男': 'it-IT-DiegoNeural',
30
+ '意大利语 (意大利)-Elsa-女': 'it-IT-ElsaNeural',
31
+ '荷兰语 (荷兰)-Colette-女': 'nl-NL-ColetteNeural',
32
+ '荷兰语 (荷兰)-Fenna-女': 'nl-NL-FennaNeural',
33
+ '荷兰语 (荷兰)-Maarten-男': 'nl-NL-MaartenNeural',
34
+ '马来语 (马来西亚)-Osman-男': 'ms-MY-OsmanNeural',
35
+ '马来语 (马来西亚)-Yasmin-女': 'ms-MY-YasminNeural',
36
+ '挪威语 (挪威)-Pernille-女': 'nb-NO-PernilleNeural',
37
+ '挪威语 (挪威)-Finn-男': 'nb-NO-FinnNeural',
38
+ '瑞典语 (瑞典)-Sofie-女': 'sv-SE-SofieNeural',
39
+ '瑞典语 (瑞典)-Mattias-男': 'sv-SE-MattiasNeural',
40
+ '阿拉伯语 (沙特阿拉伯)-Hamed-男': 'ar-SA-HamedNeural',
41
+ '阿拉伯语 (沙特阿拉伯)-Zariyah-女': 'ar-SA-ZariyahNeural',
42
+ '希腊语 (希腊)-Athina-女': 'el-GR-AthinaNeural',
43
+ '希腊语 (希腊)-Nestoras-男': 'el-GR-NestorasNeural',
44
+ '德语 (德国)-Katja-女': 'de-DE-KatjaNeural',
45
+ '德语 (德国)-Amala-女': 'de-DE-AmalaNeural',
46
+ '德语 (德国)-Conrad-男': 'de-DE-ConradNeural',
47
+ '德语 (德国)-Killian-男': 'de-DE-KillianNeural',
48
+ '阿拉伯语 (南非)-Adri-女': 'af-ZA-AdriNeural',
49
+ '阿拉伯语 (南非)-Willem-男': 'af-ZA-WillemNeural',
50
+ '阿姆哈拉语 (埃塞俄比亚)-Ameha-男': 'am-ET-AmehaNeural',
51
+ '阿姆哈拉语 (埃塞俄比亚)-Mekdes-女': 'am-ET-MekdesNeural',
52
+ '阿拉伯语 (阿拉伯联合酋长国)-Fatima-女': 'ar-AE-FatimaNeural',
53
+ '阿拉伯语 (阿拉伯联合酋长国)-Hamdan-男': 'ar-AE-HamdanNeural',
54
+ '阿拉伯语 (巴林)-Ali-男': 'ar-BH-AliNeural',
55
+ '阿拉伯语 (巴林)-Laila-女': 'ar-BH-LailaNeural',
56
+ '阿拉伯语 (阿尔及利亚)-Ismael-男': 'ar-DZ-IsmaelNeural',
57
+ '阿拉伯语 (埃及)-Salma-女': 'ar-EG-SalmaNeural',
58
+ '阿拉伯语 (埃及)-Shakir-男': 'ar-EG-ShakirNeural',
59
+ '阿拉伯语 (伊拉克)-Bassel-男': 'ar-IQ-BasselNeural',
60
+ '阿拉伯语 (伊拉克)-Rana-女': 'ar-IQ-RanaNeural',
61
+ '阿拉伯语 (约旦)-Sana-女': 'ar-JO-SanaNeural',
62
+ '阿拉伯语 (约旦)-Taim-男': 'ar-JO-TaimNeural',
63
+ '阿拉伯语 (科威特)-Fahed-男': 'ar-KW-FahedNeural',
64
+ '阿拉伯语 (科威特)-Noura-女': 'ar-KW-NouraNeural',
65
+ '阿拉伯语 (黎巴嫩)-Layla-女': 'ar-LB-LaylaNeural',
66
+ '阿拉伯语 (黎巴嫩)-Rami-男': 'ar-LB-RamiNeural',
67
+ '阿拉伯语 (利比亚)-Iman-女': 'ar-LY-ImanNeural',
68
+ '阿拉伯语 (利比亚)-Omar-男': 'ar-LY-OmarNeural',
69
+ '阿拉伯语 (摩洛哥)-Jamal-男': 'ar-MA-JamalNeural',
70
+ '阿拉伯语 (摩洛哥)-Mouna-女': 'ar-MA-MounaNeural',
71
+ '阿拉伯语 (阿曼)-Abdullah-男': 'ar-OM-AbdullahNeural',
72
+ '阿拉伯语 (阿曼)-Aysha-女': 'ar-OM-AyshaNeural',
73
+ '阿拉伯语 (卡塔尔)-Amal-女': 'ar-QA-AmalNeural',
74
+ '阿拉伯语 (卡塔尔)-Moaz-男': 'ar-QA-MoazNeural',
75
+ '阿拉伯语 (叙利亚)-Amany-女': 'ar-SY-AmanyNeural',
76
+ '阿拉伯语 (叙利亚)-Laith-男': 'ar-SY-LaithNeural',
77
+ '阿拉伯语 (突尼斯)-Hedi-男': 'ar-TN-HediNeural',
78
+ '阿拉伯语 (突尼斯)-Reem-女': 'ar-TN-ReemNeural',
79
+ '阿拉伯语 (也门)-Maryam-女': 'ar-YE-MaryamNeural',
80
+ '阿拉伯语 (也门)-Saleh-男': 'ar-YE-SalehNeural',
81
+ '阿塞拜疆语 (阿塞拜疆)-Babek-男': 'az-AZ-BabekNeural',
82
+ '阿塞拜疆语 (阿塞拜疆)-Banu-女': 'az-AZ-BanuNeural',
83
+ '保加利亚语 (保加利亚)-Borislav-男': 'bg-BG-BorislavNeural',
84
+ '保加利亚语 (保加利亚)-Kalina-女': 'bg-BG-KalinaNeural',
85
+ '孟加拉语 (孟加拉国)-Nabanita-女': 'bn-BD-NabanitaNeural',
86
+ '孟加拉语 (孟加拉国)-Pradeep-男': 'bn-BD-PradeepNeural',
87
+ '孟加拉语 (印度)-Bashkar-男': 'bn-IN-BashkarNeural',
88
+ '孟加拉语 (印度)-Tanishaa-女': 'bn-IN-TanishaaNeural',
89
+ '波斯尼亚语 (波斯尼亚和黑塞哥维那)-Goran-男': 'bs-BA-GoranNeural',
90
+ '波斯尼亚语 (波斯尼亚和黑塞哥维那)-Vesna-女': 'bs-BA-VesnaNeural',
91
+ '加泰罗尼亚语 (西班牙)-Joana-女': 'ca-ES-JoanaNeural',
92
+ '加泰罗尼亚语 (西班牙)-Enric-男': 'ca-ES-EnricNeural',
93
+ '捷克语 (捷克共和国)-Antonin-男': 'cs-CZ-AntoninNeural',
94
+ '捷克语 (捷克共和国)-Vlasta-女': 'cs-CZ-VlastaNeural',
95
+ '威尔士语 (英国)-Aled-男': 'cy-GB-AledNeural',
96
+ '威尔士语 (英国)-Nia-女': 'cy-GB-NiaNeural',
97
+ '丹麦语 (丹麦)-Christel-女': 'da-DK-ChristelNeural',
98
+ '丹麦语 (丹麦)-Jeppe-男': 'da-DK-JeppeNeural',
99
+ '德语 (奥地利)-Ingrid-女': 'de-AT-IngridNeural',
100
+ '德语 (奥地利)-Jonas-男': 'de-AT-JonasNeural',
101
+ '德语 (瑞士)-Jan-男': 'de-CH-JanNeural',
102
+ '德语 (瑞士)-Leni-女': 'de-CH-LeniNeural',
103
+ '英语 (澳大利亚)-Natasha-女': 'en-AU-NatashaNeural',
104
+ '英语 (澳大利亚)-William-男': 'en-AU-WilliamNeural',
105
+ '英语 (加拿大)-Clara-女': 'en-CA-ClaraNeural',
106
+ '英语 (加拿大)-Liam-男': 'en-CA-LiamNeural',
107
+ '英语 (英国)-Libby-女': 'en-GB-LibbyNeural',
108
+ '英语 (英国)-Maisie-女': 'en-GB-MaisieNeural',
109
+ '英语 (英国)-Ryan-男': 'en-GB-RyanNeural',
110
+ '英语 (英国)-Sonia-女': 'en-GB-SoniaNeural',
111
+ '英语 (英国)-Thomas-男': 'en-GB-ThomasNeural',
112
+ '英语 (香港)-Sam-男': 'en-HK-SamNeural',
113
+ '英语 (香港)-Yan-女': 'en-HK-YanNeural',
114
+ '英语 (爱尔兰)-Connor-男': 'en-IE-ConnorNeural',
115
+ '英语 (爱尔兰)-Emily-女': 'en-IE-EmilyNeural',
116
+ '英语 (印度)-Neerja-女': 'en-IN-NeerjaNeural',
117
+ '英语 (印度)-Prabhat-男': 'en-IN-PrabhatNeural',
118
+ '英语 (肯尼亚)-Asilia-女': 'en-KE-AsiliaNeural',
119
+ '英语 (肯尼亚)-Chilemba-男': 'en-KE-ChilembaNeural',
120
+ '英语 (尼日利亚)-Abeo-男': 'en-NG-AbeoNeural',
121
+ '英语 (尼日利亚)-Ezinne-女': 'en-NG-EzinneNeural',
122
+ '英语 (新西兰)-Mitchell-男': 'en-NZ-MitchellNeural',
123
+ '英语 (菲律宾)-James-男': 'en-PH-JamesNeural',
124
+ '英语 (菲律宾)-Rosa-女': 'en-PH-RosaNeural',
125
+ '英语 (新加坡)-Luna-女': 'en-SG-LunaNeural',
126
+ '英语 (新加坡)-Wayne-男': 'en-SG-WayneNeural',
127
+ '英语 (坦桑尼亚)-Elimu-男': 'en-TZ-ElimuNeural',
128
+ '英语 (坦桑尼亚)-Imani-女': 'en-TZ-ImaniNeural',
129
+ '英语 (南非)-Leah-女': 'en-ZA-LeahNeural',
130
+ '英语 (南非)-Luke-男': 'en-ZA-LukeNeural',
131
+ '西班牙语 (阿根廷)-Elena-女': 'es-AR-ElenaNeural',
132
+ '西班牙语 (阿根廷)-Tomas-男': 'es-AR-TomasNeural',
133
+ '西班牙语 (玻利维亚)-Marcelo-男': 'es-BO-MarceloNeural',
134
+ '西班牙语 (玻利维亚)-Sofia-女': 'es-BO-SofiaNeural',
135
+ '西班牙语 (哥伦比亚)-Gonzalo-男': 'es-CO-GonzaloNeural',
136
+ '西班牙语 (哥伦比亚)-Salome-女': 'es-CO-SalomeNeural',
137
+ '西班牙语 (哥斯达黎加)-Juan-男': 'es-CR-JuanNeural',
138
+ '西班牙语 (哥斯达黎加)-Maria-女': 'es-CR-MariaNeural',
139
+ '西班牙语 (古巴)-Belkys-女': 'es-CU-BelkysNeural',
140
+ '西班牙语 (多米尼加共和国)-Emilio-男': 'es-DO-EmilioNeural',
141
+ '西班牙语 (多米尼加共和国)-Ramona-女': 'es-DO-RamonaNeural',
142
+ '西班牙语 (厄瓜多尔)-Andrea-女': 'es-EC-AndreaNeural',
143
+ '西班牙语 (厄瓜多尔)-Luis-男': 'es-EC-LuisNeural',
144
+ '西班牙语 (西班牙)-Alvaro-男': 'es-ES-AlvaroNeural',
145
+ '西班牙语 (西班牙)-Elvira-女': 'es-ES-ElviraNeural',
146
+ '西班牙语 (赤道几内亚)-Teresa-女': 'es-GQ-TeresaNeural',
147
+ '西班牙语 (危地马拉)-Andres-男': 'es-GT-AndresNeural',
148
+ '西班牙语 (危地马拉)-Marta-女': 'es-GT-MartaNeural',
149
+ '西班牙语 (洪都拉斯)-Carlos-男': 'es-HN-CarlosNeural',
150
+ '西班牙语 (洪都拉斯)-Karla-女': 'es-HN-KarlaNeural',
151
+ '西班牙语 (尼加拉瓜)-Federico-男': 'es-NI-FedericoNeural',
152
+ '西班牙语 (尼加拉瓜)-Yolanda-女': 'es-NI-YolandaNeural',
153
+ '西班牙语 (巴拿马)-Margarita-女': 'es-PA-MargaritaNeural',
154
+ '西班牙语 (巴拿马)-Roberto-男': 'es-PA-RobertoNeural',
155
+ '西班牙语 (秘鲁)-Alex-男': 'es-PE-AlexNeural',
156
+ '西班牙语 (秘鲁)-Camila-女': 'es-PE-CamilaNeural',
157
+ '西班牙语 (波多黎各)-Karina-女': 'es-PR-KarinaNeural',
158
+ '西班牙语 (波多黎各)-Victor-男': 'es-PR-VictorNeural',
159
+ '西班牙语 (巴拉圭)-Mario-男': 'es-PY-MarioNeural',
160
+ '西班牙语 (巴拉圭)-Tania-女': 'es-PY-TaniaNeural',
161
+ '西班牙语 (萨尔瓦多)-Lorena-女': 'es-SV-LorenaNeural',
162
+ '西班牙语 (萨尔瓦多)-Rodrigo-男': 'es-SV-RodrigoNeural',
163
+ '西班牙语 (美国)-Alonso-男': 'es-US-AlonsoNeural',
164
+ '西班牙语 (美国)-Paloma-女': 'es-US-PalomaNeural',
165
+ '西班牙语 (乌拉圭)-Mateo-男': 'es-UY-MateoNeural',
166
+ '西班牙语 (乌拉圭)-Valentina-女': 'es-UY-ValentinaNeural',
167
+ '西班牙语 (委内瑞拉)-Paola-女': 'es-VE-PaolaNeural',
168
+ '西班牙语 (委内瑞拉)-Sebastian-男': 'es-VE-SebastianNeural',
169
+ '爱沙尼亚语 (爱沙尼亚)-Anu-���': 'et-EE-AnuNeural',
170
+ '爱沙尼亚语 (爱沙尼亚)-Kert-男': 'et-EE-KertNeural',
171
+ '波斯语 (伊朗)-Dilara-女': 'fa-IR-DilaraNeural',
172
+ '波斯语 (伊朗)-Farid-男': 'fa-IR-FaridNeural',
173
+ '芬兰语 (芬兰)-Harri-男': 'fi-FI-HarriNeural',
174
+ '芬兰语 (芬兰)-Noora-女': 'fi-FI-NooraNeural',
175
+ '法语 (比利时)-Charline-女': 'fr-BE-CharlineNeural',
176
+ '法语 (比利时)-Gerard-男': 'fr-BE-GerardNeural',
177
+ '法语 (加拿大)-Sylvie-女': 'fr-CA-SylvieNeural',
178
+ '法语 (加拿大)-Antoine-男': 'fr-CA-AntoineNeural',
179
+ '法语 (加拿大)-Jean-男': 'fr-CA-JeanNeural',
180
+ '法语 (瑞士)-Ariane-女': 'fr-CH-ArianeNeural',
181
+ '法语 (瑞士)-Fabrice-男': 'fr-CH-FabriceNeural',
182
+ '爱尔兰语 (爱尔兰)-Colm-男': 'ga-IE-ColmNeural',
183
+ '爱尔兰语 (爱尔兰)-Orla-女': 'ga-IE-OrlaNeural',
184
+ '加利西亚语 (西班牙)-Roi-男': 'gl-ES-RoiNeural',
185
+ '加利西亚语 (西班牙)-Sabela-女': 'gl-ES-SabelaNeural',
186
+ '古吉拉特语 (印度)-Dhwani-女': 'gu-IN-DhwaniNeural',
187
+ '古吉拉特语 (印度)-Niranjan-男': 'gu-IN-NiranjanNeural',
188
+ '印地语 (印度)-Madhur-男': 'hi-IN-MadhurNeural',
189
+ '印地语 (印度)-Swara-女': 'hi-IN-SwaraNeural',
190
+ '克罗地亚语 (克罗地亚)-Gabrijela-女': 'hr-HR-GabrijelaNeural',
191
+ '克罗地亚语 (克罗地亚)-Srecko-男': 'hr-HR-SreckoNeural',
192
+ '匈牙利语 (匈牙利)-Noemi-女': 'hu-HU-NoemiNeural',
193
+ '匈牙利语 (匈牙利)-Tamas-男': 'hu-HU-TamasNeural',
194
+ '冰岛语 (冰岛)-Gudrun-女': 'is-IS-GudrunNeural',
195
+ '冰岛语 (冰岛)-Gunnar-男': 'is-IS-GunnarNeural',
196
+ '爪哇语 (印度尼西亚)-Dimas-男': 'jv-ID-DimasNeural',
197
+ '爪哇语 (印度尼西亚)-Siti-女': 'jv-ID-SitiNeural',
198
+ '格鲁吉亚语 (格鲁吉亚)-Eka-女': 'ka-GE-EkaNeural',
199
+ '格鲁吉亚语 (格鲁吉亚)-Giorgi-男': 'ka-GE-GiorgiNeural',
200
+ '哈萨克语 (哈萨克斯坦)-Aigul-女': 'kk-KZ-AigulNeural',
201
+ '哈萨克语 (哈萨克斯坦)-Daulet-男': 'kk-KZ-DauletNeural',
202
+ '高棉语 (柬埔寨)-Piseth-男': 'km-KH-PisethNeural',
203
+ '高棉语 (柬埔寨)-Sreymom-女': 'km-KH-SreymomNeural',
204
+ '卡纳达语 (印度)-Gagan-男': 'kn-IN-GaganNeural',
205
+ '卡纳达语 (印度)-Sapna-女': 'kn-IN-SapnaNeural',
206
+ '老挝语 (老挝)-Chanthavong-男': 'lo-LA-ChanthavongNeural',
207
+ '老挝语 (老挝)-Keomany-女': 'lo-LA-KeomanyNeural',
208
+ '立陶宛语 (立陶宛)-Leonas-男': 'lt-LT-LeonasNeural',
209
+ '立陶宛语 (立陶宛)-Ona-女': 'lt-LT-OnaNeural',
210
+ '拉脱维亚语 (拉脱维亚)-Everita-女': 'lv-LV-EveritaNeural',
211
+ '拉脱维亚语 (拉脱维亚)-Nils-男': 'lv-LV-NilsNeural',
212
+ '马其顿语 (北马其顿共和国)-Aleksandar-男': 'mk-MK-AleksandarNeural',
213
+ '马其顿语 (北马其顿共和国)-Marija-女': 'mk-MK-MarijaNeural',
214
+ '马拉雅拉姆语 (印度)-Midhun-男': 'ml-IN-MidhunNeural',
215
+ '马拉雅拉姆语 (印度)-Sobhana-女': 'ml-IN-SobhanaNeural',
216
+ '蒙古语 (蒙古)-Bataa-男': 'mn-MN-BataaNeural',
217
+ '蒙古语 (蒙古)-Yesui-女': 'mn-MN-YesuiNeural',
218
+ '马拉地语 (印度)-Aarohi-女': 'mr-IN-AarohiNeural',
219
+ '马拉地语 (印度)-Manohar-男': 'mr-IN-ManoharNeural',
220
+ '马耳他语 (马耳他)-Grace-女': 'mt-MT-GraceNeural',
221
+ '马耳他语 (马耳他)-Joseph-男': 'mt-MT-JosephNeural',
222
+ '缅甸语 (缅甸)-Nilar-女': 'my-MM-NilarNeural',
223
+ '缅甸语 (缅甸)-Thiha-男': 'my-MM-ThihaNeural',
224
+ '尼泊尔语 (尼泊尔)-Hemkala-女': 'ne-NP-HemkalaNeural',
225
+ '尼泊尔语 (尼泊尔)-Sagar-男': 'ne-NP-SagarNeural',
226
+ '荷兰语 (比利时)-Arnaud-男': 'nl-BE-ArnaudNeural',
227
+ '荷兰语 (比利时)-Dena-女': 'nl-BE-DenaNeural',
228
+ '波兰语 (波兰)-Marek-男': 'pl-PL-MarekNeural',
229
+ '波兰语 (波兰)-Zofia-女': 'pl-PL-ZofiaNeural',
230
+ '普什图语 (阿富汗)-Gul Nawaz-男': 'ps-AF-GulNawazNeural',
231
+ '普什图语 (阿富汗)-Latifa-女': 'ps-AF-LatifaNeural',
232
+ '葡萄牙语 (葡萄牙)-Duarte-男': 'pt-PT-DuarteNeural',
233
+ '葡萄牙语 (葡萄牙)-Raquel-女': 'pt-PT-RaquelNeural',
234
+ '罗马尼亚语 (罗马尼亚)-Alina-女': 'ro-RO-AlinaNeural',
235
+ '罗马尼亚语 (罗马尼亚)-Emil-男': 'ro-RO-EmilNeural',
236
+ '俄语 (俄罗斯)-Svetlana-女': 'ru-RU-SvetlanaNeural',
237
+ '俄语 (俄罗斯)-Dmitry-男': 'ru-RU-DmitryNeural',
238
+ '僧伽罗语 (斯里兰卡)-Sameera-男': 'si-LK-SameeraNeural',
239
+ '僧伽罗语 (斯里兰卡)-Thilini-女': 'si-LK-ThiliniNeural',
240
+ '斯洛伐克语 (斯洛伐克)-Lukas-男': 'sk-SK-LukasNeural',
241
+ '斯洛伐克语 (斯洛伐克)-Viktoria-女': 'sk-SK-ViktoriaNeural',
242
+ '斯洛文尼亚语 (斯洛文尼亚)-Petra-女': 'sl-SI-PetraNeural',
243
+ '斯洛文尼亚语 (斯洛文尼亚)-Rok-男': 'sl-SI-RokNeural',
244
+ '索马里语 (索马里)-Muuse-男': 'so-SO-MuuseNeural',
245
+ '索马里语 (索马里)-Ubax-女': 'so-SO-UbaxNeural',
246
+ '阿尔巴尼亚语 (阿尔巴尼亚)-Anila-女': 'sq-AL-AnilaNeural',
247
+ '阿尔巴尼亚语 (阿尔巴尼亚)-Ilir-男': 'sq-AL-IlirNeural',
248
+ '塞尔维亚语 (塞尔维亚)-Nicholas-男': 'sr-RS-NicholasNeural',
249
+ '塞尔维亚语 (塞尔维亚)-Sophie-女': 'sr-RS-SophieNeural',
250
+ '巽他语 (印度尼西亚)-Jajang-男': 'su-ID-JajangNeural',
251
+ '巽他语 (印度尼��亚)-Tuti-女': 'su-ID-TutiNeural',
252
+ '斯瓦希里语 (肯尼亚)-Rafiki-男': 'sw-KE-RafikiNeural',
253
+ '斯瓦希里语 (肯尼亚)-Zuri-女': 'sw-KE-ZuriNeural',
254
+ '斯瓦希里语 (坦桑尼亚)-Daudi-男': 'sw-TZ-DaudiNeural',
255
+ '斯瓦希里语 (坦桑尼亚)-Rehema-女': 'sw-TZ-RehemaNeural',
256
+ '泰米尔语 (印度)-Pallavi-女': 'ta-IN-PallaviNeural',
257
+ '泰米尔语 (印度)-Valluvar-男': 'ta-IN-ValluvarNeural',
258
+ '泰米尔语 (斯里兰卡)-Kumar-男': 'ta-LK-KumarNeural',
259
+ '泰米尔语 (斯里兰卡)-Saranya-女': 'ta-LK-SaranyaNeural',
260
+ '泰米尔语 (马来西亚)-Kani-女': 'ta-MY-KaniNeural',
261
+ '泰米尔语 (马来西亚)-Surya-男': 'ta-MY-SuryaNeural',
262
+ '泰米尔语 (新加坡)-Anbu-男': 'ta-SG-AnbuNeural',
263
+ '泰卢固语 (印度)-Mohan-男': 'te-IN-MohanNeural',
264
+ '泰卢固语 (印度)-Shruti-女': 'te-IN-ShrutiNeural',
265
+ '土耳其语 (土耳其)-Ahmet-男': 'tr-TR-AhmetNeural',
266
+ '土耳其语 (土耳其)-Emel-女': 'tr-TR-EmelNeural',
267
+ '乌克兰语 (乌克兰)-Ostap-男': 'uk-UA-OstapNeural',
268
+ '乌克兰语 (乌克兰)-Polina-女': 'uk-UA-PolinaNeural',
269
+ '乌尔都语 (印度)-Gul-女': 'ur-IN-GulNeural',
270
+ '乌尔都语 (印度)-Salman-男': 'ur-IN-SalmanNeural',
271
+ '乌尔都语 (巴基斯坦)-Asad-男': 'ur-PK-AsadNeural',
272
+ '乌尔都语 (巴基斯坦)-Uzma-女': 'ur-PK-UzmaNeural',
273
+ '乌兹别克语 (乌兹别克斯坦)-Madina-女': 'uz-UZ-MadinaNeural',
274
+ '乌兹别克语 (乌兹别克斯坦)-Sardor-男': 'uz-UZ-SardorNeural',
275
+ '普通话 (中国大陆)-Xiaoxiao-女': 'zh-CN-XiaoxiaoNeural',
276
+ '普通话 (中国大陆)-Yunyang-男': 'zh-CN-YunyangNeural',
277
+ '普通话 (中国大陆)-Yunxi-男': 'zh-CN-YunxiNeural',
278
+ '普通话 (中国大陆)-Xiaoyi-女': 'zh-CN-XiaoyiNeural',
279
+ '普通话 (中国大陆)-Yunjian-男': 'zh-CN-YunjianNeural',
280
+ '普通话 (中国大陆)-Yunxia-男': 'zh-CN-YunxiaNeural',
281
+ '东北话 (中国大陆)-Xiaobei-女': 'zh-CN-liaoning-XiaobeiNeural',
282
+ '中原官话 (中国陕西)-Xiaoni-女': 'zh-CN-shaanxi-XiaoniNeural',
283
+ '粤语 (中国香港)-HiuMaan-女': 'zh-HK-HiuMaanNeural',
284
+ '粤语 (中国香港)-HiuGaai-女': 'zh-HK-HiuGaaiNeural',
285
+ '粤语 (中国香港)-WanLung-男': 'zh-HK-WanLungNeural',
286
+ '台湾普通话-HsiaoChen-女': 'zh-TW-HsiaoChenNeural',
287
+ '台湾普通话-HsiaoYu-女': 'zh-TW-HsiaoYuNeural',
288
+ '台湾普通话-YunJhe-男': 'zh-TW-YunJheNeural',
289
+ '祖鲁语 (南非)-Thando-女': 'zu-ZA-ThandoNeural',
290
+ '祖鲁语 (南非)-Themba-男': 'zu-ZA-ThembaNeural'}
utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import logging
5
+ import json
6
+ import subprocess
7
+ import numpy as np
8
+ from scipy.io.wavfile import read
9
+ import torch
10
+ from torch.nn import functional as F
11
+ from commons import sequence_mask
12
+
13
+ MATPLOTLIB_FLAG = False
14
+
15
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
16
+ logger = logging
17
+
18
+
19
+ def get_cmodel(rank):
20
+ checkpoint = torch.load('wavlm/WavLM-Large.pt')
21
+ cfg = WavLMConfig(checkpoint['cfg'])
22
+ cmodel = WavLM(cfg).cuda(rank)
23
+ cmodel.load_state_dict(checkpoint['model'])
24
+ cmodel.eval()
25
+ return cmodel
26
+
27
+
28
+ def get_content(cmodel, y):
29
+ with torch.no_grad():
30
+ c = cmodel.extract_features(y.squeeze(1))[0]
31
+ c = c.transpose(1, 2)
32
+ return c
33
+
34
+
35
+ def get_vocoder(rank):
36
+ with open("hifigan/config.json", "r") as f:
37
+ config = json.load(f)
38
+ config = hifigan.AttrDict(config)
39
+ vocoder = hifigan.Generator(config)
40
+ ckpt = torch.load("hifigan/generator_v1")
41
+ vocoder.load_state_dict(ckpt["generator"])
42
+ vocoder.eval()
43
+ vocoder.remove_weight_norm()
44
+ vocoder.cuda(rank)
45
+ return vocoder
46
+
47
+
48
+ def transform(mel, height): # 68-92
49
+ #r = np.random.random()
50
+ #rate = r * 0.3 + 0.85 # 0.85-1.15
51
+ #height = int(mel.size(-2) * rate)
52
+ tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1)))
53
+ if height >= mel.size(-2):
54
+ return tgt[:, :mel.size(-2), :]
55
+ else:
56
+ silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1)
57
+ silence += torch.randn_like(silence) / 10
58
+ return torch.cat((tgt, silence), 1)
59
+
60
+
61
+ def stretch(mel, width): # 0.5-2
62
+ return torchvision.transforms.functional.resize(mel, (mel.size(-2), width))
63
+
64
+
65
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
66
+ assert os.path.isfile(checkpoint_path)
67
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
68
+ iteration = checkpoint_dict['iteration']
69
+ learning_rate = checkpoint_dict['learning_rate']
70
+ if optimizer is not None:
71
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
72
+ saved_state_dict = checkpoint_dict['model']
73
+ if hasattr(model, 'module'):
74
+ state_dict = model.module.state_dict()
75
+ else:
76
+ state_dict = model.state_dict()
77
+ new_state_dict= {}
78
+ for k, v in state_dict.items():
79
+ try:
80
+ new_state_dict[k] = saved_state_dict[k]
81
+ except:
82
+ logger.info("%s is not in the checkpoint" % k)
83
+ new_state_dict[k] = v
84
+ if hasattr(model, 'module'):
85
+ model.module.load_state_dict(new_state_dict)
86
+ else:
87
+ model.load_state_dict(new_state_dict)
88
+ logger.info("Loaded checkpoint '{}' (iteration {})" .format(
89
+ checkpoint_path, iteration))
90
+ return model, optimizer, learning_rate, iteration
91
+
92
+
93
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
94
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
95
+ iteration, checkpoint_path))
96
+ if hasattr(model, 'module'):
97
+ state_dict = model.module.state_dict()
98
+ else:
99
+ state_dict = model.state_dict()
100
+ torch.save({'model': state_dict,
101
+ 'iteration': iteration,
102
+ 'optimizer': optimizer.state_dict(),
103
+ 'learning_rate': learning_rate}, checkpoint_path)
104
+
105
+
106
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
107
+ for k, v in scalars.items():
108
+ writer.add_scalar(k, v, global_step)
109
+ for k, v in histograms.items():
110
+ writer.add_histogram(k, v, global_step)
111
+ for k, v in images.items():
112
+ writer.add_image(k, v, global_step, dataformats='HWC')
113
+ for k, v in audios.items():
114
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
115
+
116
+
117
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
118
+ f_list = glob.glob(os.path.join(dir_path, regex))
119
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
120
+ x = f_list[-1]
121
+ print(x)
122
+ return x
123
+
124
+
125
+ def plot_spectrogram_to_numpy(spectrogram):
126
+ global MATPLOTLIB_FLAG
127
+ if not MATPLOTLIB_FLAG:
128
+ import matplotlib
129
+ matplotlib.use("Agg")
130
+ MATPLOTLIB_FLAG = True
131
+ mpl_logger = logging.getLogger('matplotlib')
132
+ mpl_logger.setLevel(logging.WARNING)
133
+ import matplotlib.pylab as plt
134
+ import numpy as np
135
+
136
+ fig, ax = plt.subplots(figsize=(10,2))
137
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
138
+ interpolation='none')
139
+ plt.colorbar(im, ax=ax)
140
+ plt.xlabel("Frames")
141
+ plt.ylabel("Channels")
142
+ plt.tight_layout()
143
+
144
+ fig.canvas.draw()
145
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
146
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
147
+ plt.close()
148
+ return data
149
+
150
+
151
+ def plot_alignment_to_numpy(alignment, info=None):
152
+ global MATPLOTLIB_FLAG
153
+ if not MATPLOTLIB_FLAG:
154
+ import matplotlib
155
+ matplotlib.use("Agg")
156
+ MATPLOTLIB_FLAG = True
157
+ mpl_logger = logging.getLogger('matplotlib')
158
+ mpl_logger.setLevel(logging.WARNING)
159
+ import matplotlib.pylab as plt
160
+ import numpy as np
161
+
162
+ fig, ax = plt.subplots(figsize=(6, 4))
163
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
164
+ interpolation='none')
165
+ fig.colorbar(im, ax=ax)
166
+ xlabel = 'Decoder timestep'
167
+ if info is not None:
168
+ xlabel += '\n\n' + info
169
+ plt.xlabel(xlabel)
170
+ plt.ylabel('Encoder timestep')
171
+ plt.tight_layout()
172
+
173
+ fig.canvas.draw()
174
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
175
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
176
+ plt.close()
177
+ return data
178
+
179
+
180
+ def load_wav_to_torch(full_path):
181
+ sampling_rate, data = read(full_path)
182
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
183
+
184
+
185
+ def load_filepaths_and_text(filename, split="|"):
186
+ with open(filename, encoding='utf-8') as f:
187
+ filepaths_and_text = [line.strip().split(split) for line in f]
188
+ return filepaths_and_text
189
+
190
+
191
+ def get_hparams(init=True):
192
+ parser = argparse.ArgumentParser()
193
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
194
+ help='JSON file for configuration')
195
+ parser.add_argument('-m', '--model', type=str, required=True,
196
+ help='Model name')
197
+
198
+ args = parser.parse_args()
199
+ model_dir = os.path.join("./logs", args.model)
200
+
201
+ if not os.path.exists(model_dir):
202
+ os.makedirs(model_dir)
203
+
204
+ config_path = args.config
205
+ config_save_path = os.path.join(model_dir, "config.json")
206
+ if init:
207
+ with open(config_path, "r") as f:
208
+ data = f.read()
209
+ with open(config_save_path, "w") as f:
210
+ f.write(data)
211
+ else:
212
+ with open(config_save_path, "r") as f:
213
+ data = f.read()
214
+ config = json.loads(data)
215
+
216
+ hparams = HParams(**config)
217
+ hparams.model_dir = model_dir
218
+ return hparams
219
+
220
+
221
+ def get_hparams_from_dir(model_dir):
222
+ config_save_path = os.path.join(model_dir, "config.json")
223
+ with open(config_save_path, "r") as f:
224
+ data = f.read()
225
+ config = json.loads(data)
226
+
227
+ hparams =HParams(**config)
228
+ hparams.model_dir = model_dir
229
+ return hparams
230
+
231
+
232
+ def get_hparams_from_file(config_path):
233
+ with open(config_path, "r") as f:
234
+ data = f.read()
235
+ config = json.loads(data)
236
+
237
+ hparams =HParams(**config)
238
+ return hparams
239
+
240
+
241
+ def check_git_hash(model_dir):
242
+ source_dir = os.path.dirname(os.path.realpath(__file__))
243
+ if not os.path.exists(os.path.join(source_dir, ".git")):
244
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
245
+ source_dir
246
+ ))
247
+ return
248
+
249
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
250
+
251
+ path = os.path.join(model_dir, "githash")
252
+ if os.path.exists(path):
253
+ saved_hash = open(path).read()
254
+ if saved_hash != cur_hash:
255
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
256
+ saved_hash[:8], cur_hash[:8]))
257
+ else:
258
+ open(path, "w").write(cur_hash)
259
+
260
+
261
+ def get_logger(model_dir, filename="train.log"):
262
+ global logger
263
+ logger = logging.getLogger(os.path.basename(model_dir))
264
+ logger.setLevel(logging.DEBUG)
265
+
266
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
267
+ if not os.path.exists(model_dir):
268
+ os.makedirs(model_dir)
269
+ h = logging.FileHandler(os.path.join(model_dir, filename))
270
+ h.setLevel(logging.DEBUG)
271
+ h.setFormatter(formatter)
272
+ logger.addHandler(h)
273
+ return logger
274
+
275
+
276
+ class HParams():
277
+ def __init__(self, **kwargs):
278
+ for k, v in kwargs.items():
279
+ if type(v) == dict:
280
+ v = HParams(**v)
281
+ self[k] = v
282
+
283
+ def keys(self):
284
+ return self.__dict__.keys()
285
+
286
+ def items(self):
287
+ return self.__dict__.items()
288
+
289
+ def values(self):
290
+ return self.__dict__.values()
291
+
292
+ def __len__(self):
293
+ return len(self.__dict__)
294
+
295
+ def __getitem__(self, key):
296
+ return getattr(self, key)
297
+
298
+ def __setitem__(self, key, value):
299
+ return setattr(self, key, value)
300
+
301
+ def __contains__(self, key):
302
+ return key in self.__dict__
303
+
304
+ def __repr__(self):
305
+ return self.__dict__.__repr__()