kevinwang676 commited on
Commit
6af2279
0 Parent(s):

Duplicate from kevinwang676/Bark-UI-with-Voice-Cloning

Browse files
.gitattributes ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ SE_checkpoint.pth.tar filter=lfs diff=lfs merge=lfs -text
36
+ best_model.pth.tar filter=lfs diff=lfs merge=lfs -text
37
+ nana_longest_vocal.wav filter=lfs diff=lfs merge=lfs -text
38
+ test.wav filter=lfs diff=lfs merge=lfs -text
39
+ reference.wav filter=lfs diff=lfs merge=lfs -text
40
+ ref.wav filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Bark with Voice Cloning
3
+ emoji: 📊
4
+ colorFrom: purple
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: kevinwang676/Bark-UI-with-Voice-Cloning
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
SE_checkpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f96efb20cbeeefd81fd8336d7f0155bf8902f82f9474e58ccb19d9e12345172
3
+ size 44610930
app.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ os.system("git clone https://github.com/C0untFloyd/bark-gui.git")
5
+ sys.path.append("./bark-gui/")
6
+
7
+ from cProfile import label
8
+ from distutils.command.check import check
9
+ from doctest import Example
10
+ import gradio as gr
11
+ import numpy as np
12
+ import logging
13
+ import torch
14
+ import pytorch_seed
15
+ import time
16
+
17
+ from xml.sax import saxutils
18
+ from bark.api import generate_with_settings
19
+ from bark.api import save_as_prompt
20
+ from settings import Settings
21
+ #import nltk
22
+
23
+ from bark import SAMPLE_RATE
24
+ from bark.clonevoice import clone_voice
25
+ from bark.generation import SAMPLE_RATE, preload_models
26
+ from scipy.io.wavfile import write as write_wav
27
+ from parseinput import split_and_recombine_text, build_ssml, is_ssml, create_clips_from_ssml
28
+ from datetime import datetime
29
+ from tqdm.auto import tqdm
30
+ from id3tagging import add_id3_tag
31
+
32
+ import shutil
33
+
34
+ import string
35
+ import argparse
36
+ import json
37
+
38
+ from TTS.tts.utils.synthesis import synthesis
39
+ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
40
+ try:
41
+ from TTS.utils.audio import AudioProcessor
42
+ except:
43
+ from TTS.utils.audio import AudioProcessor
44
+
45
+
46
+ from TTS.tts.models import setup_model
47
+ from TTS.config import load_config
48
+ from TTS.tts.models.vits import *
49
+
50
+ from TTS.tts.utils.speakers import SpeakerManager
51
+ from pydub import AudioSegment
52
+
53
+ # from google.colab import files
54
+ import librosa
55
+
56
+ from scipy.io.wavfile import write, read
57
+
58
+ import subprocess
59
+
60
+
61
+ OUTPUTFOLDER = "Outputs"
62
+
63
+
64
+ def generate_text_to_speech(text, selected_speaker, text_temp, waveform_temp, eos_prob, quick_generation, complete_settings, seed, progress=gr.Progress(track_tqdm=True)):
65
+ if text == None or len(text) < 1:
66
+ raise gr.Error('No text entered!')
67
+
68
+ # Chunk the text into smaller pieces then combine the generated audio
69
+
70
+ # generation settings
71
+ if selected_speaker == 'None':
72
+ selected_speaker = None
73
+ if seed != None and seed > 2**32 - 1:
74
+ logger.warning(f"Seed {seed} > 2**32 - 1 (max), setting to random")
75
+ seed = None
76
+ if seed == None or seed <= 0:
77
+ seed = np.random.default_rng().integers(1, 2**32 - 1)
78
+ assert(0 < seed and seed < 2**32)
79
+
80
+ voice_name = selected_speaker
81
+ use_last_generation_as_history = "Use last generation as history" in complete_settings
82
+ save_last_generation = "Save generation as Voice" in complete_settings
83
+ progress(0, desc="Generating")
84
+
85
+ silenceshort = np.zeros(int((float(settings.silence_sentence) / 1000.0) * SAMPLE_RATE), dtype=np.float32) # quarter second of silence
86
+ silencelong = np.zeros(int((float(settings.silence_speakers) / 1000.0) * SAMPLE_RATE), dtype=np.float32) # half a second of silence
87
+ full_generation = None
88
+
89
+ all_parts = []
90
+ complete_text = ""
91
+ text = text.lstrip()
92
+ if is_ssml(text):
93
+ list_speak = create_clips_from_ssml(text)
94
+ prev_speaker = None
95
+ for i, clip in tqdm(enumerate(list_speak), total=len(list_speak)):
96
+ selected_speaker = clip[0]
97
+ # Add pause break between speakers
98
+ if i > 0 and selected_speaker != prev_speaker:
99
+ all_parts += [silencelong.copy()]
100
+ prev_speaker = selected_speaker
101
+ text = clip[1]
102
+ text = saxutils.unescape(text)
103
+ if selected_speaker == "None":
104
+ selected_speaker = None
105
+
106
+ print(f"\nGenerating Text ({i+1}/{len(list_speak)}) -> {selected_speaker} (Seed {seed}):`{text}`")
107
+ complete_text += text
108
+ with pytorch_seed.SavedRNG(seed):
109
+ audio_array = generate_with_settings(text_prompt=text, voice_name=selected_speaker, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob)
110
+ seed = torch.random.initial_seed()
111
+ if len(list_speak) > 1:
112
+ filename = create_filename(OUTPUTFOLDER, seed, "audioclip",".wav")
113
+ save_wav(audio_array, filename)
114
+ add_id3_tag(filename, text, selected_speaker, seed)
115
+
116
+ all_parts += [audio_array]
117
+ else:
118
+ texts = split_and_recombine_text(text, settings.input_text_desired_length, settings.input_text_max_length)
119
+ for i, text in tqdm(enumerate(texts), total=len(texts)):
120
+ print(f"\nGenerating Text ({i+1}/{len(texts)}) -> {selected_speaker} (Seed {seed}):`{text}`")
121
+ complete_text += text
122
+ if quick_generation == True:
123
+ with pytorch_seed.SavedRNG(seed):
124
+ audio_array = generate_with_settings(text_prompt=text, voice_name=selected_speaker, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob)
125
+ seed = torch.random.initial_seed()
126
+ else:
127
+ full_output = use_last_generation_as_history or save_last_generation
128
+ if full_output:
129
+ full_generation, audio_array = generate_with_settings(text_prompt=text, voice_name=voice_name, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob, output_full=True)
130
+ else:
131
+ audio_array = generate_with_settings(text_prompt=text, voice_name=voice_name, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob)
132
+
133
+ # Noticed this in the HF Demo - convert to 16bit int -32767/32767 - most used audio format
134
+ # audio_array = (audio_array * 32767).astype(np.int16)
135
+
136
+ if len(texts) > 1:
137
+ filename = create_filename(OUTPUTFOLDER, seed, "audioclip",".wav")
138
+ save_wav(audio_array, filename)
139
+ add_id3_tag(filename, text, selected_speaker, seed)
140
+
141
+ if quick_generation == False and (save_last_generation == True or use_last_generation_as_history == True):
142
+ # save to npz
143
+ voice_name = create_filename(OUTPUTFOLDER, seed, "audioclip", ".npz")
144
+ save_as_prompt(voice_name, full_generation)
145
+ if use_last_generation_as_history:
146
+ selected_speaker = voice_name
147
+
148
+ all_parts += [audio_array]
149
+ # Add short pause between sentences
150
+ if text[-1] in "!?.\n" and i > 1:
151
+ all_parts += [silenceshort.copy()]
152
+
153
+ # save & play audio
154
+ result = create_filename(OUTPUTFOLDER, seed, "final",".wav")
155
+ save_wav(np.concatenate(all_parts), result)
156
+ # write id3 tag with text truncated to 60 chars, as a precaution...
157
+ add_id3_tag(result, complete_text, selected_speaker, seed)
158
+ return result
159
+
160
+ def create_filename(path, seed, name, extension):
161
+ now = datetime.now()
162
+ date_str =now.strftime("%m-%d-%Y")
163
+ outputs_folder = os.path.join(os.getcwd(), path)
164
+ if not os.path.exists(outputs_folder):
165
+ os.makedirs(outputs_folder)
166
+
167
+ sub_folder = os.path.join(outputs_folder, date_str)
168
+ if not os.path.exists(sub_folder):
169
+ os.makedirs(sub_folder)
170
+
171
+ time_str = now.strftime("%H-%M-%S")
172
+ file_name = f"{name}_{time_str}_s{seed}{extension}"
173
+ return os.path.join(sub_folder, file_name)
174
+
175
+
176
+ def save_wav(audio_array, filename):
177
+ write_wav(filename, SAMPLE_RATE, audio_array)
178
+
179
+ def save_voice(filename, semantic_prompt, coarse_prompt, fine_prompt):
180
+ np.savez_compressed(
181
+ filename,
182
+ semantic_prompt=semantic_prompt,
183
+ coarse_prompt=coarse_prompt,
184
+ fine_prompt=fine_prompt
185
+ )
186
+
187
+
188
+ def on_quick_gen_changed(checkbox):
189
+ if checkbox == False:
190
+ return gr.CheckboxGroup.update(visible=True)
191
+ return gr.CheckboxGroup.update(visible=False)
192
+
193
+ def delete_output_files(checkbox_state):
194
+ if checkbox_state:
195
+ outputs_folder = os.path.join(os.getcwd(), OUTPUTFOLDER)
196
+ if os.path.exists(outputs_folder):
197
+ purgedir(outputs_folder)
198
+ return False
199
+
200
+
201
+ # https://stackoverflow.com/a/54494779
202
+ def purgedir(parent):
203
+ for root, dirs, files in os.walk(parent):
204
+ for item in files:
205
+ # Delete subordinate files
206
+ filespec = os.path.join(root, item)
207
+ os.unlink(filespec)
208
+ for item in dirs:
209
+ # Recursively perform this operation for subordinate directories
210
+ purgedir(os.path.join(root, item))
211
+
212
+ def convert_text_to_ssml(text, selected_speaker):
213
+ return build_ssml(text, selected_speaker)
214
+
215
+
216
+ def apply_settings(themes, input_server_name, input_server_port, input_server_public, input_desired_len, input_max_len, input_silence_break, input_silence_speaker):
217
+ settings.selected_theme = themes
218
+ settings.server_name = input_server_name
219
+ settings.server_port = input_server_port
220
+ settings.server_share = input_server_public
221
+ settings.input_text_desired_length = input_desired_len
222
+ settings.input_text_max_length = input_max_len
223
+ settings.silence_sentence = input_silence_break
224
+ settings.silence_speaker = input_silence_speaker
225
+ settings.save()
226
+
227
+ def restart():
228
+ global restart_server
229
+ restart_server = True
230
+
231
+
232
+ def create_version_html():
233
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
234
+ versions_html = f"""
235
+ python: <span title="{sys.version}">{python_version}</span>
236
+  • 
237
+ torch: {getattr(torch, '__long_version__',torch.__version__)}
238
+  • 
239
+ gradio: {gr.__version__}
240
+ """
241
+ return versions_html
242
+
243
+
244
+
245
+ logger = logging.getLogger(__name__)
246
+ APPTITLE = "Bark UI Enhanced v0.4.6"
247
+
248
+
249
+ autolaunch = False
250
+
251
+ if len(sys.argv) > 1:
252
+ autolaunch = "-autolaunch" in sys.argv
253
+
254
+
255
+ if torch.cuda.is_available() == False:
256
+ os.environ['BARK_FORCE_CPU'] = 'True'
257
+ logger.warning("No CUDA detected, fallback to CPU!")
258
+
259
+ print(f'smallmodels={os.environ.get("SUNO_USE_SMALL_MODELS", False)}')
260
+ print(f'enablemps={os.environ.get("SUNO_ENABLE_MPS", False)}')
261
+ print(f'offloadcpu={os.environ.get("SUNO_OFFLOAD_CPU", False)}')
262
+ print(f'forcecpu={os.environ.get("BARK_FORCE_CPU", False)}')
263
+ print(f'autolaunch={autolaunch}\n\n')
264
+
265
+ #print("Updating nltk\n")
266
+ #nltk.download('punkt')
267
+
268
+ print("Preloading Models\n")
269
+ preload_models()
270
+
271
+ settings = Settings('config.yaml')
272
+
273
+ # Collect all existing speakers/voices in dir
274
+ speakers_list = []
275
+
276
+ for root, dirs, files in os.walk("./bark/assets/prompts"):
277
+ for file in files:
278
+ if(file.endswith(".npz")):
279
+ pathpart = root.replace("./bark/assets/prompts", "")
280
+ name = os.path.join(pathpart, file[:-4])
281
+ if name.startswith("/") or name.startswith("\\"):
282
+ name = name[1:]
283
+ speakers_list.append(name)
284
+
285
+ speakers_list = sorted(speakers_list, key=lambda x: x.lower())
286
+ speakers_list.insert(0, 'None')
287
+
288
+ available_themes = ["Default", "gradio/glass", "gradio/monochrome", "gradio/seafoam", "gradio/soft", "gstaff/xkcd", "freddyaboulton/dracula_revamped", "ysharma/steampunk"]
289
+
290
+ seed = -1
291
+ server_name = settings.server_name
292
+ if len(server_name) < 1:
293
+ server_name = None
294
+ server_port = settings.server_port
295
+ if server_port <= 0:
296
+ server_port = None
297
+ global run_server
298
+ global restart_server
299
+
300
+ run_server = True
301
+
302
+
303
+
304
+
305
+ '''
306
+ from google.colab import drive
307
+ drive.mount('/content/drive')
308
+ src_path = os.path.join(os.path.join(os.path.join(os.path.join(os.getcwd(), 'drive'), 'MyDrive'), 'Colab Notebooks'), 'best_model_latest.pth.tar')
309
+ dst_path = os.path.join(os.getcwd(), 'best_model.pth.tar')
310
+ shutil.copy(src_path, dst_path)
311
+ '''
312
+
313
+ TTS_PATH = "TTS/"
314
+
315
+ # add libraries into environment
316
+ sys.path.append(TTS_PATH) # set this if TTS is not installed globally
317
+
318
+ # Paths definition
319
+
320
+ OUT_PATH = 'out/'
321
+
322
+ # create output path
323
+ os.makedirs(OUT_PATH, exist_ok=True)
324
+
325
+ # model vars
326
+ MODEL_PATH = 'best_model.pth.tar'
327
+ CONFIG_PATH = 'config.json'
328
+ TTS_LANGUAGES = "language_ids.json"
329
+ TTS_SPEAKERS = "speakers.json"
330
+ USE_CUDA = torch.cuda.is_available()
331
+
332
+ # load the config
333
+ C = load_config(CONFIG_PATH)
334
+
335
+ # load the audio processor
336
+ ap = AudioProcessor(**C.audio)
337
+
338
+ speaker_embedding = None
339
+
340
+ C.model_args['d_vector_file'] = TTS_SPEAKERS
341
+ C.model_args['use_speaker_encoder_as_loss'] = False
342
+
343
+ model = setup_model(C)
344
+ model.language_manager.set_language_ids_from_file(TTS_LANGUAGES)
345
+ # print(model.language_manager.num_languages, model.embedded_language_dim)
346
+ # print(model.emb_l)
347
+ cp = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
348
+ # remove speaker encoder
349
+ model_weights = cp['model'].copy()
350
+ for key in list(model_weights.keys()):
351
+ if "speaker_encoder" in key:
352
+ del model_weights[key]
353
+
354
+ model.load_state_dict(model_weights)
355
+
356
+ model.eval()
357
+
358
+ if USE_CUDA:
359
+ model = model.cuda()
360
+
361
+ # synthesize voice
362
+ use_griffin_lim = False
363
+
364
+ # Paths definition
365
+
366
+ CONFIG_SE_PATH = "config_se.json"
367
+ CHECKPOINT_SE_PATH = "SE_checkpoint.pth.tar"
368
+
369
+ # Load the Speaker encoder
370
+
371
+ SE_speaker_manager = SpeakerManager(encoder_model_path=CHECKPOINT_SE_PATH, encoder_config_path=CONFIG_SE_PATH, use_cuda=USE_CUDA)
372
+
373
+ # Define helper function
374
+
375
+ def compute_spec(ref_file):
376
+ y, sr = librosa.load(ref_file, sr=ap.sample_rate)
377
+ spec = ap.spectrogram(y)
378
+ spec = torch.FloatTensor(spec).unsqueeze(0)
379
+ return spec
380
+
381
+
382
+ def voice_conversion(ta, ra, da):
383
+
384
+ target_audio = 'target.wav'
385
+ reference_audio = 'reference.wav'
386
+ driving_audio = 'driving.wav'
387
+
388
+ write(target_audio, ta[0], ta[1])
389
+ write(reference_audio, ra[0], ra[1])
390
+ write(driving_audio, da[0], da[1])
391
+
392
+ # !ffmpeg-normalize $target_audio -nt rms -t=-27 -o $target_audio -ar 16000 -f
393
+ # !ffmpeg-normalize $reference_audio -nt rms -t=-27 -o $reference_audio -ar 16000 -f
394
+ # !ffmpeg-normalize $driving_audio -nt rms -t=-27 -o $driving_audio -ar 16000 -f
395
+
396
+ files = [target_audio, reference_audio, driving_audio]
397
+
398
+ for file in files:
399
+ subprocess.run(["ffmpeg-normalize", file, "-nt", "rms", "-t=-27", "-o", file, "-ar", "16000", "-f"])
400
+
401
+ # ta_ = read(target_audio)
402
+
403
+ target_emb = SE_speaker_manager.compute_d_vector_from_clip([target_audio])
404
+ target_emb = torch.FloatTensor(target_emb).unsqueeze(0)
405
+
406
+ driving_emb = SE_speaker_manager.compute_d_vector_from_clip([reference_audio])
407
+ driving_emb = torch.FloatTensor(driving_emb).unsqueeze(0)
408
+
409
+ # Convert the voice
410
+
411
+ driving_spec = compute_spec(driving_audio)
412
+ y_lengths = torch.tensor([driving_spec.size(-1)])
413
+ if USE_CUDA:
414
+ ref_wav_voc, _, _ = model.voice_conversion(driving_spec.cuda(), y_lengths.cuda(), driving_emb.cuda(), target_emb.cuda())
415
+ ref_wav_voc = ref_wav_voc.squeeze().cpu().detach().numpy()
416
+ else:
417
+ ref_wav_voc, _, _ = model.voice_conversion(driving_spec, y_lengths, driving_emb, target_emb)
418
+ ref_wav_voc = ref_wav_voc.squeeze().detach().numpy()
419
+
420
+ # print("Reference Audio after decoder:")
421
+ # IPython.display.display(Audio(ref_wav_voc, rate=ap.sample_rate))
422
+
423
+ return (ap.sample_rate, ref_wav_voc)
424
+
425
+
426
+ while run_server:
427
+ print(f'Launching {APPTITLE} Server')
428
+
429
+ # Create Gradio Blocks
430
+
431
+ with gr.Blocks(title=f"{APPTITLE}", mode=f"{APPTITLE}", theme=settings.selected_theme) as barkgui:
432
+ with gr.Row():
433
+ with gr.Column():
434
+ gr.Markdown(f"### [{APPTITLE}](https://github.com/C0untFloyd/bark-gui)")
435
+ with gr.Column():
436
+ gr.HTML(create_version_html(), elem_id="versions")
437
+
438
+ with gr.Tab("TTS"):
439
+ with gr.Row():
440
+ with gr.Column():
441
+ placeholder = "Enter text here."
442
+ input_text = gr.Textbox(label="Input Text", lines=4, placeholder=placeholder)
443
+ with gr.Column():
444
+ seedcomponent = gr.Number(label="Seed (default -1 = Random)", precision=0, value=-1)
445
+ convert_to_ssml_button = gr.Button("Convert Text to SSML")
446
+ with gr.Row():
447
+ with gr.Column():
448
+ examples = [
449
+ "Special meanings: [laughter] [laughs] [sighs] [music] [gasps] [clears throat] MAN: WOMAN:",
450
+ "♪ Never gonna make you cry, never gonna say goodbye, never gonna tell a lie and hurt you ♪",
451
+ "And now — a picture of a larch [laughter]",
452
+ """
453
+ WOMAN: I would like an oatmilk latte please.
454
+ MAN: Wow, that's expensive!
455
+ """,
456
+ """<?xml version="1.0"?>
457
+ <speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis"
458
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
459
+ xsi:schemaLocation="http://www.w3.org/2001/10/synthesis
460
+ http://www.w3.org/TR/speech-synthesis/synthesis.xsd"
461
+ xml:lang="en-US">
462
+ <voice name="en_speaker_9">Look at that drunk guy!</voice>
463
+ <voice name="en_speaker_3">Who is he?</voice>
464
+ <voice name="en_speaker_9">WOMAN: [clears throat] 10 years ago, he proposed me and I rejected him.</voice>
465
+ <voice name="en_speaker_3">Oh my God [laughs] he is still celebrating</voice>
466
+ </speak>"""
467
+ ]
468
+ examples = gr.Examples(examples=examples, inputs=input_text)
469
+
470
+ with gr.Row():
471
+ with gr.Column():
472
+ gr.Markdown("[Voice Prompt Library](https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c)")
473
+ speaker = gr.Dropdown(speakers_list, value=speakers_list[0], label="Voice")
474
+ with gr.Column():
475
+ text_temp = gr.Slider(0.1, 1.0, value=0.6, label="Generation Temperature", info="1.0 more diverse, 0.1 more conservative")
476
+ waveform_temp = gr.Slider(0.1, 1.0, value=0.7, label="Waveform temperature", info="1.0 more diverse, 0.1 more conservative")
477
+
478
+ with gr.Row():
479
+ with gr.Column():
480
+ quick_gen_checkbox = gr.Checkbox(label="Quick Generation", value=True)
481
+ settings_checkboxes = ["Use last generation as history", "Save generation as Voice"]
482
+ complete_settings = gr.CheckboxGroup(choices=settings_checkboxes, value=settings_checkboxes, label="Detailed Generation Settings", type="value", interactive=True, visible=False)
483
+ with gr.Column():
484
+ eos_prob = gr.Slider(0.0, 0.5, value=0.05, label="End of sentence probability")
485
+
486
+ with gr.Row():
487
+ with gr.Column():
488
+ tts_create_button = gr.Button("Generate")
489
+ with gr.Column():
490
+ hidden_checkbox = gr.Checkbox(visible=False)
491
+ button_stop_generation = gr.Button("Stop generation")
492
+ with gr.Row():
493
+ output_audio = gr.Audio(label="Generated Audio")
494
+
495
+ with gr.Row():
496
+ inp1 = gr.Audio(label='Target Speaker - Reference Clip')
497
+ inp2 = output_audio
498
+ inp3 = output_audio
499
+ btn = gr.Button("Generate")
500
+ out1 = gr.Audio(label='Target Speaker - Converted Clip')
501
+ btn.click(voice_conversion, [inp1, inp2, inp3], [out1])
502
+
503
+
504
+
505
+ with gr.Tab("Clone Voice"):
506
+ input_audio_filename = gr.Audio(label="Input audio.wav", source="upload", type="filepath")
507
+ transcription_text = gr.Textbox(label="Transcription Text", lines=1, placeholder="Enter Text of your Audio Sample here...")
508
+ initialname = "./bark/assets/prompts/custom/MeMyselfAndI"
509
+ output_voice = gr.Textbox(label="Filename of trained Voice", lines=1, placeholder=initialname, value=initialname)
510
+ clone_voice_button = gr.Button("Create Voice")
511
+ dummy = gr.Text(label="Progress")
512
+
513
+ with gr.Tab("Settings"):
514
+ with gr.Row():
515
+ themes = gr.Dropdown(available_themes, label="Theme", info="Change needs complete restart", value=settings.selected_theme)
516
+ with gr.Row():
517
+ input_server_name = gr.Textbox(label="Server Name", lines=1, info="Leave blank to run locally", value=settings.server_name)
518
+ input_server_port = gr.Number(label="Server Port", precision=0, info="Leave at 0 to use default", value=settings.server_port)
519
+ share_checkbox = gr.Checkbox(label="Public Server", value=settings.server_share)
520
+ with gr.Row():
521
+ input_desired_len = gr.Slider(100, 150, value=settings.input_text_desired_length, label="Desired Input Text Length", info="Ideal length to split input sentences")
522
+ input_max_len = gr.Slider(150, 256, value=settings.input_text_max_length, label="Max Input Text Length", info="Maximum Input Text Length")
523
+ with gr.Row():
524
+ input_silence_break = gr.Slider(1, 1000, value=settings.silence_sentence, label="Sentence Pause Time (ms)", info="Silence between sentences in milliseconds")
525
+ input_silence_speakers = gr.Slider(1, 5000, value=settings.silence_speakers, label="Speaker Pause Time (ms)", info="Silence between different speakers in milliseconds")
526
+
527
+ with gr.Row():
528
+ button_apply_settings = gr.Button("Apply Settings")
529
+ button_apply_restart = gr.Button("Restart Server")
530
+ button_delete_files = gr.Button("Clear output folder")
531
+
532
+ quick_gen_checkbox.change(fn=on_quick_gen_changed, inputs=quick_gen_checkbox, outputs=complete_settings)
533
+ convert_to_ssml_button.click(convert_text_to_ssml, inputs=[input_text, speaker],outputs=input_text)
534
+ gen_click = tts_create_button.click(generate_text_to_speech, inputs=[input_text, speaker, text_temp, waveform_temp, eos_prob, quick_gen_checkbox, complete_settings, seedcomponent],outputs=output_audio)
535
+ button_stop_generation.click(fn=None, inputs=None, outputs=None, cancels=[gen_click])
536
+ # Javascript hack to display modal confirmation dialog
537
+ js = "(x) => confirm('Are you sure? This will remove all files from output folder')"
538
+ button_delete_files.click(None, None, hidden_checkbox, _js=js)
539
+ hidden_checkbox.change(delete_output_files, [hidden_checkbox], [hidden_checkbox])
540
+ clone_voice_button.click(clone_voice, inputs=[input_audio_filename, transcription_text, output_voice], outputs=dummy)
541
+ button_apply_settings.click(apply_settings, inputs=[themes, input_server_name, input_server_port, share_checkbox, input_desired_len, input_max_len, input_silence_break, input_silence_speakers])
542
+ button_apply_restart.click(restart)
543
+ restart_server = False
544
+ try:
545
+ barkgui.queue().launch(show_error=True)
546
+ except:
547
+ restart_server = True
548
+ run_server = False
549
+ try:
550
+ while restart_server == False:
551
+ time.sleep(1.0)
552
+ except (KeyboardInterrupt, OSError):
553
+ print("Keyboard interruption in main thread... closing server.")
554
+ run_server = False
555
+ barkgui.close()
bark/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
2
+ from .generation import SAMPLE_RATE, preload_models
bark/api.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Union
2
+
3
+ import numpy as np
4
+
5
+ from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic
6
+
7
+
8
+ def generate_with_settings(text_prompt, semantic_temp=0.6, eos_p=0.2, coarse_temp=0.7, fine_temp=0.5, voice_name=None, output_full=False):
9
+
10
+ # generation with more control
11
+ x_semantic = generate_text_semantic(
12
+ text_prompt,
13
+ history_prompt=voice_name,
14
+ temp=semantic_temp,
15
+ min_eos_p = eos_p,
16
+ use_kv_caching=True
17
+ )
18
+
19
+ x_coarse_gen = generate_coarse(
20
+ x_semantic,
21
+ history_prompt=voice_name,
22
+ temp=coarse_temp,
23
+ use_kv_caching=True
24
+ )
25
+ x_fine_gen = generate_fine(
26
+ x_coarse_gen,
27
+ history_prompt=voice_name,
28
+ temp=fine_temp,
29
+ )
30
+
31
+ if output_full:
32
+ full_generation = {
33
+ 'semantic_prompt': x_semantic,
34
+ 'coarse_prompt': x_coarse_gen,
35
+ 'fine_prompt': x_fine_gen
36
+ }
37
+ return full_generation, codec_decode(x_fine_gen)
38
+ return codec_decode(x_fine_gen)
39
+
40
+
41
+ def text_to_semantic(
42
+ text: str,
43
+ history_prompt: Optional[Union[Dict, str]] = None,
44
+ temp: float = 0.7,
45
+ silent: bool = False,
46
+ ):
47
+ """Generate semantic array from text.
48
+
49
+ Args:
50
+ text: text to be turned into audio
51
+ history_prompt: history choice for audio cloning
52
+ temp: generation temperature (1.0 more diverse, 0.0 more conservative)
53
+ silent: disable progress bar
54
+
55
+ Returns:
56
+ numpy semantic array to be fed into `semantic_to_waveform`
57
+ """
58
+ x_semantic = generate_text_semantic(
59
+ text,
60
+ history_prompt=history_prompt,
61
+ temp=temp,
62
+ silent=silent,
63
+ use_kv_caching=True
64
+ )
65
+ return x_semantic
66
+
67
+
68
+ def semantic_to_waveform(
69
+ semantic_tokens: np.ndarray,
70
+ history_prompt: Optional[Union[Dict, str]] = None,
71
+ temp: float = 0.7,
72
+ silent: bool = False,
73
+ output_full: bool = False,
74
+ ):
75
+ """Generate audio array from semantic input.
76
+
77
+ Args:
78
+ semantic_tokens: semantic token output from `text_to_semantic`
79
+ history_prompt: history choice for audio cloning
80
+ temp: generation temperature (1.0 more diverse, 0.0 more conservative)
81
+ silent: disable progress bar
82
+ output_full: return full generation to be used as a history prompt
83
+
84
+ Returns:
85
+ numpy audio array at sample frequency 24khz
86
+ """
87
+ coarse_tokens = generate_coarse(
88
+ semantic_tokens,
89
+ history_prompt=history_prompt,
90
+ temp=temp,
91
+ silent=silent,
92
+ use_kv_caching=True
93
+ )
94
+ fine_tokens = generate_fine(
95
+ coarse_tokens,
96
+ history_prompt=history_prompt,
97
+ temp=0.5,
98
+ )
99
+ audio_arr = codec_decode(fine_tokens)
100
+ if output_full:
101
+ full_generation = {
102
+ "semantic_prompt": semantic_tokens,
103
+ "coarse_prompt": coarse_tokens,
104
+ "fine_prompt": fine_tokens,
105
+ }
106
+ return full_generation, audio_arr
107
+ return audio_arr
108
+
109
+
110
+ def save_as_prompt(filepath, full_generation):
111
+ assert(filepath.endswith(".npz"))
112
+ assert(isinstance(full_generation, dict))
113
+ assert("semantic_prompt" in full_generation)
114
+ assert("coarse_prompt" in full_generation)
115
+ assert("fine_prompt" in full_generation)
116
+ np.savez(filepath, **full_generation)
117
+
118
+
119
+ def generate_audio(
120
+ text: str,
121
+ history_prompt: Optional[Union[Dict, str]] = None,
122
+ text_temp: float = 0.7,
123
+ waveform_temp: float = 0.7,
124
+ silent: bool = False,
125
+ output_full: bool = False,
126
+ ):
127
+ """Generate audio array from input text.
128
+
129
+ Args:
130
+ text: text to be turned into audio
131
+ history_prompt: history choice for audio cloning
132
+ text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
133
+ waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
134
+ silent: disable progress bar
135
+ output_full: return full generation to be used as a history prompt
136
+
137
+ Returns:
138
+ numpy audio array at sample frequency 24khz
139
+ """
140
+ semantic_tokens = text_to_semantic(
141
+ text,
142
+ history_prompt=history_prompt,
143
+ temp=text_temp,
144
+ silent=silent,
145
+ )
146
+ out = semantic_to_waveform(
147
+ semantic_tokens,
148
+ history_prompt=history_prompt,
149
+ temp=waveform_temp,
150
+ silent=silent,
151
+ output_full=output_full,
152
+ )
153
+ if output_full:
154
+ full_generation, audio_arr = out
155
+ return full_generation, audio_arr
156
+ else:
157
+ audio_arr = out
158
+ return audio_arr
bark/assets/prompts/announcer.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26f2d1a9e3b6fe453cf5fc8191de26cbfae6276c5b0f7c376c6a0f3c35867f83
3
+ size 16794
bark/assets/prompts/en_speaker_0.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:932f40d879ba8659f1ca26319ba64ea3b0647b2050fe24313bf42b0dff1fe241
3
+ size 28100
bark/assets/prompts/en_speaker_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e7f18015e1ab9b6302ded1e28a971af5306a72f193bb6c411f1948a083c8578
3
+ size 25220
bark/assets/prompts/en_speaker_2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d218990680ece5f2d4fc18ea4783b016b3ae353ec413eaee2058f2d57263c9b3
3
+ size 26236
bark/assets/prompts/en_speaker_3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92c2e2a29145c83738e9b63f082fd1c873d9422468a155463cb27f814aeaea66
3
+ size 34980
bark/assets/prompts/en_speaker_4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:992f91991a9a5359d72f00b09a11a550e71bb8ebfc0cfd877e39d7d41f98b714
3
+ size 23780
bark/assets/prompts/en_speaker_5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18831c3f6014e4a2ff60ad5169b1fae06e28ed07f43f8a3616aafb84515091bf
3
+ size 24740
bark/assets/prompts/en_speaker_6.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fab38dc6b6bc9226bcc414f4c5a9524bc1b2441865a586153fb620127a8faa4e
3
+ size 25540
bark/assets/prompts/en_speaker_7.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f4c4eb33f5994be8de5cfd1744ebce13da1618a6da3a7d244514178c61ef7db
3
+ size 22716
bark/assets/prompts/en_speaker_8.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fc9f11b539588f51bbf78150a73e0365c49b2306bd72e5a22b28ef09c4fb15d
3
+ size 23300
bark/assets/prompts/en_speaker_9.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78b3ba32eb9aeb9ed34556856c40633ecc8332d1c3ae3c81e6f5015ac3eefbd5
3
+ size 30180
bark/assets/prompts/zh_speaker_0.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd7ac118a3e944b3f20c89f2446056a00850a630ee16318922acc6572ce80929
3
+ size 20636
bark/assets/prompts/zh_speaker_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0eacf5c862dfd3c5ac825f2ebb26f323e64309cb712e7e264cbd31c5bca3f038
3
+ size 19836
bark/assets/prompts/zh_speaker_2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e324b47f8250e5798c314f395d4e049575e7ca369d0b6074e91c7bba70e9f26d
3
+ size 21060
bark/assets/prompts/zh_speaker_3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98c476abc7bf634ffb2d71d363284e7bd8c8abd5e33ec5ca21d4aa5b15730d18
3
+ size 31300
bark/assets/prompts/zh_speaker_4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fa8673a9895ad3302d13ac94193b5ad5da481f1cc276e6181fa895acaae133b
3
+ size 29964
bark/assets/prompts/zh_speaker_5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:226edfe5fabc72eeb83a13e350599bc8babe5adc2264b3cdb661fd1258dc4044
3
+ size 17436
bark/assets/prompts/zh_speaker_6.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:285d51fbe81cc263636b5b487fbb6633e6f3cf92c53ca9ab8e6b7f55d4b4a31d
3
+ size 16900
bark/assets/prompts/zh_speaker_7.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0967cdb14ffa79895747b0d52df9f15bdad80d6c55b7630894345c9a7ec87c91
3
+ size 21060
bark/assets/prompts/zh_speaker_8.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c028f78530013f29ab8c0c1cf4fe2138106fbe5252951f5f36e0168056779549
3
+ size 19300
bark/assets/prompts/zh_speaker_9.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6265bb827008d7af8a45a8e057fe3e91efb347d56208180a9ed990ad54e4d75e
3
+ size 16156
bark/clonevoice.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bark.generation import load_codec_model, generate_text_semantic, grab_best_device
2
+ from encodec.utils import convert_audio
3
+ import torchaudio
4
+ import torch
5
+ import os
6
+ import gradio
7
+
8
+
9
+ def clone_voice(audio_filepath, text, dest_filename, progress=gradio.Progress(track_tqdm=True)):
10
+ if len(text) < 1:
11
+ raise gradio.Error('No transcription text entered!')
12
+
13
+ use_gpu = not os.environ.get("BARK_FORCE_CPU", False)
14
+ progress(0, desc="Loading Codec")
15
+ model = load_codec_model(use_gpu=use_gpu)
16
+ progress(0.25, desc="Converting WAV")
17
+
18
+ # Load and pre-process the audio waveform
19
+ device = grab_best_device(use_gpu)
20
+ wav, sr = torchaudio.load(audio_filepath)
21
+ wav = convert_audio(wav, sr, model.sample_rate, model.channels)
22
+ wav = wav.unsqueeze(0).to(device)
23
+ progress(0.5, desc="Extracting codes")
24
+
25
+ # Extract discrete codes from EnCodec
26
+ with torch.no_grad():
27
+ encoded_frames = model.encode(wav)
28
+ codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]
29
+
30
+ # get seconds of audio
31
+ seconds = wav.shape[-1] / model.sample_rate
32
+ # generate semantic tokens
33
+ semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
34
+
35
+ # move codes to cpu
36
+ codes = codes.cpu().numpy()
37
+
38
+ import numpy as np
39
+ output_path = dest_filename + '.npz'
40
+ np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
41
+ return "Finished"
bark/generation.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import gc
3
+ import os
4
+ import re
5
+ import requests
6
+ import gc
7
+ import sys
8
+
9
+ from encodec import EncodecModel
10
+ import funcy
11
+ import logging
12
+ import numpy as np
13
+ from scipy.special import softmax
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import tqdm
17
+ from transformers import BertTokenizer
18
+ from huggingface_hub import hf_hub_download, hf_hub_url
19
+
20
+ from .model import GPTConfig, GPT
21
+ from .model_fine import FineGPT, FineGPTConfig
22
+ from .settings import initenv
23
+
24
+ initenv(sys.argv)
25
+ global_force_cpu = os.environ.get("BARK_FORCE_CPU", False)
26
+ if (
27
+ global_force_cpu != True and
28
+ torch.cuda.is_available() and
29
+ hasattr(torch.cuda, "amp") and
30
+ hasattr(torch.cuda.amp, "autocast") and
31
+ hasattr(torch.cuda, "is_bf16_supported") and
32
+ torch.cuda.is_bf16_supported()
33
+ ):
34
+ autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
35
+ else:
36
+ @contextlib.contextmanager
37
+ def autocast():
38
+ yield
39
+
40
+
41
+ # hold models in global scope to lazy load
42
+ global models
43
+ models = {}
44
+
45
+ global models_devices
46
+ models_devices = {}
47
+
48
+
49
+ CONTEXT_WINDOW_SIZE = 1024
50
+
51
+ SEMANTIC_RATE_HZ = 49.9
52
+ SEMANTIC_VOCAB_SIZE = 10_000
53
+
54
+ CODEBOOK_SIZE = 1024
55
+ N_COARSE_CODEBOOKS = 2
56
+ N_FINE_CODEBOOKS = 8
57
+ COARSE_RATE_HZ = 75
58
+
59
+ SAMPLE_RATE = 24_000
60
+
61
+
62
+ SUPPORTED_LANGS = [
63
+ ("English", "en"),
64
+ ("German", "de"),
65
+ ("Spanish", "es"),
66
+ ("French", "fr"),
67
+ ("Hindi", "hi"),
68
+ ("Italian", "it"),
69
+ ("Japanese", "ja"),
70
+ ("Korean", "ko"),
71
+ ("Polish", "pl"),
72
+ ("Portuguese", "pt"),
73
+ ("Russian", "ru"),
74
+ ("Turkish", "tr"),
75
+ ("Chinese", "zh"),
76
+ ]
77
+
78
+ ALLOWED_PROMPTS = {"announcer"}
79
+ for _, lang in SUPPORTED_LANGS:
80
+ for prefix in ("", f"v2{os.path.sep}"):
81
+ for n in range(10):
82
+ ALLOWED_PROMPTS.add(f"{prefix}{lang}_speaker_{n}")
83
+
84
+
85
+ logger = logging.getLogger(__name__)
86
+
87
+
88
+ CUR_PATH = os.path.dirname(os.path.abspath(__file__))
89
+
90
+
91
+ #default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
92
+ #CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
93
+ #CACHE_DIR = os.path.join(os.getcwd(), "models"
94
+ CACHE_DIR = "./models"
95
+
96
+
97
+ def _cast_bool_env_var(s):
98
+ return s.lower() in ('true', '1', 't')
99
+
100
+ USE_SMALL_MODELS = _cast_bool_env_var(os.environ.get("SUNO_USE_SMALL_MODELS", "False"))
101
+ GLOBAL_ENABLE_MPS = _cast_bool_env_var(os.environ.get("SUNO_ENABLE_MPS", "False"))
102
+ OFFLOAD_CPU = _cast_bool_env_var(os.environ.get("SUNO_OFFLOAD_CPU", "False"))
103
+
104
+ REMOTE_MODEL_PATHS = {
105
+ "text_small": {
106
+ "repo_id": "suno/bark",
107
+ "file_name": "text.pt",
108
+ },
109
+ "coarse_small": {
110
+ "repo_id": "suno/bark",
111
+ "file_name": "coarse.pt",
112
+ },
113
+ "fine_small": {
114
+ "repo_id": "suno/bark",
115
+ "file_name": "fine.pt",
116
+ },
117
+ "text": {
118
+ "repo_id": "suno/bark",
119
+ "file_name": "text_2.pt",
120
+ },
121
+ "coarse": {
122
+ "repo_id": "suno/bark",
123
+ "file_name": "coarse_2.pt",
124
+ },
125
+ "fine": {
126
+ "repo_id": "suno/bark",
127
+ "file_name": "fine_2.pt",
128
+ },
129
+ }
130
+
131
+
132
+ if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available():
133
+ logger.warning(
134
+ "torch version does not support flash attention. You will get faster" +
135
+ " inference speed by upgrade torch to newest nightly version."
136
+ )
137
+
138
+
139
+ def grab_best_device(use_gpu=True):
140
+ if torch.cuda.device_count() > 0 and use_gpu:
141
+ device = "cuda"
142
+ elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
143
+ device = "mps"
144
+ else:
145
+ device = "cpu"
146
+ return device
147
+
148
+
149
+ def _get_ckpt_path(model_type, use_small=False):
150
+ key = model_type
151
+ if use_small or USE_SMALL_MODELS:
152
+ key += "_small"
153
+ return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
154
+
155
+ """
156
+ def _download(from_hf_path, file_name, destfilename):
157
+ os.makedirs(CACHE_DIR, exist_ok=True)
158
+ hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR, local_dir_use_symlinks=False)
159
+ # Bug in original repo? Downloaded name differs from expected...
160
+ if not os.path.exists(destfilename):
161
+ localname = os.path.join(CACHE_DIR, file_name)
162
+ os.rename(localname, destfilename)
163
+ """
164
+ def _download(from_hf_path, file_name):
165
+ os.makedirs(CACHE_DIR, exist_ok=True)
166
+ hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
167
+
168
+
169
+ class InferenceContext:
170
+ def __init__(self, benchmark=False):
171
+ # we can't expect inputs to be the same length, so disable benchmarking by default
172
+ self._chosen_cudnn_benchmark = benchmark
173
+ self._cudnn_benchmark = None
174
+
175
+ def __enter__(self):
176
+ self._cudnn_benchmark = torch.backends.cudnn.benchmark
177
+ torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark
178
+
179
+ def __exit__(self, exc_type, exc_value, exc_traceback):
180
+ torch.backends.cudnn.benchmark = self._cudnn_benchmark
181
+
182
+
183
+ if torch.cuda.is_available():
184
+ torch.backends.cuda.matmul.allow_tf32 = True
185
+ torch.backends.cudnn.allow_tf32 = True
186
+
187
+
188
+ @contextlib.contextmanager
189
+ def _inference_mode():
190
+ with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast():
191
+ yield
192
+
193
+
194
+ def _clear_cuda_cache():
195
+ if torch.cuda.is_available():
196
+ torch.cuda.empty_cache()
197
+ torch.cuda.synchronize()
198
+
199
+
200
+ def clean_models(model_key=None):
201
+ global models
202
+ model_keys = [model_key] if model_key is not None else models.keys()
203
+ for k in model_keys:
204
+ if k in models:
205
+ del models[k]
206
+ _clear_cuda_cache()
207
+ gc.collect()
208
+
209
+
210
+ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
211
+ if model_type == "text":
212
+ ConfigClass = GPTConfig
213
+ ModelClass = GPT
214
+ elif model_type == "coarse":
215
+ ConfigClass = GPTConfig
216
+ ModelClass = GPT
217
+ elif model_type == "fine":
218
+ ConfigClass = FineGPTConfig
219
+ ModelClass = FineGPT
220
+ else:
221
+ raise NotImplementedError()
222
+
223
+ # Force-remove Models to allow running on >12Gb GPU
224
+ # CF: Probably not needed anymore
225
+ #global models
226
+ #models.clear()
227
+ #gc.collect()
228
+ #torch.cuda.empty_cache()
229
+ # to here...
230
+
231
+ model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
232
+ model_info = REMOTE_MODEL_PATHS[model_key]
233
+ if not os.path.exists(ckpt_path):
234
+ logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
235
+ ## added next two lines to make it super clear which model is being downloaded
236
+ remote_filename = hf_hub_url(model_info["repo_id"], model_info["file_name"])
237
+ print(f"Downloading {model_key} {model_info['repo_id']} remote model file {remote_filename} {model_info['file_name']} to {CACHE_DIR}")
238
+ _download(model_info["repo_id"], model_info["file_name"])
239
+ # add next line to make it super clear which model is being loaded
240
+ print(f"Loading {model_key} model from {ckpt_path} to {device}") # added
241
+ checkpoint = torch.load(ckpt_path, map_location=device)
242
+ # this is a hack
243
+ model_args = checkpoint["model_args"]
244
+ if "input_vocab_size" not in model_args:
245
+ model_args["input_vocab_size"] = model_args["vocab_size"]
246
+ model_args["output_vocab_size"] = model_args["vocab_size"]
247
+ del model_args["vocab_size"]
248
+ gptconf = ConfigClass(**checkpoint["model_args"])
249
+ model = ModelClass(gptconf)
250
+ state_dict = checkpoint["model"]
251
+ # fixup checkpoint
252
+ unwanted_prefix = "_orig_mod."
253
+ for k, v in list(state_dict.items()):
254
+ if k.startswith(unwanted_prefix):
255
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
256
+ extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
257
+ extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
258
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
259
+ missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
260
+ if len(extra_keys) != 0:
261
+ raise ValueError(f"extra keys found: {extra_keys}")
262
+ if len(missing_keys) != 0:
263
+ raise ValueError(f"missing keys: {missing_keys}")
264
+ model.load_state_dict(state_dict, strict=False)
265
+ n_params = model.get_num_params()
266
+ val_loss = checkpoint["best_val_loss"].item()
267
+ logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
268
+ model.eval()
269
+ model.to(device)
270
+ del checkpoint, state_dict
271
+ _clear_cuda_cache()
272
+ if model_type == "text":
273
+ tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
274
+ return {
275
+ "model": model,
276
+ "tokenizer": tokenizer,
277
+ }
278
+ return model
279
+
280
+
281
+ def _load_codec_model(device):
282
+ model = EncodecModel.encodec_model_24khz()
283
+ model.set_target_bandwidth(6.0)
284
+ model.eval()
285
+ model.to(device)
286
+ _clear_cuda_cache()
287
+ return model
288
+
289
+
290
+ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
291
+ _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
292
+ if model_type not in ("text", "coarse", "fine"):
293
+ raise NotImplementedError()
294
+ global models
295
+ global models_devices
296
+ device = grab_best_device(use_gpu=use_gpu)
297
+ model_key = f"{model_type}"
298
+ if OFFLOAD_CPU:
299
+ models_devices[model_key] = device
300
+ device = "cpu"
301
+ if model_key not in models or force_reload:
302
+ ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
303
+ clean_models(model_key=model_key)
304
+ model = _load_model_f(ckpt_path, device)
305
+ models[model_key] = model
306
+ if model_type == "text":
307
+ models[model_key]["model"].to(device)
308
+ else:
309
+ models[model_key].to(device)
310
+ return models[model_key]
311
+
312
+
313
+ def load_codec_model(use_gpu=True, force_reload=False):
314
+ global models
315
+ global models_devices
316
+ device = grab_best_device(use_gpu=use_gpu)
317
+ if device == "mps":
318
+ # encodec doesn't support mps
319
+ device = "cpu"
320
+ model_key = "codec"
321
+ if OFFLOAD_CPU:
322
+ models_devices[model_key] = device
323
+ device = "cpu"
324
+ if model_key not in models or force_reload:
325
+ clean_models(model_key=model_key)
326
+ model = _load_codec_model(device)
327
+ models[model_key] = model
328
+ models[model_key].to(device)
329
+ return models[model_key]
330
+
331
+
332
+ def preload_models(
333
+ text_use_gpu=True,
334
+ text_use_small=False,
335
+ coarse_use_gpu=True,
336
+ coarse_use_small=False,
337
+ fine_use_gpu=True,
338
+ fine_use_small=False,
339
+ codec_use_gpu=True,
340
+ force_reload=False
341
+ ):
342
+ """Load all the necessary models for the pipeline."""
343
+ if grab_best_device() == "cpu" and (
344
+ text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
345
+ ):
346
+ logger.warning("No GPU being used. Careful, inference might be very slow!")
347
+ _ = load_model(
348
+ model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
349
+ )
350
+ _ = load_model(
351
+ model_type="coarse",
352
+ use_gpu=coarse_use_gpu,
353
+ use_small=coarse_use_small,
354
+ force_reload=force_reload,
355
+ )
356
+ _ = load_model(
357
+ model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
358
+ )
359
+ _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
360
+
361
+
362
+ ####
363
+ # Generation Functionality
364
+ ####
365
+
366
+
367
+ def _tokenize(tokenizer, text):
368
+ return tokenizer.encode(text, add_special_tokens=False)
369
+
370
+
371
+ def _detokenize(tokenizer, enc_text):
372
+ return tokenizer.decode(enc_text)
373
+
374
+
375
+ def _normalize_whitespace(text):
376
+ return re.sub(r"\s+", " ", text).strip()
377
+
378
+
379
+ TEXT_ENCODING_OFFSET = 10_048
380
+ SEMANTIC_PAD_TOKEN = 10_000
381
+ TEXT_PAD_TOKEN = 129_595
382
+ SEMANTIC_INFER_TOKEN = 129_599
383
+
384
+
385
+ def _load_history_prompt(history_prompt_input):
386
+ if isinstance(history_prompt_input, str) and history_prompt_input.endswith(".npz"):
387
+ history_prompt = np.load(history_prompt_input)
388
+ elif isinstance(history_prompt_input, str):
389
+ # make sure this works on non-ubuntu
390
+ history_prompt_input = os.path.join(*history_prompt_input.split("/"))
391
+ # if history_prompt_input not in ALLOWED_PROMPTS:
392
+ # raise ValueError("history prompt not found")
393
+ history_prompt = np.load(
394
+ os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt_input}.npz")
395
+ )
396
+ elif isinstance(history_prompt_input, dict):
397
+ assert("semantic_prompt" in history_prompt_input)
398
+ assert("coarse_prompt" in history_prompt_input)
399
+ assert("fine_prompt" in history_prompt_input)
400
+ history_prompt = history_prompt_input
401
+ else:
402
+ raise ValueError("history prompt format unrecognized")
403
+ return history_prompt
404
+
405
+
406
+ def generate_text_semantic(
407
+ text,
408
+ history_prompt=None,
409
+ temp=0.7,
410
+ top_k=None,
411
+ top_p=None,
412
+ silent=False,
413
+ min_eos_p=0.2,
414
+ max_gen_duration_s=None,
415
+ allow_early_stop=True,
416
+ use_kv_caching=False,
417
+ ):
418
+ """Generate semantic tokens from text."""
419
+ assert isinstance(text, str)
420
+ text = _normalize_whitespace(text)
421
+ assert len(text.strip()) > 0
422
+ if history_prompt is not None:
423
+ history_prompt = _load_history_prompt(history_prompt)
424
+ semantic_history = history_prompt["semantic_prompt"]
425
+ assert (
426
+ isinstance(semantic_history, np.ndarray)
427
+ and len(semantic_history.shape) == 1
428
+ and len(semantic_history) > 0
429
+ and semantic_history.min() >= 0
430
+ and semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
431
+ )
432
+ else:
433
+ semantic_history = None
434
+ # load models if not yet exist
435
+ global models
436
+ global models_devices
437
+ if "text" not in models:
438
+ preload_models()
439
+ model_container = models["text"]
440
+ model = model_container["model"]
441
+ tokenizer = model_container["tokenizer"]
442
+ encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
443
+ if OFFLOAD_CPU:
444
+ model.to(models_devices["text"])
445
+ device = next(model.parameters()).device
446
+ if len(encoded_text) > 256:
447
+ p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
448
+ logger.warning(f"warning, text too long, lopping of last {p}%")
449
+ encoded_text = encoded_text[:256]
450
+ encoded_text = np.pad(
451
+ encoded_text,
452
+ (0, 256 - len(encoded_text)),
453
+ constant_values=TEXT_PAD_TOKEN,
454
+ mode="constant",
455
+ )
456
+ if semantic_history is not None:
457
+ semantic_history = semantic_history.astype(np.int64)
458
+ # lop off if history is too long, pad if needed
459
+ semantic_history = semantic_history[-256:]
460
+ semantic_history = np.pad(
461
+ semantic_history,
462
+ (0, 256 - len(semantic_history)),
463
+ constant_values=SEMANTIC_PAD_TOKEN,
464
+ mode="constant",
465
+ )
466
+ else:
467
+ semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
468
+ x = torch.from_numpy(
469
+ np.hstack([
470
+ encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])
471
+ ]).astype(np.int64)
472
+ )[None]
473
+ assert x.shape[1] == 256 + 256 + 1
474
+ with _inference_mode():
475
+ x = x.to(device)
476
+ n_tot_steps = 768
477
+ # custom tqdm updates since we don't know when eos will occur
478
+ pbar = tqdm.tqdm(disable=silent, total=100)
479
+ pbar_state = 0
480
+ tot_generated_duration_s = 0
481
+ kv_cache = None
482
+ for n in range(n_tot_steps):
483
+ if use_kv_caching and kv_cache is not None:
484
+ x_input = x[:, [-1]]
485
+ else:
486
+ x_input = x
487
+ logits, kv_cache = model(
488
+ x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
489
+ )
490
+ relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
491
+ if allow_early_stop:
492
+ relevant_logits = torch.hstack(
493
+ (relevant_logits, logits[0, 0, [SEMANTIC_PAD_TOKEN]]) # eos
494
+ )
495
+ if top_p is not None:
496
+ # faster to convert to numpy
497
+ logits_device = relevant_logits.device
498
+ logits_dtype = relevant_logits.type()
499
+ relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
500
+ sorted_indices = np.argsort(relevant_logits)[::-1]
501
+ sorted_logits = relevant_logits[sorted_indices]
502
+ cumulative_probs = np.cumsum(softmax(sorted_logits))
503
+ sorted_indices_to_remove = cumulative_probs > top_p
504
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
505
+ sorted_indices_to_remove[0] = False
506
+ relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
507
+ relevant_logits = torch.from_numpy(relevant_logits)
508
+ relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
509
+ if top_k is not None:
510
+ v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
511
+ relevant_logits[relevant_logits < v[-1]] = -float("Inf")
512
+ probs = F.softmax(relevant_logits / temp, dim=-1)
513
+ # multinomial bugged on mps: shuttle to cpu if necessary
514
+ inf_device = probs.device
515
+ if probs.device.type == "mps":
516
+ probs = probs.to("cpu")
517
+ item_next = torch.multinomial(probs, num_samples=1)
518
+ probs = probs.to(inf_device)
519
+ item_next = item_next.to(inf_device)
520
+ if allow_early_stop and (
521
+ item_next == SEMANTIC_VOCAB_SIZE
522
+ or (min_eos_p is not None and probs[-1] >= min_eos_p)
523
+ ):
524
+ # eos found, so break
525
+ pbar.update(100 - pbar_state)
526
+ break
527
+ x = torch.cat((x, item_next[None]), dim=1)
528
+ tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ
529
+ if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s:
530
+ pbar.update(100 - pbar_state)
531
+ break
532
+ if n == n_tot_steps - 1:
533
+ pbar.update(100 - pbar_state)
534
+ break
535
+ del logits, relevant_logits, probs, item_next
536
+ req_pbar_state = np.min([100, int(round(100 * n / n_tot_steps))])
537
+ if req_pbar_state > pbar_state:
538
+ pbar.update(req_pbar_state - pbar_state)
539
+ pbar_state = req_pbar_state
540
+ pbar.close()
541
+ out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
542
+ if OFFLOAD_CPU:
543
+ model.to("cpu")
544
+ assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
545
+ _clear_cuda_cache()
546
+ return out
547
+
548
+
549
+ def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):
550
+ assert len(arr.shape) == 2
551
+ arr = arr.copy()
552
+ if offset_size is not None:
553
+ for n in range(1, arr.shape[0]):
554
+ arr[n, :] += offset_size * n
555
+ flat_arr = arr.ravel("F")
556
+ return flat_arr
557
+
558
+
559
+ COARSE_SEMANTIC_PAD_TOKEN = 12_048
560
+ COARSE_INFER_TOKEN = 12_050
561
+
562
+
563
+ def generate_coarse(
564
+ x_semantic,
565
+ history_prompt=None,
566
+ temp=0.7,
567
+ top_k=None,
568
+ top_p=None,
569
+ silent=False,
570
+ max_coarse_history=630, # min 60 (faster), max 630 (more context)
571
+ sliding_window_len=60,
572
+ use_kv_caching=False,
573
+ ):
574
+ """Generate coarse audio codes from semantic tokens."""
575
+ assert (
576
+ isinstance(x_semantic, np.ndarray)
577
+ and len(x_semantic.shape) == 1
578
+ and len(x_semantic) > 0
579
+ and x_semantic.min() >= 0
580
+ and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
581
+ )
582
+ assert 60 <= max_coarse_history <= 630
583
+ assert max_coarse_history + sliding_window_len <= 1024 - 256
584
+ semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
585
+ max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
586
+ if history_prompt is not None:
587
+ history_prompt = _load_history_prompt(history_prompt)
588
+ x_semantic_history = history_prompt["semantic_prompt"]
589
+ x_coarse_history = history_prompt["coarse_prompt"]
590
+ assert (
591
+ isinstance(x_semantic_history, np.ndarray)
592
+ and len(x_semantic_history.shape) == 1
593
+ and len(x_semantic_history) > 0
594
+ and x_semantic_history.min() >= 0
595
+ and x_semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
596
+ and isinstance(x_coarse_history, np.ndarray)
597
+ and len(x_coarse_history.shape) == 2
598
+ and x_coarse_history.shape[0] == N_COARSE_CODEBOOKS
599
+ and x_coarse_history.shape[-1] >= 0
600
+ and x_coarse_history.min() >= 0
601
+ and x_coarse_history.max() <= CODEBOOK_SIZE - 1
602
+ and (
603
+ round(x_coarse_history.shape[-1] / len(x_semantic_history), 1)
604
+ == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
605
+ )
606
+ )
607
+ x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
608
+ # trim histories correctly
609
+ n_semantic_hist_provided = np.min(
610
+ [
611
+ max_semantic_history,
612
+ len(x_semantic_history) - len(x_semantic_history) % 2,
613
+ int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)),
614
+ ]
615
+ )
616
+ n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
617
+ x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32)
618
+ x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32)
619
+ # TODO: bit of a hack for time alignment (sounds better)
620
+ x_coarse_history = x_coarse_history[:-2]
621
+ else:
622
+ x_semantic_history = np.array([], dtype=np.int32)
623
+ x_coarse_history = np.array([], dtype=np.int32)
624
+ # load models if not yet exist
625
+ global models
626
+ global models_devices
627
+ if "coarse" not in models:
628
+ preload_models()
629
+ model = models["coarse"]
630
+ if OFFLOAD_CPU:
631
+ model.to(models_devices["coarse"])
632
+ device = next(model.parameters()).device
633
+ # start loop
634
+ n_steps = int(
635
+ round(
636
+ np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS)
637
+ * N_COARSE_CODEBOOKS
638
+ )
639
+ )
640
+ assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0
641
+ x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32)
642
+ x_coarse = x_coarse_history.astype(np.int32)
643
+ base_semantic_idx = len(x_semantic_history)
644
+ with _inference_mode():
645
+ x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
646
+ x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
647
+ n_window_steps = int(np.ceil(n_steps / sliding_window_len))
648
+ n_step = 0
649
+ for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
650
+ semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
651
+ # pad from right side
652
+ x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :]
653
+ x_in = x_in[:, :256]
654
+ x_in = F.pad(
655
+ x_in,
656
+ (0, 256 - x_in.shape[-1]),
657
+ "constant",
658
+ COARSE_SEMANTIC_PAD_TOKEN,
659
+ )
660
+ x_in = torch.hstack(
661
+ [
662
+ x_in,
663
+ torch.tensor([COARSE_INFER_TOKEN])[None].to(device),
664
+ x_coarse_in[:, -max_coarse_history:],
665
+ ]
666
+ )
667
+ kv_cache = None
668
+ for _ in range(sliding_window_len):
669
+ if n_step >= n_steps:
670
+ continue
671
+ is_major_step = n_step % N_COARSE_CODEBOOKS == 0
672
+
673
+ if use_kv_caching and kv_cache is not None:
674
+ x_input = x_in[:, [-1]]
675
+ else:
676
+ x_input = x_in
677
+
678
+ logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
679
+ logit_start_idx = (
680
+ SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
681
+ )
682
+ logit_end_idx = (
683
+ SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
684
+ )
685
+ relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
686
+ if top_p is not None:
687
+ # faster to convert to numpy
688
+ logits_device = relevant_logits.device
689
+ logits_dtype = relevant_logits.type()
690
+ relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
691
+ sorted_indices = np.argsort(relevant_logits)[::-1]
692
+ sorted_logits = relevant_logits[sorted_indices]
693
+ cumulative_probs = np.cumsum(softmax(sorted_logits))
694
+ sorted_indices_to_remove = cumulative_probs > top_p
695
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
696
+ sorted_indices_to_remove[0] = False
697
+ relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
698
+ relevant_logits = torch.from_numpy(relevant_logits)
699
+ relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
700
+ if top_k is not None:
701
+ v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
702
+ relevant_logits[relevant_logits < v[-1]] = -float("Inf")
703
+ probs = F.softmax(relevant_logits / temp, dim=-1)
704
+ # multinomial bugged on mps: shuttle to cpu if necessary
705
+ inf_device = probs.device
706
+ if probs.device.type == "mps":
707
+ probs = probs.to("cpu")
708
+ item_next = torch.multinomial(probs, num_samples=1)
709
+ probs = probs.to(inf_device)
710
+ item_next = item_next.to(inf_device)
711
+ item_next += logit_start_idx
712
+ x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
713
+ x_in = torch.cat((x_in, item_next[None]), dim=1)
714
+ del logits, relevant_logits, probs, item_next
715
+ n_step += 1
716
+ del x_in
717
+ del x_semantic_in
718
+ if OFFLOAD_CPU:
719
+ model.to("cpu")
720
+ gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
721
+ del x_coarse_in
722
+ assert len(gen_coarse_arr) == n_steps
723
+ gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE
724
+ for n in range(1, N_COARSE_CODEBOOKS):
725
+ gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE
726
+ _clear_cuda_cache()
727
+ return gen_coarse_audio_arr
728
+
729
+
730
+ def generate_fine(
731
+ x_coarse_gen,
732
+ history_prompt=None,
733
+ temp=0.5,
734
+ silent=True,
735
+ ):
736
+ """Generate full audio codes from coarse audio codes."""
737
+ assert (
738
+ isinstance(x_coarse_gen, np.ndarray)
739
+ and len(x_coarse_gen.shape) == 2
740
+ and 1 <= x_coarse_gen.shape[0] <= N_FINE_CODEBOOKS - 1
741
+ and x_coarse_gen.shape[1] > 0
742
+ and x_coarse_gen.min() >= 0
743
+ and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
744
+ )
745
+ if history_prompt is not None:
746
+ history_prompt = _load_history_prompt(history_prompt)
747
+ x_fine_history = history_prompt["fine_prompt"]
748
+ assert (
749
+ isinstance(x_fine_history, np.ndarray)
750
+ and len(x_fine_history.shape) == 2
751
+ and x_fine_history.shape[0] == N_FINE_CODEBOOKS
752
+ and x_fine_history.shape[1] >= 0
753
+ and x_fine_history.min() >= 0
754
+ and x_fine_history.max() <= CODEBOOK_SIZE - 1
755
+ )
756
+ else:
757
+ x_fine_history = None
758
+ n_coarse = x_coarse_gen.shape[0]
759
+ # load models if not yet exist
760
+ global models
761
+ global models_devices
762
+ if "fine" not in models:
763
+ preload_models()
764
+ model = models["fine"]
765
+ if OFFLOAD_CPU:
766
+ model.to(models_devices["fine"])
767
+ device = next(model.parameters()).device
768
+ # make input arr
769
+ in_arr = np.vstack(
770
+ [
771
+ x_coarse_gen,
772
+ np.zeros((N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1]))
773
+ + CODEBOOK_SIZE, # padding
774
+ ]
775
+ ).astype(np.int32)
776
+ # prepend history if available (max 512)
777
+ if x_fine_history is not None:
778
+ x_fine_history = x_fine_history.astype(np.int32)
779
+ in_arr = np.hstack(
780
+ [
781
+ x_fine_history[:, -512:].astype(np.int32),
782
+ in_arr,
783
+ ]
784
+ )
785
+ n_history = x_fine_history[:, -512:].shape[1]
786
+ else:
787
+ n_history = 0
788
+ n_remove_from_end = 0
789
+ # need to pad if too short (since non-causal model)
790
+ if in_arr.shape[1] < 1024:
791
+ n_remove_from_end = 1024 - in_arr.shape[1]
792
+ in_arr = np.hstack(
793
+ [
794
+ in_arr,
795
+ np.zeros((N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + CODEBOOK_SIZE,
796
+ ]
797
+ )
798
+ # we can be lazy about fractional loop and just keep overwriting codebooks
799
+ n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1
800
+ with _inference_mode():
801
+ in_arr = torch.tensor(in_arr.T).to(device)
802
+ for n in tqdm.tqdm(range(n_loops), disable=silent):
803
+ start_idx = np.min([n * 512, in_arr.shape[0] - 1024])
804
+ start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512])
805
+ rel_start_fill_idx = start_fill_idx - start_idx
806
+ in_buffer = in_arr[start_idx : start_idx + 1024, :][None]
807
+ for nn in range(n_coarse, N_FINE_CODEBOOKS):
808
+ logits = model(nn, in_buffer)
809
+ if temp is None:
810
+ relevant_logits = logits[0, rel_start_fill_idx:, :CODEBOOK_SIZE]
811
+ codebook_preds = torch.argmax(relevant_logits, -1)
812
+ else:
813
+ relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
814
+ probs = F.softmax(relevant_logits, dim=-1)
815
+ # multinomial bugged on mps: shuttle to cpu if necessary
816
+ inf_device = probs.device
817
+ if probs.device.type == "mps":
818
+ probs = probs.to("cpu")
819
+ codebook_preds = torch.hstack(
820
+ [
821
+ torch.multinomial(probs[nnn], num_samples=1).to(inf_device)
822
+ for nnn in range(rel_start_fill_idx, 1024)
823
+ ]
824
+ )
825
+ in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
826
+ del logits, codebook_preds
827
+ # transfer over info into model_in and convert to numpy
828
+ for nn in range(n_coarse, N_FINE_CODEBOOKS):
829
+ in_arr[
830
+ start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn
831
+ ] = in_buffer[0, rel_start_fill_idx:, nn]
832
+ del in_buffer
833
+ gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
834
+ del in_arr
835
+ if OFFLOAD_CPU:
836
+ model.to("cpu")
837
+ gen_fine_arr = gen_fine_arr[:, n_history:]
838
+ if n_remove_from_end > 0:
839
+ gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
840
+ assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1]
841
+ _clear_cuda_cache()
842
+ return gen_fine_arr
843
+
844
+
845
+ def codec_decode(fine_tokens):
846
+ """Turn quantized audio codes into audio array using encodec."""
847
+ # load models if not yet exist
848
+ global models
849
+ global models_devices
850
+ if "codec" not in models:
851
+ preload_models()
852
+ model = models["codec"]
853
+ if OFFLOAD_CPU:
854
+ model.to(models_devices["codec"])
855
+ device = next(model.parameters()).device
856
+ arr = torch.from_numpy(fine_tokens)[None]
857
+ arr = arr.to(device)
858
+ arr = arr.transpose(0, 1)
859
+ emb = model.quantizer.decode(arr)
860
+ out = model.decoder(emb)
861
+ audio_arr = out.detach().cpu().numpy().squeeze()
862
+ del arr, emb, out
863
+ if OFFLOAD_CPU:
864
+ model.to("cpu")
865
+ return audio_arr
bark/model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Much of this code is adapted from Andrej Karpathy's NanoGPT
3
+ (https://github.com/karpathy/nanoGPT)
4
+ """
5
+ import math
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+ class LayerNorm(nn.Module):
13
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
14
+
15
+ def __init__(self, ndim, bias):
16
+ super().__init__()
17
+ self.weight = nn.Parameter(torch.ones(ndim))
18
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
19
+
20
+ def forward(self, input):
21
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
22
+
23
+ class CausalSelfAttention(nn.Module):
24
+
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ assert config.n_embd % config.n_head == 0
28
+ # key, query, value projections for all heads, but in a batch
29
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
30
+ # output projection
31
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
32
+ # regularization
33
+ self.attn_dropout = nn.Dropout(config.dropout)
34
+ self.resid_dropout = nn.Dropout(config.dropout)
35
+ self.n_head = config.n_head
36
+ self.n_embd = config.n_embd
37
+ self.dropout = config.dropout
38
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
39
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
40
+ if not self.flash:
41
+ # print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
42
+ # causal mask to ensure that attention is only applied to the left in the input sequence
43
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
44
+ .view(1, 1, config.block_size, config.block_size))
45
+
46
+ def forward(self, x, past_kv=None, use_cache=False):
47
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
48
+
49
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
50
+ q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
51
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
52
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
53
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
54
+
55
+ if past_kv is not None:
56
+ past_key = past_kv[0]
57
+ past_value = past_kv[1]
58
+ k = torch.cat((past_key, k), dim=-2)
59
+ v = torch.cat((past_value, v), dim=-2)
60
+
61
+ FULL_T = k.shape[-2]
62
+
63
+ if use_cache is True:
64
+ present = (k, v)
65
+ else:
66
+ present = None
67
+
68
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
69
+ if self.flash:
70
+ # efficient attention using Flash Attention CUDA kernels
71
+ if past_kv is not None:
72
+ # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
73
+ # the query for the last token. scaled_dot_product_attention interprets this as the first token in the
74
+ # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
75
+ # to work around this we set is_causal=False.
76
+ is_causal = False
77
+ else:
78
+ is_causal = True
79
+
80
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
81
+ else:
82
+ # manual implementation of attention
83
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
84
+ att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
85
+ att = F.softmax(att, dim=-1)
86
+ att = self.attn_dropout(att)
87
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
88
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
89
+
90
+ # output projection
91
+ y = self.resid_dropout(self.c_proj(y))
92
+ return (y, present)
93
+
94
+ class MLP(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
99
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
100
+ self.dropout = nn.Dropout(config.dropout)
101
+ self.gelu = nn.GELU()
102
+
103
+ def forward(self, x):
104
+ x = self.c_fc(x)
105
+ x = self.gelu(x)
106
+ x = self.c_proj(x)
107
+ x = self.dropout(x)
108
+ return x
109
+
110
+ class Block(nn.Module):
111
+
112
+ def __init__(self, config, layer_idx):
113
+ super().__init__()
114
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
115
+ self.attn = CausalSelfAttention(config)
116
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
117
+ self.mlp = MLP(config)
118
+ self.layer_idx = layer_idx
119
+
120
+ def forward(self, x, past_kv=None, use_cache=False):
121
+ attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
122
+ x = x + attn_output
123
+ x = x + self.mlp(self.ln_2(x))
124
+ return (x, prev_kvs)
125
+
126
+ @dataclass
127
+ class GPTConfig:
128
+ block_size: int = 1024
129
+ input_vocab_size: int = 10_048
130
+ output_vocab_size: int = 10_048
131
+ n_layer: int = 12
132
+ n_head: int = 12
133
+ n_embd: int = 768
134
+ dropout: float = 0.0
135
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
136
+
137
+ class GPT(nn.Module):
138
+
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ assert config.input_vocab_size is not None
142
+ assert config.output_vocab_size is not None
143
+ assert config.block_size is not None
144
+ self.config = config
145
+
146
+ self.transformer = nn.ModuleDict(dict(
147
+ wte = nn.Embedding(config.input_vocab_size, config.n_embd),
148
+ wpe = nn.Embedding(config.block_size, config.n_embd),
149
+ drop = nn.Dropout(config.dropout),
150
+ h = nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]),
151
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
152
+ ))
153
+ self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
154
+
155
+ def get_num_params(self, non_embedding=True):
156
+ """
157
+ Return the number of parameters in the model.
158
+ For non-embedding count (default), the position embeddings get subtracted.
159
+ The token embeddings would too, except due to the parameter sharing these
160
+ params are actually used as weights in the final layer, so we include them.
161
+ """
162
+ n_params = sum(p.numel() for p in self.parameters())
163
+ if non_embedding:
164
+ n_params -= self.transformer.wte.weight.numel()
165
+ n_params -= self.transformer.wpe.weight.numel()
166
+ return n_params
167
+
168
+ def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
169
+ device = idx.device
170
+ b, t = idx.size()
171
+ if past_kv is not None:
172
+ assert t == 1
173
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
174
+ else:
175
+ if merge_context:
176
+ assert(idx.shape[1] >= 256+256+1)
177
+ t = idx.shape[1] - 256
178
+ else:
179
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
180
+
181
+ # forward the GPT model itself
182
+ if merge_context:
183
+ tok_emb = torch.cat([
184
+ self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
185
+ self.transformer.wte(idx[:,256+256:])
186
+ ], dim=1)
187
+ else:
188
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
189
+
190
+ if past_kv is None:
191
+ past_length = 0
192
+ past_kv = tuple([None] * len(self.transformer.h))
193
+ else:
194
+ past_length = past_kv[0][0].size(-2)
195
+
196
+ if position_ids is None:
197
+ position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
198
+ position_ids = position_ids.unsqueeze(0) # shape (1, t)
199
+ assert position_ids.shape == (1, t)
200
+
201
+ pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
202
+
203
+ x = self.transformer.drop(tok_emb + pos_emb)
204
+
205
+ new_kv = () if use_cache else None
206
+
207
+ for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
208
+ x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
209
+
210
+ if use_cache:
211
+ new_kv = new_kv + (kv,)
212
+
213
+ x = self.transformer.ln_f(x)
214
+
215
+ # inference-time mini-optimization: only forward the lm_head on the very last position
216
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
217
+
218
+ return (logits, new_kv)
bark/model_fine.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Much of this code is adapted from Andrej Karpathy's NanoGPT
3
+ (https://github.com/karpathy/nanoGPT)
4
+ """
5
+ from dataclasses import dataclass
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+ from .model import GPT, GPTConfig, MLP
13
+
14
+
15
+ class NonCausalSelfAttention(nn.Module):
16
+ def __init__(self, config):
17
+ super().__init__()
18
+ assert config.n_embd % config.n_head == 0
19
+ # key, query, value projections for all heads, but in a batch
20
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
21
+ # output projection
22
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
23
+ # regularization
24
+ self.attn_dropout = nn.Dropout(config.dropout)
25
+ self.resid_dropout = nn.Dropout(config.dropout)
26
+ self.n_head = config.n_head
27
+ self.n_embd = config.n_embd
28
+ self.dropout = config.dropout
29
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
30
+ self.flash = (
31
+ hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
32
+ )
33
+
34
+ def forward(self, x):
35
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
36
+
37
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
38
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
39
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
40
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
41
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
42
+
43
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
44
+ if self.flash:
45
+ # efficient attention using Flash Attention CUDA kernels
46
+ y = torch.nn.functional.scaled_dot_product_attention(
47
+ q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
48
+ )
49
+ else:
50
+ # manual implementation of attention
51
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
52
+ att = F.softmax(att, dim=-1)
53
+ att = self.attn_dropout(att)
54
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
55
+ y = (
56
+ y.transpose(1, 2).contiguous().view(B, T, C)
57
+ ) # re-assemble all head outputs side by side
58
+
59
+ # output projection
60
+ y = self.resid_dropout(self.c_proj(y))
61
+ return y
62
+
63
+
64
+ class FineBlock(nn.Module):
65
+ def __init__(self, config):
66
+ super().__init__()
67
+ self.ln_1 = nn.LayerNorm(config.n_embd)
68
+ self.attn = NonCausalSelfAttention(config)
69
+ self.ln_2 = nn.LayerNorm(config.n_embd)
70
+ self.mlp = MLP(config)
71
+
72
+ def forward(self, x):
73
+ x = x + self.attn(self.ln_1(x))
74
+ x = x + self.mlp(self.ln_2(x))
75
+ return x
76
+
77
+
78
+ class FineGPT(GPT):
79
+ def __init__(self, config):
80
+ super().__init__(config)
81
+ del self.lm_head
82
+ self.config = config
83
+ self.n_codes_total = config.n_codes_total
84
+ self.transformer = nn.ModuleDict(
85
+ dict(
86
+ wtes=nn.ModuleList(
87
+ [
88
+ nn.Embedding(config.input_vocab_size, config.n_embd)
89
+ for _ in range(config.n_codes_total)
90
+ ]
91
+ ),
92
+ wpe=nn.Embedding(config.block_size, config.n_embd),
93
+ drop=nn.Dropout(config.dropout),
94
+ h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]),
95
+ ln_f=nn.LayerNorm(config.n_embd),
96
+ )
97
+ )
98
+ self.lm_heads = nn.ModuleList(
99
+ [
100
+ nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
101
+ for _ in range(config.n_codes_given, self.n_codes_total)
102
+ ]
103
+ )
104
+ for i in range(self.n_codes_total - config.n_codes_given):
105
+ self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight
106
+
107
+ def forward(self, pred_idx, idx):
108
+ device = idx.device
109
+ b, t, codes = idx.size()
110
+ assert (
111
+ t <= self.config.block_size
112
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
113
+ assert pred_idx > 0, "cannot predict 0th codebook"
114
+ assert codes == self.n_codes_total, (b, t, codes)
115
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
116
+
117
+ # forward the GPT model itself
118
+ tok_embs = [
119
+ wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes)
120
+ ] # token embeddings of shape (b, t, n_embd)
121
+ tok_emb = torch.cat(tok_embs, dim=-1)
122
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
123
+ x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1)
124
+ x = self.transformer.drop(x + pos_emb)
125
+ for block in self.transformer.h:
126
+ x = block(x)
127
+ x = self.transformer.ln_f(x)
128
+ logits = self.lm_heads[pred_idx - self.config.n_codes_given](x)
129
+ return logits
130
+
131
+ def get_num_params(self, non_embedding=True):
132
+ """
133
+ Return the number of parameters in the model.
134
+ For non-embedding count (default), the position embeddings get subtracted.
135
+ The token embeddings would too, except due to the parameter sharing these
136
+ params are actually used as weights in the final layer, so we include them.
137
+ """
138
+ n_params = sum(p.numel() for p in self.parameters())
139
+ if non_embedding:
140
+ for wte in self.transformer.wtes:
141
+ n_params -= wte.weight.numel()
142
+ n_params -= self.transformer.wpe.weight.numel()
143
+ return n_params
144
+
145
+
146
+ @dataclass
147
+ class FineGPTConfig(GPTConfig):
148
+ n_codes_total: int = 8
149
+ n_codes_given: int = 1
bark/settings.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def initenv(args):
4
+ os.environ['SUNO_USE_SMALL_MODELS'] = str("-smallmodels" in args)
5
+ os.environ['BARK_FORCE_CPU'] = str("-forcecpu" in args)
6
+ os.environ['SUNO_ENABLE_MPS'] = str("-enablemps" in args)
7
+ os.environ['SUNO_OFFLOAD_CPU'] = str("-offloadcpu" in args)
best_model.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:017bfd8907c80bb5857d65d0223f0e4e4b9d699ef52e2a853d9cc7eb7e308cf0
3
+ size 379957289
config.json ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "vits",
3
+ "run_name": "vits_tts-portuguese",
4
+ "run_description": "",
5
+ "epochs": 1000,
6
+ "batch_size": 52,
7
+ "eval_batch_size": 52,
8
+ "mixed_precision": false,
9
+ "scheduler_after_epoch": true,
10
+ "run_eval": true,
11
+ "test_delay_epochs": -1,
12
+ "print_eval": true,
13
+ "dashboard_logger": "tensorboard",
14
+ "print_step": 25,
15
+ "plot_step": 100,
16
+ "model_param_stats": false,
17
+ "project_name": null,
18
+ "log_model_step": 10000,
19
+ "wandb_entity": null,
20
+ "save_step": 10000,
21
+ "checkpoint": true,
22
+ "keep_all_best": false,
23
+ "keep_after": 10000,
24
+ "num_loader_workers": 4,
25
+ "num_eval_loader_workers": 4,
26
+ "use_noise_augment": false,
27
+ "use_language_weighted_sampler": true,
28
+ "output_path": "../checkpoints/VITS-multilingual/VITS_fixes/new/new-SE/use_noise_aument_false/xlarge-ZS-PT-VCTK/pt-en+LibriTTS-fr/speaker_encoder_as_loss_9_alpha/mixed-p-false-bug-SDP-fixed/",
29
+ "distributed_backend": "nccl",
30
+ "distributed_url": "tcp://localhost:54321",
31
+ "audio": {
32
+ "fft_size": 1024,
33
+ "win_length": 1024,
34
+ "hop_length": 256,
35
+ "frame_shift_ms": null,
36
+ "frame_length_ms": null,
37
+ "stft_pad_mode": "reflect",
38
+ "sample_rate": 16000,
39
+ "resample": false,
40
+ "preemphasis": 0.0,
41
+ "ref_level_db": 20,
42
+ "do_sound_norm": false,
43
+ "log_func": "np.log",
44
+ "do_trim_silence": true,
45
+ "trim_db": 45,
46
+ "power": 1.5,
47
+ "griffin_lim_iters": 60,
48
+ "num_mels": 80,
49
+ "mel_fmin": 0.0,
50
+ "mel_fmax": null,
51
+ "spec_gain": 1,
52
+ "do_amp_to_db_linear": false,
53
+ "do_amp_to_db_mel": true,
54
+ "signal_norm": false,
55
+ "min_level_db": -100,
56
+ "symmetric_norm": true,
57
+ "max_norm": 4.0,
58
+ "clip_norm": true,
59
+ "stats_path": null
60
+ },
61
+ "use_phonemes": false,
62
+ "use_espeak_phonemes": false,
63
+ "phoneme_language": "pt-br",
64
+ "compute_input_seq_cache": false,
65
+ "text_cleaner": "multilingual_cleaners",
66
+ "enable_eos_bos_chars": false,
67
+ "test_sentences_file": "",
68
+ "phoneme_cache_path": null,
69
+ "characters": {
70
+ "pad": "_",
71
+ "eos": "&",
72
+ "bos": "*",
73
+ "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u00af\u00b7\u00df\u00e0\u00e1\u00e2\u00e3\u00e4\u00e6\u00e7\u00e8\u00e9\u00ea\u00eb\u00ec\u00ed\u00ee\u00ef\u00f1\u00f2\u00f3\u00f4\u00f5\u00f6\u00f9\u00fa\u00fb\u00fc\u00ff\u0101\u0105\u0107\u0113\u0119\u011b\u012b\u0131\u0142\u0144\u014d\u0151\u0153\u015b\u016b\u0171\u017a\u017c\u01ce\u01d0\u01d2\u01d4\u0430\u0431\u0432\u0433\u0434\u0435\u0436\u0437\u0438\u0439\u043a\u043b\u043c\u043d\u043e\u043f\u0440\u0441\u0442\u0443\u0444\u0445\u0446\u0447\u0448\u0449\u044a\u044b\u044c\u044d\u044e\u044f\u0451\u0454\u0456\u0457\u0491\u2013!'(),-.:;? ",
74
+ "punctuations": "!'(),-.:;? ",
75
+ "phonemes": "iy\u0268\u0289\u026fu\u026a\u028f\u028ae\u00f8\u0258\u0259\u0275\u0264o\u025b\u0153\u025c\u025e\u028c\u0254\u00e6\u0250a\u0276\u0251\u0252\u1d7b\u0298\u0253\u01c0\u0257\u01c3\u0284\u01c2\u0260\u01c1\u029bpbtd\u0288\u0256c\u025fk\u0261q\u0262\u0294\u0274\u014b\u0272\u0273n\u0271m\u0299r\u0280\u2c71\u027e\u027d\u0278\u03b2fv\u03b8\u00f0sz\u0283\u0292\u0282\u0290\u00e7\u029dx\u0263\u03c7\u0281\u0127\u0295h\u0266\u026c\u026e\u028b\u0279\u027bj\u0270l\u026d\u028e\u029f\u02c8\u02cc\u02d0\u02d1\u028dw\u0265\u029c\u02a2\u02a1\u0255\u0291\u027a\u0267\u025a\u02de\u026b'\u0303' ",
76
+ "unique": true
77
+ },
78
+ "batch_group_size": 0,
79
+ "loss_masking": null,
80
+ "min_seq_len": 90,
81
+ "max_seq_len": 270,
82
+ "compute_f0": false,
83
+ "compute_linear_spec": true,
84
+ "add_blank": true,
85
+ "datasets": [
86
+ {
87
+ "name": "vctk",
88
+ "path": "../../datasets/VCTK-Corpus-removed-silence_16Khz/",
89
+ "meta_file_train": null,
90
+ "ununsed_speakers": [
91
+ "p225",
92
+ "p234",
93
+ "p238",
94
+ "p245",
95
+ "p248",
96
+ "p261",
97
+ "p294",
98
+ "p302",
99
+ "p326",
100
+ "p335",
101
+ "p347"
102
+ ],
103
+ "language": "en",
104
+ "meta_file_val": null,
105
+ "meta_file_attn_mask": ""
106
+ },
107
+ {
108
+ "name": "libri_tts",
109
+ "path": "../../datasets/LibriTTS/LibriTTS/dataset-preprocessed-clean-100-and-360/dataset-22k/",
110
+ "meta_file_train": "metadata_all.csv",
111
+ "ununsed_speakers": null,
112
+ "language": "en",
113
+ "meta_file_val": "dev-clean_500.csv",
114
+ "meta_file_attn_mask": ""
115
+ },
116
+ {
117
+ "name": "brspeech",
118
+ "path": "../../datasets/TTS-Portuguese-Corpus_16khz/",
119
+ "meta_file_train": "train_TTS-Portuguese_Corpus_metadata.csv",
120
+ "ununsed_speakers": null,
121
+ "language": "pt-br",
122
+ "meta_file_val": "eval_TTS-Portuguese_Corpus_metadata.csv",
123
+ "meta_file_attn_mask": ""
124
+ },
125
+ {
126
+ "name": "mailabs",
127
+ "path": "../../datasets/M-AILABS/fr_FR",
128
+ "meta_file_train": "",
129
+ "ununsed_speakers": null,
130
+ "language": "fr-fr",
131
+ "meta_file_val": null,
132
+ "meta_file_attn_mask": null
133
+ }
134
+ ],
135
+ "optimizer": "AdamW",
136
+ "optimizer_params": {
137
+ "betas": [
138
+ 0.8,
139
+ 0.99
140
+ ],
141
+ "eps": 1e-09,
142
+ "weight_decay": 0.01
143
+ },
144
+ "lr_scheduler": "",
145
+ "lr_scheduler_params": null,
146
+ "test_sentences": [
147
+ [
148
+ "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
149
+ "VCTK_p225",
150
+ null,
151
+ "en"
152
+ ],
153
+ [
154
+ "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
155
+ "ED",
156
+ null,
157
+ "en"
158
+ ],
159
+ [
160
+ "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
161
+ "bernard",
162
+ null,
163
+ "en"
164
+ ],
165
+ [
166
+ "This cake is great. It's so delicious and moist.",
167
+ "VCTK_p234",
168
+ null,
169
+ "en"
170
+ ],
171
+ [
172
+ "This cake is great. It's so delicious and moist.",
173
+ "ED",
174
+ null,
175
+ "en"
176
+ ],
177
+ [
178
+ "This cake is great. It's so delicious and moist.",
179
+ "ezwa",
180
+ null,
181
+ "en"
182
+ ],
183
+ [
184
+ "Hoje \u00e9 fundamental encontrar a raz\u00e3o da exist\u00eancia humana.",
185
+ "ED",
186
+ null,
187
+ "pt-br"
188
+ ],
189
+ [
190
+ "Hoje \u00e9 fundamental encontrar a raz\u00e3o da exist\u00eancia humana.",
191
+ "VCTK_p238",
192
+ null,
193
+ "pt-br"
194
+ ],
195
+ [
196
+ "Hoje \u00e9 fundamental encontrar a raz\u00e3o da exist\u00eancia humana.",
197
+ "gilles_g_le_blanc",
198
+ null,
199
+ "pt-br"
200
+ ],
201
+ [
202
+ "Em muitas cidades a popula\u00e7\u00e3o est\u00e1 diminuindo.",
203
+ "ED",
204
+ null,
205
+ "pt-br"
206
+ ],
207
+ [
208
+ "Em muitas cidades a popula\u00e7\u00e3o est\u00e1 diminuindo.",
209
+ "VCTK_p245",
210
+ null,
211
+ "pt-br"
212
+ ],
213
+ [
214
+ "Em muitas cidades a popula\u00e7\u00e3o est\u00e1 diminuindo.",
215
+ "nadine_eckert_boulet",
216
+ null,
217
+ "pt-br"
218
+ ],
219
+ [
220
+ "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
221
+ "VCTK_p245",
222
+ null,
223
+ "fr-fr"
224
+ ],
225
+ [
226
+ "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
227
+ "ED",
228
+ null,
229
+ "fr-fr"
230
+ ],
231
+ [
232
+ "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
233
+ "ezwa",
234
+ null,
235
+ "fr-fr"
236
+ ],
237
+ [
238
+ "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
239
+ "bernard",
240
+ null,
241
+ "fr-fr"
242
+ ],
243
+ [
244
+ "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
245
+ "gilles_g_le_blanc",
246
+ null,
247
+ "fr-fr"
248
+ ],
249
+ [
250
+ "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
251
+ "nadine_eckert_boulet",
252
+ null,
253
+ "fr-fr"
254
+ ],
255
+ [
256
+ "Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
257
+ "zeckou",
258
+ null,
259
+ "fr-fr"
260
+ ]
261
+ ],
262
+ "use_speaker_embedding": true,
263
+ "use_d_vector_file": true,
264
+ "d_vector_dim": 512,
265
+ "model_args": {
266
+ "num_chars": 165,
267
+ "out_channels": 513,
268
+ "spec_segment_size": 62,
269
+ "hidden_channels": 192,
270
+ "hidden_channels_ffn_text_encoder": 768,
271
+ "num_heads_text_encoder": 2,
272
+ "num_layers_text_encoder": 10,
273
+ "kernel_size_text_encoder": 3,
274
+ "dropout_p_text_encoder": 0.1,
275
+ "dropout_p_duration_predictor": 0.5,
276
+ "kernel_size_posterior_encoder": 5,
277
+ "dilation_rate_posterior_encoder": 1,
278
+ "num_layers_posterior_encoder": 16,
279
+ "kernel_size_flow": 5,
280
+ "dilation_rate_flow": 1,
281
+ "num_layers_flow": 4,
282
+ "resblock_type_decoder": 1,
283
+ "resblock_kernel_sizes_decoder": [
284
+ 3,
285
+ 7,
286
+ 11
287
+ ],
288
+ "resblock_dilation_sizes_decoder": [
289
+ [
290
+ 1,
291
+ 3,
292
+ 5
293
+ ],
294
+ [
295
+ 1,
296
+ 3,
297
+ 5
298
+ ],
299
+ [
300
+ 1,
301
+ 3,
302
+ 5
303
+ ]
304
+ ],
305
+ "upsample_rates_decoder": [
306
+ 8,
307
+ 8,
308
+ 2,
309
+ 2
310
+ ],
311
+ "upsample_initial_channel_decoder": 512,
312
+ "upsample_kernel_sizes_decoder": [
313
+ 16,
314
+ 16,
315
+ 4,
316
+ 4
317
+ ],
318
+ "use_sdp": true,
319
+ "noise_scale": 1.0,
320
+ "inference_noise_scale": 0.667,
321
+ "length_scale": 1,
322
+ "noise_scale_dp": 1.0,
323
+ "inference_noise_scale_dp": 0.8,
324
+ "max_inference_len": null,
325
+ "init_discriminator": true,
326
+ "use_spectral_norm_disriminator": false,
327
+ "use_speaker_embedding": true,
328
+ "num_speakers": 1244,
329
+ "speakers_file": null,
330
+ "d_vector_file": "../speaker_embeddings/new-SE/VCTK-LibriTTS+TTS-PT+MAILABS-FR/speakers.json",
331
+ "speaker_embedding_channels": 512,
332
+ "use_d_vector_file": true,
333
+ "d_vector_dim": 512,
334
+ "detach_dp_input": true,
335
+ "use_language_embedding": true,
336
+ "embedded_language_dim": 4,
337
+ "num_languages": 3,
338
+ "use_speaker_encoder_as_loss": true,
339
+ "speaker_encoder_config_path": "../checkpoints/Speaker_Encoder/Resnet-original-paper/config.json",
340
+ "speaker_encoder_model_path": "../checkpoints/Speaker_Encoder/Resnet-original-paper/converted_checkpoint.pth.tar",
341
+ "fine_tuning_mode": 0,
342
+ "freeze_encoder": false,
343
+ "freeze_DP": false,
344
+ "freeze_PE": false,
345
+ "freeze_flow_decoder": false,
346
+ "freeze_waveform_decoder": false
347
+ },
348
+ "grad_clip": [
349
+ 5.0,
350
+ 5.0
351
+ ],
352
+ "lr_gen": 0.0002,
353
+ "lr_disc": 0.0002,
354
+ "lr_scheduler_gen": "ExponentialLR",
355
+ "lr_scheduler_gen_params": {
356
+ "gamma": 0.999875,
357
+ "last_epoch": -1
358
+ },
359
+ "lr_scheduler_disc": "ExponentialLR",
360
+ "lr_scheduler_disc_params": {
361
+ "gamma": 0.999875,
362
+ "last_epoch": -1
363
+ },
364
+ "kl_loss_alpha": 1.0,
365
+ "disc_loss_alpha": 1.0,
366
+ "gen_loss_alpha": 1.0,
367
+ "feat_loss_alpha": 1.0,
368
+ "mel_loss_alpha": 45.0,
369
+ "dur_loss_alpha": 1.0,
370
+ "speaker_encoder_loss_alpha": 9.0,
371
+ "return_wav": true,
372
+ "r": 1
373
+ }
config.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ input_text_desired_length: 110
2
+ input_text_max_length: 170
3
+ selected_theme: gradio/soft
4
+ server_name: ''
5
+ server_port: 0
6
+ server_share: false
7
+ silence_between_sentences: 250
8
+ silence_between_speakers: 500
config_se.json ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "speaker_encoder",
3
+ "run_name": "speaker_encoder",
4
+ "run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev",
5
+ "epochs": 100000,
6
+ "batch_size": null,
7
+ "eval_batch_size": null,
8
+ "mixed_precision": false,
9
+ "run_eval": true,
10
+ "test_delay_epochs": 0,
11
+ "print_eval": false,
12
+ "print_step": 50,
13
+ "tb_plot_step": 100,
14
+ "tb_model_param_stats": false,
15
+ "save_step": 1000,
16
+ "checkpoint": true,
17
+ "keep_all_best": false,
18
+ "keep_after": 10000,
19
+ "num_loader_workers": 8,
20
+ "num_val_loader_workers": 0,
21
+ "use_noise_augment": false,
22
+ "output_path": "../checkpoints/speaker_encoder/language_balanced/normalized/angleproto-4-samples-by-speakers/",
23
+ "distributed_backend": "nccl",
24
+ "distributed_url": "tcp://localhost:54321",
25
+ "audio": {
26
+ "fft_size": 512,
27
+ "win_length": 400,
28
+ "hop_length": 160,
29
+ "frame_shift_ms": null,
30
+ "frame_length_ms": null,
31
+ "stft_pad_mode": "reflect",
32
+ "sample_rate": 16000,
33
+ "resample": false,
34
+ "preemphasis": 0.97,
35
+ "ref_level_db": 20,
36
+ "do_sound_norm": false,
37
+ "do_trim_silence": false,
38
+ "trim_db": 60,
39
+ "power": 1.5,
40
+ "griffin_lim_iters": 60,
41
+ "num_mels": 64,
42
+ "mel_fmin": 0.0,
43
+ "mel_fmax": 8000.0,
44
+ "spec_gain": 20,
45
+ "signal_norm": false,
46
+ "min_level_db": -100,
47
+ "symmetric_norm": false,
48
+ "max_norm": 4.0,
49
+ "clip_norm": false,
50
+ "stats_path": null
51
+ },
52
+ "datasets": [
53
+ {
54
+ "name": "voxceleb2",
55
+ "path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox2_dev_aac/",
56
+ "meta_file_train": null,
57
+ "ununsed_speakers": null,
58
+ "meta_file_val": null,
59
+ "meta_file_attn_mask": "",
60
+ "language": "voxceleb"
61
+ }
62
+ ],
63
+ "model_params": {
64
+ "model_name": "resnet",
65
+ "input_dim": 64,
66
+ "use_torch_spec": true,
67
+ "log_input": true,
68
+ "proj_dim": 512
69
+ },
70
+ "audio_augmentation": {
71
+ "p": 0.5,
72
+ "rir": {
73
+ "rir_path": "/workspace/store/ecasanova/ComParE/RIRS_NOISES/simulated_rirs/",
74
+ "conv_mode": "full"
75
+ },
76
+ "additive": {
77
+ "sounds_path": "/workspace/store/ecasanova/ComParE/musan/",
78
+ "speech": {
79
+ "min_snr_in_db": 13,
80
+ "max_snr_in_db": 20,
81
+ "min_num_noises": 1,
82
+ "max_num_noises": 1
83
+ },
84
+ "noise": {
85
+ "min_snr_in_db": 0,
86
+ "max_snr_in_db": 15,
87
+ "min_num_noises": 1,
88
+ "max_num_noises": 1
89
+ },
90
+ "music": {
91
+ "min_snr_in_db": 5,
92
+ "max_snr_in_db": 15,
93
+ "min_num_noises": 1,
94
+ "max_num_noises": 1
95
+ }
96
+ },
97
+ "gaussian": {
98
+ "p": 0.0,
99
+ "min_amplitude": 0.0,
100
+ "max_amplitude": 1e-05
101
+ }
102
+ },
103
+ "storage": {
104
+ "sample_from_storage_p": 0.5,
105
+ "storage_size": 40
106
+ },
107
+ "max_train_step": 1000000,
108
+ "loss": "angleproto",
109
+ "grad_clip": 3.0,
110
+ "lr": 0.0001,
111
+ "lr_decay": false,
112
+ "warmup_steps": 4000,
113
+ "wd": 1e-06,
114
+ "steps_plot_stats": 100,
115
+ "num_speakers_in_batch": 100,
116
+ "num_utters_per_speaker": 4,
117
+ "skip_speakers": true,
118
+ "voice_len": 2.0
119
+ }
id3tagging.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mutagen.wave import WAVE
2
+ from mutagen.id3._frames import *
3
+
4
+ def add_id3_tag(filename, text, speakername, seed):
5
+ audio = WAVE(filename)
6
+ if speakername == None:
7
+ speakername = "Unconditional"
8
+
9
+ # write id3 tag with text truncated to 60 chars, as a precaution...
10
+ audio["TIT2"] = TIT2(encoding=3, text=text[:60])
11
+ audio["TPE1"] = TPE1(encoding=3, text=f"Voice {speakername} using Seed={seed}")
12
+ audio["TPUB"] = TPUB(encoding=3, text="Bark by Suno AI")
13
+ audio["COMMENT"] = COMM(encoding=3, text="Generated with Bark GUI - Text-Prompted Generative Audio Model. Visit https://github.com/C0untFloyd/bark-gui")
14
+ audio.save()
language_ids.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "en": 0,
3
+ "fr-fr": 1,
4
+ "pt-br": 2
5
+ }
parseinput.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import xml.etree.ElementTree as ET
3
+ from xml.sax import saxutils
4
+ #import nltk
5
+
6
+ # Chunked generation originally from https://github.com/serp-ai/bark-with-voice-clone
7
+ def split_and_recombine_text(text, desired_length=100, max_length=150):
8
+ # return nltk.sent_tokenize(text)
9
+
10
+ # from https://github.com/neonbjb/tortoise-tts
11
+ """Split text it into chunks of a desired length trying to keep sentences intact."""
12
+ # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
13
+ text = re.sub(r"\n\n+", "\n", text)
14
+ text = re.sub(r"\s+", " ", text)
15
+ text = re.sub(r"[“”]", '"', text)
16
+
17
+ rv = []
18
+ in_quote = False
19
+ current = ""
20
+ split_pos = []
21
+ pos = -1
22
+ end_pos = len(text) - 1
23
+
24
+ def seek(delta):
25
+ nonlocal pos, in_quote, current
26
+ is_neg = delta < 0
27
+ for _ in range(abs(delta)):
28
+ if is_neg:
29
+ pos -= 1
30
+ current = current[:-1]
31
+ else:
32
+ pos += 1
33
+ current += text[pos]
34
+ if text[pos] == '"':
35
+ in_quote = not in_quote
36
+ return text[pos]
37
+
38
+ def peek(delta):
39
+ p = pos + delta
40
+ return text[p] if p < end_pos and p >= 0 else ""
41
+
42
+ def commit():
43
+ nonlocal rv, current, split_pos
44
+ rv.append(current)
45
+ current = ""
46
+ split_pos = []
47
+
48
+ while pos < end_pos:
49
+ c = seek(1)
50
+ # do we need to force a split?
51
+ if len(current) >= max_length:
52
+ if len(split_pos) > 0 and len(current) > (desired_length / 2):
53
+ # we have at least one sentence and we are over half the desired length, seek back to the last split
54
+ d = pos - split_pos[-1]
55
+ seek(-d)
56
+ else:
57
+ # no full sentences, seek back until we are not in the middle of a word and split there
58
+ while c not in "!?.,\n " and pos > 0 and len(current) > desired_length:
59
+ c = seek(-1)
60
+ commit()
61
+ # check for sentence boundaries
62
+ elif not in_quote and (c in "!?]\n" or (c == "." and peek(1) in "\n ")):
63
+ # seek forward if we have consecutive boundary markers but still within the max length
64
+ while (
65
+ pos < len(text) - 1 and len(current) < max_length and peek(1) in "!?.]"
66
+ ):
67
+ c = seek(1)
68
+ split_pos.append(pos)
69
+ if len(current) >= desired_length:
70
+ commit()
71
+ # treat end of quote as a boundary if its followed by a space or newline
72
+ elif in_quote and peek(1) == '"' and peek(2) in "\n ":
73
+ seek(2)
74
+ split_pos.append(pos)
75
+ rv.append(current)
76
+
77
+ # clean up, remove lines with only whitespace or punctuation
78
+ rv = [s.strip() for s in rv]
79
+ rv = [s for s in rv if len(s) > 0 and not re.match(r"^[\s\.,;:!?]*$", s)]
80
+
81
+ return rv
82
+
83
+ def is_ssml(value):
84
+ try:
85
+ ET.fromstring(value)
86
+ except ET.ParseError:
87
+ return False
88
+ return True
89
+
90
+ def build_ssml(rawtext, selected_voice):
91
+ texts = rawtext.split("\n")
92
+ joinedparts = ""
93
+ for textpart in texts:
94
+ textpart = textpart.strip()
95
+ if len(textpart) < 1:
96
+ continue
97
+ joinedparts = joinedparts + f"\n<voice name=\"{selected_voice}\">{saxutils.escape(textpart)}</voice>"
98
+ ssml = f"""<?xml version="1.0"?>
99
+ <speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis"
100
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
101
+ xsi:schemaLocation="http://www.w3.org/2001/10/synthesis
102
+ http://www.w3.org/TR/speech-synthesis/synthesis.xsd"
103
+ xml:lang="en-US">
104
+ {joinedparts}
105
+ </speak>
106
+ """
107
+ return ssml
108
+
109
+ def create_clips_from_ssml(ssmlinput):
110
+ # Parse the XML
111
+ tree = ET.ElementTree(ET.fromstring(ssmlinput))
112
+ root = tree.getroot()
113
+
114
+ # Create an empty list
115
+ voice_list = []
116
+
117
+ # Loop through all voice tags
118
+ for voice in root.iter('{http://www.w3.org/2001/10/synthesis}voice'):
119
+ # Extract the voice name attribute and the content text
120
+ voice_name = voice.attrib['name']
121
+ voice_content = voice.text.strip() if voice.text else ''
122
+ if(len(voice_content) > 0):
123
+ parts = split_and_recombine_text(voice_content)
124
+ for p in parts:
125
+ if(len(p) > 1):
126
+ # add to tuple list
127
+ voice_list.append((voice_name, p))
128
+ return voice_list
129
+
pyproject.toml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "bark-ui-enhanced"
7
+ version = "0.4.7"
8
+ description = "Bark text to audio model with addition features and a Web UI"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ authors = [
12
+ {name = "Suno Inc (original Bark)", email = "[email protected]"},
13
+ {name = "Count Floyd"},
14
+ ]
15
+ # MIT License
16
+ license = {file = "LICENSE"}
17
+
18
+ dependencies = [
19
+ "boto3",
20
+ "encodec",
21
+ "funcy",
22
+ "mutagen",
23
+ "numpy",
24
+ "pytorch_seed",
25
+ "scipy",
26
+ "tokenizers",
27
+ "torch",
28
+ "tqdm",
29
+ "transformers",
30
+ "pyyaml"
31
+ ]
32
+
33
+ [project.urls]
34
+ source = "https://github.com/C0untFloyd/bark-gui"
35
+
36
+ [project.optional-dependencies]
37
+ dev = [
38
+ "bandit",
39
+ "black",
40
+ "codecov",
41
+ "flake8",
42
+ "huggingface-hub>=0.14.1",
43
+ "hypothesis>=6.14,<7",
44
+ "isort>=5.0.0,<6",
45
+ "jupyter",
46
+ "mypy",
47
+ "nbconvert",
48
+ "nbformat",
49
+ "pydocstyle",
50
+ "pylint",
51
+ "pytest",
52
+ "pytest-cov",
53
+ ]
54
+
55
+ [tool.setuptools]
56
+ packages = ["bark"]
57
+
58
+ [tool.setuptools.package-data]
59
+ bark = ["assets/prompts/*.npz", "assets/prompts/v2/*.npz"]
60
+
61
+
62
+ [tool.black]
63
+ line-length = 100
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch_seed
2
+ encodec
3
+ funcy
4
+ transformers
5
+ scipy
6
+ mutagen
7
+ git+https://github.com/Edresson/Coqui-TTS@multilingual-torchaudio-SE
8
+ torchaudio
9
+ pydub
10
+ ffmpeg-normalize==1.21.0
settings.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ class Settings:
4
+ def __init__(self, config_file):
5
+ self.config_file = config_file
6
+ self.load()
7
+
8
+ def load(self):
9
+ try:
10
+ with open(self.config_file, 'r') as f:
11
+ data = yaml.load(f, Loader=yaml.FullLoader)
12
+ self.selected_theme = data.get('selected_theme', "gstaff/xkcd")
13
+ self.server_name = data.get('server_name', "")
14
+ self.server_port = data.get('server_port', 0)
15
+ self.server_share = data.get('server_share', False)
16
+ self.input_text_desired_length = data.get('input_text_desired_length', 110)
17
+ self.input_text_max_length = data.get('input_text_max_length', 170)
18
+ self.silence_sentence = data.get('silence_between_sentences', 250)
19
+ self.silence_speakers = data.get('silence_between_speakers', 500)
20
+
21
+ except:
22
+ self.selected_theme = "gstaff/xkcd"
23
+
24
+ def save(self):
25
+ data = {
26
+ 'selected_theme': self.selected_theme,
27
+ 'server_name': self.server_name,
28
+ 'server_port': self.server_port,
29
+ 'server_share': self.server_share,
30
+ 'input_text_desired_length' : self.input_text_desired_length,
31
+ 'input_text_max_length' : self.input_text_max_length,
32
+ 'silence_between_sentences': self.silence_sentence,
33
+ 'silence_between_speakers': self.silence_speakers,
34
+ }
35
+ with open(self.config_file, 'w') as f:
36
+ yaml.dump(data, f)
37
+
38
+
39
+
setup.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup()
speakers.json ADDED
The diff for this file is too large to render. See raw diff