|
import os |
|
import torch |
|
import gdown |
|
import logging |
|
import psutil |
|
import langid |
|
langid.set_languages(['en', 'zh', 'ja']) |
|
|
|
import pathlib |
|
import platform |
|
if platform.system().lower() == 'windows': |
|
temp = pathlib.PosixPath |
|
pathlib.PosixPath = pathlib.WindowsPath |
|
elif platform.system().lower() == 'linux': |
|
temp = pathlib.WindowsPath |
|
pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
import numpy as np |
|
from data.tokenizer import ( |
|
AudioTokenizer, |
|
tokenize_audio, |
|
) |
|
from data.collation import get_text_token_collater |
|
from models.vallex import VALLE |
|
from utils.g2p import PhonemeBpeTokenizer |
|
from utils.sentence_cutter import split_text_into_sentences |
|
|
|
from macros import * |
|
|
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda", 0) |
|
|
|
url = 'https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing' |
|
|
|
checkpoints_dir = "./checkpoints/" |
|
|
|
model_checkpoint_name = "vallex-checkpoint.pt" |
|
|
|
model = None |
|
|
|
codec = None |
|
|
|
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json") |
|
text_collater = get_text_token_collater() |
|
|
|
def preload_models(): |
|
global model, codec |
|
if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir) |
|
if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)): |
|
gdown.download(id="10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl", output=os.path.join(checkpoints_dir, model_checkpoint_name), quiet=False) |
|
|
|
model = VALLE( |
|
N_DIM, |
|
NUM_HEAD, |
|
NUM_LAYERS, |
|
norm_first=True, |
|
add_prenet=False, |
|
prefix_mode=PREFIX_MODE, |
|
share_embedding=True, |
|
nar_scale_factor=1.0, |
|
prepend_bos=True, |
|
num_quantizers=NUM_QUANTIZERS, |
|
).to(device) |
|
checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu') |
|
missing_keys, unexpected_keys = model.load_state_dict( |
|
checkpoint["model"], strict=True |
|
) |
|
assert not missing_keys |
|
model.eval() |
|
|
|
|
|
codec = AudioTokenizer(device) |
|
|
|
@torch.no_grad() |
|
def generate_audio(text, prompt=None, language='auto', accent='no-accent'): |
|
global model, codec, text_tokenizer, text_collater |
|
text = text.replace("\n", "").strip(" ") |
|
|
|
if language == "auto": |
|
language = langid.classify(text)[0] |
|
lang_token = lang2token[language] |
|
lang = token2lang[lang_token] |
|
text = lang_token + text + lang_token |
|
|
|
|
|
if prompt is not None: |
|
prompt_path = prompt |
|
if not os.path.exists(prompt_path): |
|
prompt_path = "./presets/" + prompt + ".npz" |
|
if not os.path.exists(prompt_path): |
|
prompt_path = "./customs/" + prompt + ".npz" |
|
if not os.path.exists(prompt_path): |
|
raise ValueError(f"Cannot find prompt {prompt}") |
|
prompt_data = np.load(prompt_path) |
|
audio_prompts = prompt_data['audio_tokens'] |
|
text_prompts = prompt_data['text_tokens'] |
|
lang_pr = prompt_data['lang_code'] |
|
lang_pr = code2lang[int(lang_pr)] |
|
|
|
|
|
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) |
|
text_prompts = torch.tensor(text_prompts).type(torch.int32) |
|
else: |
|
audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) |
|
text_prompts = torch.zeros([1, 0]).type(torch.int32) |
|
lang_pr = lang if lang != 'mix' else 'en' |
|
|
|
enroll_x_lens = text_prompts.shape[-1] |
|
logging.info(f"synthesize text: {text}") |
|
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) |
|
text_tokens, text_tokens_lens = text_collater( |
|
[ |
|
phone_tokens |
|
] |
|
) |
|
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) |
|
text_tokens_lens += enroll_x_lens |
|
|
|
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] |
|
encoded_frames = model.inference( |
|
text_tokens.to(device), |
|
text_tokens_lens.to(device), |
|
audio_prompts, |
|
enroll_x_lens=enroll_x_lens, |
|
top_k=-100, |
|
temperature=1, |
|
prompt_language=lang_pr, |
|
text_language=langs if accent == "no-accent" else lang, |
|
) |
|
samples = codec.decode( |
|
[(encoded_frames.transpose(2, 1), None)] |
|
) |
|
|
|
return samples[0][0].cpu().numpy() |
|
|
|
@torch.no_grad() |
|
def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'): |
|
""" |
|
For long audio generation, two modes are available. |
|
fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence. |
|
sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance. |
|
""" |
|
global model, codec, text_tokenizer, text_collater |
|
if prompt is None or prompt == "": |
|
mode = 'sliding-window' |
|
sentences = split_text_into_sentences(text) |
|
|
|
if language == "auto": |
|
language = langid.classify(text)[0] |
|
|
|
|
|
if prompt is not None and prompt != "": |
|
prompt_path = prompt |
|
if not os.path.exists(prompt_path): |
|
prompt_path = "./presets/" + prompt + ".npz" |
|
if not os.path.exists(prompt_path): |
|
prompt_path = "./customs/" + prompt + ".npz" |
|
if not os.path.exists(prompt_path): |
|
raise ValueError(f"Cannot find prompt {prompt}") |
|
prompt_data = np.load(prompt_path) |
|
audio_prompts = prompt_data['audio_tokens'] |
|
text_prompts = prompt_data['text_tokens'] |
|
lang_pr = prompt_data['lang_code'] |
|
lang_pr = code2lang[int(lang_pr)] |
|
|
|
|
|
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) |
|
text_prompts = torch.tensor(text_prompts).type(torch.int32) |
|
else: |
|
audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) |
|
text_prompts = torch.zeros([1, 0]).type(torch.int32) |
|
lang_pr = language if language != 'mix' else 'en' |
|
if mode == 'fixed-prompt': |
|
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) |
|
for text in sentences: |
|
text = text.replace("\n", "").strip(" ") |
|
if text == "": |
|
continue |
|
lang_token = lang2token[language] |
|
lang = token2lang[lang_token] |
|
text = lang_token + text + lang_token |
|
|
|
enroll_x_lens = text_prompts.shape[-1] |
|
logging.info(f"synthesize text: {text}") |
|
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) |
|
text_tokens, text_tokens_lens = text_collater( |
|
[ |
|
phone_tokens |
|
] |
|
) |
|
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) |
|
text_tokens_lens += enroll_x_lens |
|
|
|
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] |
|
encoded_frames = model.inference( |
|
text_tokens.to(device), |
|
text_tokens_lens.to(device), |
|
audio_prompts, |
|
enroll_x_lens=enroll_x_lens, |
|
top_k=-100, |
|
temperature=1, |
|
prompt_language=lang_pr, |
|
text_language=langs if accent == "no-accent" else lang, |
|
) |
|
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) |
|
samples = codec.decode( |
|
[(complete_tokens, None)] |
|
) |
|
return samples[0][0].cpu().numpy() |
|
elif mode == "sliding-window": |
|
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) |
|
original_audio_prompts = audio_prompts |
|
original_text_prompts = text_prompts |
|
for text in sentences: |
|
text = text.replace("\n", "").strip(" ") |
|
if text == "": |
|
continue |
|
lang_token = lang2token[language] |
|
lang = token2lang[lang_token] |
|
text = lang_token + text + lang_token |
|
|
|
enroll_x_lens = text_prompts.shape[-1] |
|
logging.info(f"synthesize text: {text}") |
|
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) |
|
text_tokens, text_tokens_lens = text_collater( |
|
[ |
|
phone_tokens |
|
] |
|
) |
|
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) |
|
text_tokens_lens += enroll_x_lens |
|
|
|
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] |
|
encoded_frames = model.inference( |
|
text_tokens.to(device), |
|
text_tokens_lens.to(device), |
|
audio_prompts, |
|
enroll_x_lens=enroll_x_lens, |
|
top_k=-100, |
|
temperature=1, |
|
prompt_language=lang_pr, |
|
text_language=langs if accent == "no-accent" else lang, |
|
) |
|
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) |
|
if torch.rand(1) < 0.5: |
|
audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:] |
|
text_prompts = text_tokens[:, enroll_x_lens:] |
|
else: |
|
audio_prompts = original_audio_prompts |
|
text_prompts = original_text_prompts |
|
samples = codec.decode( |
|
[(complete_tokens, None)] |
|
) |
|
return samples[0][0].cpu().numpy() |
|
else: |
|
raise ValueError(f"No such mode {mode}") |
|
|