keithhon commited on
Commit
1955797
1 Parent(s): 937d113

Upload toolbox/__init__.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. toolbox/__init__.py +357 -0
toolbox/__init__.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from toolbox.ui import UI
2
+ from encoder import inference as encoder
3
+ from synthesizer.inference import Synthesizer
4
+ from vocoder import inference as vocoder
5
+ from pathlib import Path
6
+ from time import perf_counter as timer
7
+ from toolbox.utterance import Utterance
8
+ import numpy as np
9
+ import traceback
10
+ import sys
11
+ import torch
12
+ import librosa
13
+ from audioread.exceptions import NoBackendError
14
+
15
+ # Use this directory structure for your datasets, or modify it to fit your needs
16
+ recognized_datasets = [
17
+ "LibriSpeech/dev-clean",
18
+ "LibriSpeech/dev-other",
19
+ "LibriSpeech/test-clean",
20
+ "LibriSpeech/test-other",
21
+ "LibriSpeech/train-clean-100",
22
+ "LibriSpeech/train-clean-360",
23
+ "LibriSpeech/train-other-500",
24
+ "LibriTTS/dev-clean",
25
+ "LibriTTS/dev-other",
26
+ "LibriTTS/test-clean",
27
+ "LibriTTS/test-other",
28
+ "LibriTTS/train-clean-100",
29
+ "LibriTTS/train-clean-360",
30
+ "LibriTTS/train-other-500",
31
+ "LJSpeech-1.1",
32
+ "VoxCeleb1/wav",
33
+ "VoxCeleb1/test_wav",
34
+ "VoxCeleb2/dev/aac",
35
+ "VoxCeleb2/test/aac",
36
+ "VCTK-Corpus/wav48",
37
+ ]
38
+
39
+ #Maximum of generated wavs to keep on memory
40
+ MAX_WAVES = 15
41
+
42
+ class Toolbox:
43
+ def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, seed, no_mp3_support):
44
+ if not no_mp3_support:
45
+ try:
46
+ librosa.load("samples/6829_00000.mp3")
47
+ except NoBackendError:
48
+ print("Librosa will be unable to open mp3 files if additional software is not installed.\n"
49
+ "Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.")
50
+ exit(-1)
51
+ self.no_mp3_support = no_mp3_support
52
+ sys.excepthook = self.excepthook
53
+ self.datasets_root = datasets_root
54
+ self.utterances = set()
55
+ self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
56
+
57
+ self.synthesizer = None # type: Synthesizer
58
+ self.current_wav = None
59
+ self.waves_list = []
60
+ self.waves_count = 0
61
+ self.waves_namelist = []
62
+
63
+ # Check for webrtcvad (enables removal of silences in vocoder output)
64
+ try:
65
+ import webrtcvad
66
+ self.trim_silences = True
67
+ except:
68
+ self.trim_silences = False
69
+
70
+ # Initialize the events and the interface
71
+ self.ui = UI()
72
+ self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
73
+ self.setup_events()
74
+ self.ui.start()
75
+
76
+ def excepthook(self, exc_type, exc_value, exc_tb):
77
+ traceback.print_exception(exc_type, exc_value, exc_tb)
78
+ self.ui.log("Exception: %s" % exc_value)
79
+
80
+ def setup_events(self):
81
+ # Dataset, speaker and utterance selection
82
+ self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
83
+ random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
84
+ recognized_datasets,
85
+ level)
86
+ self.ui.random_dataset_button.clicked.connect(random_func(0))
87
+ self.ui.random_speaker_button.clicked.connect(random_func(1))
88
+ self.ui.random_utterance_button.clicked.connect(random_func(2))
89
+ self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
90
+ self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
91
+
92
+ # Model selection
93
+ self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
94
+ def func():
95
+ self.synthesizer = None
96
+ self.ui.synthesizer_box.currentIndexChanged.connect(func)
97
+ self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
98
+
99
+ # Utterance selection
100
+ func = lambda: self.load_from_browser(self.ui.browse_file())
101
+ self.ui.browser_browse_button.clicked.connect(func)
102
+ func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
103
+ self.ui.utterance_history.currentIndexChanged.connect(func)
104
+ func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
105
+ self.ui.play_button.clicked.connect(func)
106
+ self.ui.stop_button.clicked.connect(self.ui.stop)
107
+ self.ui.record_button.clicked.connect(self.record)
108
+
109
+ #Audio
110
+ self.ui.setup_audio_devices(Synthesizer.sample_rate)
111
+
112
+ #Wav playback & save
113
+ func = lambda: self.replay_last_wav()
114
+ self.ui.replay_wav_button.clicked.connect(func)
115
+ func = lambda: self.export_current_wave()
116
+ self.ui.export_wav_button.clicked.connect(func)
117
+ self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
118
+
119
+ # Generation
120
+ func = lambda: self.synthesize() or self.vocode()
121
+ self.ui.generate_button.clicked.connect(func)
122
+ self.ui.synthesize_button.clicked.connect(self.synthesize)
123
+ self.ui.vocode_button.clicked.connect(self.vocode)
124
+ self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
125
+
126
+ # UMAP legend
127
+ self.ui.clear_button.clicked.connect(self.clear_utterances)
128
+
129
+ def set_current_wav(self, index):
130
+ self.current_wav = self.waves_list[index]
131
+
132
+ def export_current_wave(self):
133
+ self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)
134
+
135
+ def replay_last_wav(self):
136
+ self.ui.play(self.current_wav, Synthesizer.sample_rate)
137
+
138
+ def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed):
139
+ self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
140
+ self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir)
141
+ self.ui.populate_gen_options(seed, self.trim_silences)
142
+
143
+ def load_from_browser(self, fpath=None):
144
+ if fpath is None:
145
+ fpath = Path(self.datasets_root,
146
+ self.ui.current_dataset_name,
147
+ self.ui.current_speaker_name,
148
+ self.ui.current_utterance_name)
149
+ name = str(fpath.relative_to(self.datasets_root))
150
+ speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
151
+
152
+ # Select the next utterance
153
+ if self.ui.auto_next_checkbox.isChecked():
154
+ self.ui.browser_select_next()
155
+ elif fpath == "":
156
+ return
157
+ else:
158
+ name = fpath.name
159
+ speaker_name = fpath.parent.name
160
+
161
+ if fpath.suffix.lower() == ".mp3" and self.no_mp3_support:
162
+ self.ui.log("Error: No mp3 file argument was passed but an mp3 file was used")
163
+ return
164
+
165
+ # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
166
+ # playback, so as to have a fair comparison with the generated audio
167
+ wav = Synthesizer.load_preprocess_wav(fpath)
168
+ self.ui.log("Loaded %s" % name)
169
+
170
+ self.add_real_utterance(wav, name, speaker_name)
171
+
172
+ def record(self):
173
+ wav = self.ui.record_one(encoder.sampling_rate, 5)
174
+ if wav is None:
175
+ return
176
+ self.ui.play(wav, encoder.sampling_rate)
177
+
178
+ speaker_name = "user01"
179
+ name = speaker_name + "_rec_%05d" % np.random.randint(100000)
180
+ self.add_real_utterance(wav, name, speaker_name)
181
+
182
+ def add_real_utterance(self, wav, name, speaker_name):
183
+ # Compute the mel spectrogram
184
+ spec = Synthesizer.make_spectrogram(wav)
185
+ self.ui.draw_spec(spec, "current")
186
+
187
+ # Compute the embedding
188
+ if not encoder.is_loaded():
189
+ self.init_encoder()
190
+ encoder_wav = encoder.preprocess_wav(wav)
191
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
192
+
193
+ # Add the utterance
194
+ utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
195
+ self.utterances.add(utterance)
196
+ self.ui.register_utterance(utterance)
197
+
198
+ # Plot it
199
+ self.ui.draw_embed(embed, name, "current")
200
+ self.ui.draw_umap_projections(self.utterances)
201
+
202
+ def clear_utterances(self):
203
+ self.utterances.clear()
204
+ self.ui.draw_umap_projections(self.utterances)
205
+
206
+ def synthesize(self):
207
+ self.ui.log("Generating the mel spectrogram...")
208
+ self.ui.set_loading(1)
209
+
210
+ # Update the synthesizer random seed
211
+ if self.ui.random_seed_checkbox.isChecked():
212
+ seed = int(self.ui.seed_textbox.text())
213
+ self.ui.populate_gen_options(seed, self.trim_silences)
214
+ else:
215
+ seed = None
216
+
217
+ if seed is not None:
218
+ torch.manual_seed(seed)
219
+
220
+ # Synthesize the spectrogram
221
+ if self.synthesizer is None or seed is not None:
222
+ self.init_synthesizer()
223
+
224
+ texts = self.ui.text_prompt.toPlainText().split("\n")
225
+ embed = self.ui.selected_utterance.embed
226
+ embeds = [embed] * len(texts)
227
+ specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
228
+ breaks = [spec.shape[1] for spec in specs]
229
+ spec = np.concatenate(specs, axis=1)
230
+
231
+ self.ui.draw_spec(spec, "generated")
232
+ self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
233
+ self.ui.set_loading(0)
234
+
235
+ def vocode(self):
236
+ speaker_name, spec, breaks, _ = self.current_generated
237
+ assert spec is not None
238
+
239
+ # Initialize the vocoder model and make it determinstic, if user provides a seed
240
+ if self.ui.random_seed_checkbox.isChecked():
241
+ seed = int(self.ui.seed_textbox.text())
242
+ self.ui.populate_gen_options(seed, self.trim_silences)
243
+ else:
244
+ seed = None
245
+
246
+ if seed is not None:
247
+ torch.manual_seed(seed)
248
+
249
+ # Synthesize the waveform
250
+ if not vocoder.is_loaded() or seed is not None:
251
+ self.init_vocoder()
252
+
253
+ def vocoder_progress(i, seq_len, b_size, gen_rate):
254
+ real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
255
+ line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
256
+ % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
257
+ self.ui.log(line, "overwrite")
258
+ self.ui.set_loading(i, seq_len)
259
+ if self.ui.current_vocoder_fpath is not None:
260
+ self.ui.log("")
261
+ wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
262
+ else:
263
+ self.ui.log("Waveform generation with Griffin-Lim... ")
264
+ wav = Synthesizer.griffin_lim(spec)
265
+ self.ui.set_loading(0)
266
+ self.ui.log(" Done!", "append")
267
+
268
+ # Add breaks
269
+ b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
270
+ b_starts = np.concatenate(([0], b_ends[:-1]))
271
+ wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
272
+ breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
273
+ wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
274
+
275
+ # Trim excessive silences
276
+ if self.ui.trim_silences_checkbox.isChecked():
277
+ wav = encoder.preprocess_wav(wav)
278
+
279
+ # Play it
280
+ wav = wav / np.abs(wav).max() * 0.97
281
+ self.ui.play(wav, Synthesizer.sample_rate)
282
+
283
+ # Name it (history displayed in combobox)
284
+ # TODO better naming for the combobox items?
285
+ wav_name = str(self.waves_count + 1)
286
+
287
+ #Update waves combobox
288
+ self.waves_count += 1
289
+ if self.waves_count > MAX_WAVES:
290
+ self.waves_list.pop()
291
+ self.waves_namelist.pop()
292
+ self.waves_list.insert(0, wav)
293
+ self.waves_namelist.insert(0, wav_name)
294
+
295
+ self.ui.waves_cb.disconnect()
296
+ self.ui.waves_cb_model.setStringList(self.waves_namelist)
297
+ self.ui.waves_cb.setCurrentIndex(0)
298
+ self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
299
+
300
+ # Update current wav
301
+ self.set_current_wav(0)
302
+
303
+ #Enable replay and save buttons:
304
+ self.ui.replay_wav_button.setDisabled(False)
305
+ self.ui.export_wav_button.setDisabled(False)
306
+
307
+ # Compute the embedding
308
+ # TODO: this is problematic with different sampling rates, gotta fix it
309
+ if not encoder.is_loaded():
310
+ self.init_encoder()
311
+ encoder_wav = encoder.preprocess_wav(wav)
312
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
313
+
314
+ # Add the utterance
315
+ name = speaker_name + "_gen_%05d" % np.random.randint(100000)
316
+ utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
317
+ self.utterances.add(utterance)
318
+
319
+ # Plot it
320
+ self.ui.draw_embed(embed, name, "generated")
321
+ self.ui.draw_umap_projections(self.utterances)
322
+
323
+ def init_encoder(self):
324
+ model_fpath = self.ui.current_encoder_fpath
325
+
326
+ self.ui.log("Loading the encoder %s... " % model_fpath)
327
+ self.ui.set_loading(1)
328
+ start = timer()
329
+ encoder.load_model(model_fpath)
330
+ self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
331
+ self.ui.set_loading(0)
332
+
333
+ def init_synthesizer(self):
334
+ model_fpath = self.ui.current_synthesizer_fpath
335
+
336
+ self.ui.log("Loading the synthesizer %s... " % model_fpath)
337
+ self.ui.set_loading(1)
338
+ start = timer()
339
+ self.synthesizer = Synthesizer(model_fpath)
340
+ self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
341
+ self.ui.set_loading(0)
342
+
343
+ def init_vocoder(self):
344
+ model_fpath = self.ui.current_vocoder_fpath
345
+ # Case of Griffin-lim
346
+ if model_fpath is None:
347
+ return
348
+
349
+ self.ui.log("Loading the vocoder %s... " % model_fpath)
350
+ self.ui.set_loading(1)
351
+ start = timer()
352
+ vocoder.load_model(model_fpath)
353
+ self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
354
+ self.ui.set_loading(0)
355
+
356
+ def update_seed_textbox(self):
357
+ self.ui.update_seed_textbox()