Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import re | |
import json | |
import codecs | |
import ffmpeg | |
import argparse | |
import platform | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.nn.utils.rnn import pad_sequence | |
import scipy | |
import scipy.io.wavfile | |
import librosa | |
from scipy.io.wavfile import write | |
import numpy as np | |
try: | |
import sys | |
sys.path.append(".") | |
from resources.app.python.xvapitch.text import ALL_SYMBOLS, get_text_preprocessor, lang_names | |
from resources.app.python.xvapitch.xvapitch_model import xVAPitch as xVAPitchModel | |
except ModuleNotFoundError: | |
try: | |
from python.xvapitch.text import ALL_SYMBOLS, get_text_preprocessor, lang_names | |
from python.xvapitch.xvapitch_model import xVAPitch as xVAPitchModel | |
except ModuleNotFoundError: | |
try: | |
from xvapitch.text import ALL_SYMBOLS, get_text_preprocessor, lang_names | |
from xvapitch.xvapitch_model import xVAPitch as xVAPitchModel | |
except ModuleNotFoundError: | |
from text import ALL_SYMBOLS, get_text_preprocessor, lang_names | |
from xvapitch_model import xVAPitch as xVAPitchModel | |
class xVAPitch(object): | |
def __init__(self, logger, PROD, device, models_manager): | |
super(xVAPitch, self).__init__() | |
self.logger = logger | |
self.PROD = PROD | |
self.models_manager = models_manager | |
self.device = device | |
self.ckpt_path = None | |
self.arpabet_dict = {} | |
# torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.benchmark = False | |
self.base_dir = f'{"./resources/app" if self.PROD else "."}/python/xvapitch/text' | |
self.lang_tp = {} | |
self.lang_tp["en"] = get_text_preprocessor("en", self.base_dir, logger=self.logger) | |
self.language_id_mapping = {name: i for i, name in enumerate(sorted(list(lang_names.keys())))} | |
self.pitch_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/pitch_emb.npy')).unsqueeze(0).unsqueeze(-1) | |
self.angry_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/angry.npy')).unsqueeze(0).unsqueeze(-1) | |
self.happy_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/happy.npy')).unsqueeze(0).unsqueeze(-1) | |
self.sad_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/sad.npy')).unsqueeze(0).unsqueeze(-1) | |
self.surprise_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/surprise.npy')).unsqueeze(0).unsqueeze(-1) | |
self.base_lang = "en" | |
self.init_model() | |
self.model.pitch_emb_values = self.pitch_emb_values.to(self.models_manager.device) | |
self.model.angry_emb_values = self.angry_emb_values.to(self.models_manager.device) | |
self.model.happy_emb_values = self.happy_emb_values.to(self.models_manager.device) | |
self.model.sad_emb_values = self.sad_emb_values.to(self.models_manager.device) | |
self.model.surprise_emb_values = self.surprise_emb_values.to(self.models_manager.device) | |
self.isReady = True | |
def init_model (self): | |
parser = argparse.ArgumentParser() | |
args = parser.parse_args() | |
# Params from training | |
args.pitch = 1 | |
args.pe_scaling = 0.1 | |
args.expanded_flow = 0 | |
args.ow_flow = 0 | |
args.energy = 0 | |
self.model = xVAPitchModel(args).to(self.device) | |
self.model.eval() | |
self.model.device = self.device | |
def load_state_dict (self, ckpt_path, ckpt, n_speakers=1, base_lang="en"): | |
self.logger.info(f'load_state_dict base_lang: {base_lang}') | |
if base_lang not in self.lang_tp.keys(): | |
self.lang_tp[base_lang] = get_text_preprocessor(base_lang, self.base_dir, logger=self.logger) | |
self.base_lang = base_lang | |
self.ckpt_path = ckpt_path | |
if os.path.exists(ckpt_path.replace(".pt", ".json")): | |
with open(ckpt_path.replace(".pt", ".json"), "r") as f: | |
data = json.load(f) | |
self.base_emb = data["games"][0]["base_speaker_emb"] | |
if 'model' in ckpt: | |
ckpt = ckpt['model'] | |
if ckpt["emb_l.weight"].shape[0]==31: | |
self.model.emb_l = nn.Embedding(31, self.model.embedded_language_dim).to(self.models_manager.device) | |
elif ckpt["emb_l.weight"].shape[0]==50: | |
num_languages = 50 | |
self.model.emb_l = nn.Embedding(num_languages, self.model.embedded_language_dim).to(self.models_manager.device) | |
self.model.load_state_dict(ckpt, strict=False) | |
self.model = self.model.float() | |
self.model.eval() | |
def init_arpabet_dicts (self): | |
if len(list(self.arpabet_dict.keys()))==0: | |
self.refresh_arpabet_dicts() | |
def refresh_arpabet_dicts (self): | |
self.arpabet_dict = {} | |
json_files = sorted(os.listdir(f'{"./resources/app" if self.PROD else "."}/arpabet')) | |
json_files = [fname for fname in json_files if fname.endswith(".json")] | |
for fname in json_files: | |
with codecs.open(f'{"./resources/app" if self.PROD else "."}/arpabet/{fname}', encoding="utf-8") as f: | |
json_data = json.load(f) | |
for word in list(json_data["data"].keys()): | |
if json_data["data"][word]["enabled"]==True: | |
self.arpabet_dict[word] = json_data["data"][word]["arpabet"] | |
def run_speech_to_speech (self, audiopath, audio_out_path, style_emb, models_manager, plugin_manager, vc_strength=1, useSR=False, useCleanup=False): | |
if ".wav" in style_emb: | |
self.logger.info(f'Getting style emb from: {style_emb}') | |
style_emb = models_manager.models("speaker_rep").compute_embedding(style_emb).squeeze() | |
else: | |
self.logger.info(f'Given style emb') | |
style_emb = torch.tensor(style_emb).squeeze() | |
try: | |
content_emb = models_manager.models("speaker_rep").compute_embedding(audiopath).squeeze() | |
except: | |
return "TOO_SHORT" | |
style_emb = F.normalize(style_emb.unsqueeze(0), dim=1).unsqueeze(-1).to(self.models_manager.device) | |
content_emb = F.normalize(content_emb.unsqueeze(0), dim=1).unsqueeze(-1).to(self.models_manager.device) | |
content_emb = content_emb + (-(vc_strength-1) * (style_emb - content_emb)) | |
y, sr = librosa.load(audiopath, sr=22050) | |
D = librosa.stft( | |
y=y, | |
n_fft=1024, | |
hop_length=256, | |
win_length=1024, | |
pad_mode="reflect", | |
window="hann", | |
center=True, | |
) | |
spec = np.abs(D).astype(np.float32) | |
ref_spectrogram = torch.FloatTensor(spec).unsqueeze(0) | |
y_lengths = torch.tensor([ref_spectrogram.size(-1)]).to(self.models_manager.device) | |
y = ref_spectrogram.to(self.models_manager.device) | |
wav = self.model.voice_conversion(y=y, y_lengths=y_lengths, spk1_emb=content_emb, spk2_emb=style_emb) | |
wav = wav.squeeze().cpu().detach().numpy() | |
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) | |
if useCleanup: | |
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe' | |
if useSR: | |
scipy.io.wavfile.write(audio_out_path.replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16)) | |
else: | |
scipy.io.wavfile.write(audio_out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16)) | |
stream = ffmpeg.input(audio_out_path.replace(".wav", "_preCleanupPreFFmpeg.wav")) | |
ffmpeg_options = {"ar": 48000} | |
output_path = audio_out_path.replace(".wav", "_preCleanup.wav") | |
stream = ffmpeg.output(stream, output_path, **ffmpeg_options) | |
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True)) | |
os.remove(audio_out_path.replace(".wav", "_preCleanupPreFFmpeg.wav")) | |
else: | |
scipy.io.wavfile.write(audio_out_path.replace(".wav", "_preSR.wav") if useSR else audio_out_path, 22050, wav_norm.astype(np.int16)) | |
if useSR: | |
self.models_manager.init_model("nuwave2") | |
self.models_manager.models("nuwave2").sr_audio(audio_out_path.replace(".wav", "_preSR.wav"), audio_out_path.replace(".wav", "_preCleanup.wav") if useCleanup else audio_out_path) | |
if useCleanup: | |
self.models_manager.init_model("deepfilternet2") | |
self.models_manager.models("deepfilternet2").cleanup_audio(audio_out_path.replace(".wav", "_preCleanup.wav"), audio_out_path) | |
return | |
def infer_batch(self, plugin_manager, linesBatch, outputJSON, vocoder, speaker_i, old_sequence=None, useSR=False, useCleanup=False): | |
print(f'Inferring batch of {len(linesBatch)} lines') | |
text_sequences = [] | |
cleaned_text_sequences = [] | |
lang_embs = [] | |
speaker_embs = [] | |
# [sequence, pitch, duration, pace, tempFileLocation, outPath, outFolder, pitch_amp, base_lang, base_emb, vc_content, vc_style] | |
vc_input = [] | |
tts_input = [] | |
for ri,record in enumerate(linesBatch): | |
if record[-2]: # If a VC content file has been given, handle this as VC | |
vc_input.append(record) | |
else: | |
tts_input.append(record) | |
# ================= | |
# ======= Handle VC | |
# ================= | |
if len(vc_input): | |
for ri,record in enumerate(vc_input): | |
content_emb = self.models_manager.models("speaker_rep").compute_embedding(record[-2]).squeeze() | |
style_emb = self.models_manager.models("speaker_rep").compute_embedding(record[-1]).squeeze() | |
# content_emb = F.normalize(content_emb.unsqueeze(0), dim=1).squeeze(0) | |
# style_emb = F.normalize(style_emb.unsqueeze(0), dim=1).squeeze(0) | |
content_emb = content_emb.unsqueeze(0).unsqueeze(-1).to(self.models_manager.device) | |
style_emb = style_emb.unsqueeze(0).unsqueeze(-1).to(self.models_manager.device) | |
y, sr = librosa.load(record[-2], sr=22050) | |
D = librosa.stft( | |
y=y, | |
n_fft=1024, | |
hop_length=256, | |
win_length=1024, | |
pad_mode="reflect", | |
window="hann", | |
center=True, | |
) | |
spec = np.abs(D).astype(np.float32) | |
ref_spectrogram = torch.FloatTensor(spec).unsqueeze(0) | |
y_lengths = torch.tensor([ref_spectrogram.size(-1)]).to(self.models_manager.device) | |
y = ref_spectrogram.to(self.models_manager.device) | |
# Run Voice Conversion | |
self.model.logger = self.logger | |
wav = self.model.voice_conversion(y=y, y_lengths=y_lengths, spk1_emb=content_emb, spk2_emb=style_emb) | |
wav = wav.squeeze().cpu().detach().numpy() | |
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) | |
if useCleanup: | |
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe' | |
if useSR: | |
scipy.io.wavfile.write(tts_input[ri][4].replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16)) | |
else: | |
scipy.io.wavfile.write(tts_input[ri][4].replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16)) | |
stream = ffmpeg.input(tts_input[ri][4].replace(".wav", "_preCleanupPreFFmpeg.wav")) | |
ffmpeg_options = {"ar": 48000} | |
output_path = tts_input[ri][4].replace(".wav", "_preCleanup.wav") | |
stream = ffmpeg.output(stream, output_path, **ffmpeg_options) | |
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True)) | |
os.remove(tts_input[ri][4].replace(".wav", "_preCleanupPreFFmpeg.wav")) | |
else: | |
scipy.io.wavfile.write(vc_input[ri][4].replace(".wav", "_preSR.wav") if useSR else vc_input[ri][4], 22050, wav_norm.astype(np.int16)) | |
if useSR: | |
self.models_manager.init_model("nuwave2") | |
self.models_manager.models("nuwave2").sr_audio(vc_input[ri][4].replace(".wav", "_preSR.wav"), vc_input[ri][4].replace(".wav", "_preCleanup.wav") if useCleanup else vc_input[ri][4]) | |
os.remove(vc_input[ri][4].replace(".wav", "_preSR.wav")) | |
if useCleanup: | |
self.models_manager.init_model("deepfilternet2") | |
self.models_manager.models("deepfilternet2").cleanup_audio(vc_input[ri][4].replace(".wav", "_preCleanup.wav"), vc_input[ri][4]) | |
os.remove(vc_input[ri][4].replace(".wav", "_preCleanup.wav")) | |
# ================== | |
# ======= Handle TTS | |
# ================== | |
if len(tts_input): | |
lang_embs_sizes = [] | |
for ri,record in enumerate(tts_input): | |
# Pre-process text | |
text = record[0].replace("/lang", "\\lang") | |
base_lang = record[-4] | |
self.logger.info(f'[infer_batch] text: {text}') | |
sequenceSplitByLanguage = self.preprocess_prompt_language(text, base_lang) | |
# Make sure all languages' text processors are initialized | |
for subSequence in sequenceSplitByLanguage: | |
langCode = list(subSequence.keys())[0] | |
if langCode not in self.lang_tp.keys(): | |
self.lang_tp[langCode] = get_text_preprocessor(langCode, self.base_dir, logger=self.logger) | |
try: | |
pad_symb = len(ALL_SYMBOLS)-2 | |
all_sequence = [] | |
all_cleaned_text = [] | |
all_text = [] | |
all_lang_ids = [] | |
# Collapse same-language words into phrases, so that heteronyms can still be detected | |
sequenceSplitByLanguage_grouped = [] | |
last_lang_group = None | |
group = "" | |
for ssi, subSequence in enumerate(sequenceSplitByLanguage): | |
if list(subSequence.keys())[0]!=last_lang_group: | |
if last_lang_group is not None: | |
sequenceSplitByLanguage_grouped.append({last_lang_group: group}) | |
group = "" | |
last_lang_group = list(subSequence.keys())[0] | |
group += subSequence[last_lang_group] | |
if len(group): | |
sequenceSplitByLanguage_grouped.append({last_lang_group: group}) | |
for ssi, subSequence in enumerate(sequenceSplitByLanguage_grouped): | |
langCode = list(subSequence.keys())[0] | |
subSeq = subSequence[langCode] | |
sequence, cleaned_text = self.lang_tp[langCode].text_to_sequence(subSeq) | |
if ssi<len(sequenceSplitByLanguage_grouped)-1: | |
sequence = sequence + [pad_symb] | |
all_sequence.append(sequence) | |
all_cleaned_text += ("|"+cleaned_text) if len(all_cleaned_text) else cleaned_text | |
if ssi<len(sequenceSplitByLanguage_grouped)-1: | |
all_cleaned_text = all_cleaned_text + ["|<PAD>"] | |
all_text.append(torch.LongTensor(sequence)) | |
language_id = self.language_id_mapping[langCode] | |
all_lang_ids += [language_id for _ in range(len(sequence))] | |
except ValueError as e: | |
self.logger.info("====") | |
self.logger.info(str(e)) | |
self.logger.info("====--") | |
if "not in list" in str(e): | |
symbol_not_in_list = str(e).split("is not in list")[0].split("ValueError:")[-1].replace("'", "").strip() | |
return f'ERR: ARPABET_NOT_IN_LIST: {symbol_not_in_list}' | |
all_cleaned_text = "".join(all_cleaned_text) | |
text = torch.cat(all_text, dim=0) | |
cleaned_text_sequences.append(all_cleaned_text) | |
text = torch.LongTensor(text) | |
text_sequences.append(text) | |
lang_ids = torch.tensor(all_lang_ids).to(self.models_manager.device) | |
lang_embs.append(lang_ids) | |
lang_embs_sizes.append(lang_ids.shape[0]) | |
speaker_embs.append(torch.tensor(tts_input[ri][-3]).unsqueeze(-1)) | |
lang_embs = pad_sequence(lang_embs, batch_first=True).to(self.models_manager.device) | |
text_sequences = pad_sequence(text_sequences, batch_first=True).to(self.models_manager.device) | |
speaker_embs = pad_sequence(speaker_embs, batch_first=True).to(self.models_manager.device) | |
pace = torch.tensor([record[3] for record in tts_input]).unsqueeze(1).to(self.device) | |
pitch_amp = torch.tensor([record[7] for record in tts_input]).unsqueeze(1).to(self.device) | |
# Could pass indexes (and get them returned) to the tts inference fn | |
# Do the same to the vc infer fn | |
# Then marge them into their place in an output array? | |
out = self.model.infer_advanced(self.logger, plugin_manager, [cleaned_text_sequences], text_sequences, lang_embs=lang_embs, speaker_embs=speaker_embs, pace=pace, old_sequence=None, pitch_amp=pitch_amp) | |
if isinstance(out, str): | |
return out | |
else: | |
output_wav, dur_pred, pitch_pred, energy_pred, _, _, _, _ = out | |
for i,wav in enumerate(output_wav): | |
wav = wav.squeeze().cpu().detach().numpy() | |
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) | |
if useCleanup: | |
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe' | |
if useSR: | |
scipy.io.wavfile.write(tts_input[i][4].replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16)) | |
else: | |
scipy.io.wavfile.write(tts_input[i][4].replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16)) | |
stream = ffmpeg.input(tts_input[i][4].replace(".wav", "_preCleanupPreFFmpeg.wav")) | |
ffmpeg_options = {"ar": 48000} | |
output_path = tts_input[i][4].replace(".wav", "_preCleanup.wav") | |
stream = ffmpeg.output(stream, output_path, **ffmpeg_options) | |
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True)) | |
os.remove(tts_input[i][4].replace(".wav", "_preCleanupPreFFmpeg.wav")) | |
else: | |
scipy.io.wavfile.write(tts_input[i][4].replace(".wav", "_preSR.wav") if useSR else tts_input[i][4], 22050, wav_norm.astype(np.int16)) | |
if useSR: | |
self.models_manager.init_model("nuwave2") | |
self.models_manager.models("nuwave2").sr_audio(tts_input[i][4].replace(".wav", "_preSR.wav"), tts_input[i][4].replace(".wav", "_preCleanup.wav") if useCleanup else tts_input[i][4]) | |
os.remove(tts_input[i][4].replace(".wav", "_preSR.wav")) | |
if useCleanup: | |
self.models_manager.init_model("deepfilternet2") | |
self.models_manager.models("deepfilternet2").cleanup_audio(tts_input[i][4].replace(".wav", "_preCleanup.wav"), tts_input[i][4]) | |
os.remove(tts_input[i][4].replace(".wav", "_preCleanup.wav")) | |
if outputJSON: | |
for ri, record in enumerate(tts_input): | |
# tts_input: sequence, pitch, duration, pace, tempFileLocation, outPath, outFolder | |
output_fname = tts_input[ri][5].replace(".wav", ".json") | |
containing_folder = "/".join(output_fname.split("/")[:-1]) | |
os.makedirs(containing_folder, exist_ok=True) | |
with open(output_fname, "w+") as f: | |
data = {} | |
data["modelType"] = "xVAPitch" | |
data["inputSequence"] = str(tts_input[ri][0]) | |
data["pacing"] = float(tts_input[ri][3]) | |
data["letters"] = [char.replace("{", "").replace("}", "") for char in list(cleaned_text_sequences[ri].split("|"))] | |
data["currentVoice"] = self.ckpt_path.split("/")[-1].replace(".pt", "") | |
# data["resetEnergy"] = [float(val) for val in list(energy_pred[ri].cpu().detach().numpy())] | |
data["resetEnergy"] = [float(1) for val in list(pitch_pred[ri][0].cpu().detach().numpy())] | |
data["resetPitch"] = [float(val) for val in list(pitch_pred[ri][0].cpu().detach().numpy())] | |
data["resetDurs"] = [float(val) for val in list(dur_pred[ri][0].cpu().detach().numpy())] | |
data["ampFlatCounter"] = 0 | |
data["pitchNew"] = data["resetPitch"] | |
data["energyNew"] = data["resetEnergy"] | |
data["dursNew"] = data["resetDurs"] | |
f.write(json.dumps(data, indent=4)) | |
return "" | |
# Split words by space, while also breaking out the \land[code][text] formatting | |
def splitWords (self, sequence, addSpace=False): | |
words = [] | |
for word in sequence: | |
if word.startswith("\\lang["): | |
words.append(word.split("][")[0]+"][") | |
word = word.split("][")[1] | |
for char in ["}","]","[","{"]: | |
if word.startswith(char): | |
words.append(char) | |
word = word[1:] | |
end_extras = [] | |
for char in ["}","]","[","{"]: | |
if word.endswith(char): | |
end_extras.append(char) | |
word = word[:-1] | |
words.append(word) | |
end_extras.reverse() | |
for extra in end_extras: | |
words.append(extra) | |
if addSpace: | |
words.append(" ") | |
return words | |
def preprocess_prompt_language (self, sequence, base_lang): | |
# Separate the ARPAbet brackets from punctuation | |
sequence = sequence.replace("}.", "} .") | |
sequence = sequence.replace("}!", "} !") | |
sequence = sequence.replace("}?", "} ?") | |
sequence = sequence.replace("},", "} ,") | |
sequence = sequence.replace("}\"", "} \"") | |
sequence = sequence.replace("}'", "} '") | |
sequence = sequence.replace("}-", "} -") | |
sequence = sequence.replace("})", "} )") | |
sequence = sequence.replace(".{", ". {") | |
sequence = sequence.replace("!{", "! {") | |
sequence = sequence.replace("?{", "? {") | |
sequence = sequence.replace(",{", ", {") | |
sequence = sequence.replace("\"{", "\" {") | |
sequence = sequence.replace("'{", "' {") | |
sequence = sequence.replace("-{", "- {") | |
sequence = sequence.replace("({", "( {") | |
# Prepare the input sequence for processing. Do a few times to catch edge cases | |
sequence = self.splitWords(sequence.split(" "), True) | |
sequence = self.splitWords(sequence) | |
sequence = self.splitWords(sequence) | |
sequence = self.splitWords(sequence) | |
subSequences = [] | |
openedLangs = 0 | |
langs_stack = [base_lang] | |
for word in sequence: | |
skip_word = False | |
if word.startswith("\\lang["): | |
openedLangs += 1 | |
langs_stack.append(word.split("lang[")[1].split("]")[0]) | |
skip_word = True | |
if word.endswith("]"): | |
openedLangs -= 1 | |
langs_stack.pop() | |
skip_word = True | |
# Add the word to the list if not skipping it, if it's not empty, or it's not a second space in a row | |
if not skip_word and len(word) and (word!=" " or len(subSequences)==0 or subSequences[-1][list(subSequences[-1].keys())[0]]!=" "): | |
subSequences.append({langs_stack[-1]: word}) | |
subSequences_collapsed = [] | |
current_open_arpabet = [] | |
last_lang = None | |
is_in_arpabet = False | |
# Collapse groups of inlined ARPABet symbols, to have them treated as such | |
for subSequence in subSequences: | |
ss_lang = list(subSequence.keys())[0] | |
ss_val = subSequence[ss_lang] | |
if ss_lang is not last_lang: | |
if len(current_open_arpabet): | |
subSequences_collapsed.append({ss_lang: "{"+" ".join(current_open_arpabet).replace(" "," ")+"}"}) | |
current_open_arpabet = [] | |
last_lang = ss_lang | |
if ss_val.strip()=="{": | |
is_in_arpabet = True | |
elif ss_val.strip()=="}": | |
subSequences_collapsed.append({ss_lang: "{"+" ".join(current_open_arpabet).replace(" "," ")+"}"}) | |
current_open_arpabet = [] | |
is_in_arpabet = False | |
else: | |
if is_in_arpabet: | |
current_open_arpabet.append(ss_val) | |
else: | |
subSequences_collapsed.append({ss_lang: ss_val}) | |
return subSequences_collapsed | |
def getG2P (self, text, base_lang): | |
sequenceSplitByLanguage = self.preprocess_prompt_language(text, base_lang) | |
# Make sure all languages' text processors are initialized | |
for subSequence in sequenceSplitByLanguage: | |
langCode = list(subSequence.keys())[0] | |
if langCode not in self.lang_tp.keys(): | |
self.lang_tp[langCode] = get_text_preprocessor(langCode, self.base_dir, logger=self.logger) | |
returnString = "{" | |
langs_stack = [base_lang] | |
last_lang = base_lang | |
for subSequence in sequenceSplitByLanguage: | |
langCode = list(subSequence.keys())[0] | |
subSeq = subSequence[langCode] | |
sequence, cleaned_text = self.lang_tp[langCode].text_to_sequence(subSeq) | |
if langCode != last_lang: | |
last_lang = langCode | |
if len(langs_stack)>1 and langs_stack[-2]==langCode: | |
langs_stack.pop() | |
if returnString[-1]=="}": | |
returnString = returnString[:-1] | |
returnString += "]}" | |
else: | |
langs_stack.append(langCode) | |
if returnString[-1]=="{": | |
returnString = returnString[:-1] | |
returnString += f'\\lang[{langCode}][' + "{" | |
returnString += " ".join([symb for symb in cleaned_text.split("|") if symb != "<PAD>"]).replace("_", "} {") | |
if returnString[-1]=="{": | |
returnString = returnString[:-1] | |
else: | |
returnString = returnString+"}" | |
returnString = returnString.replace(".}", "}.") | |
returnString = returnString.replace(",}", "},") | |
returnString = returnString.replace("!}", "}!") | |
returnString = returnString.replace("?}", "}?") | |
returnString = returnString.replace("]}", "}]") | |
returnString = returnString.replace("}]}", "}]") | |
returnString = returnString.replace("{"+"}", "") | |
returnString = returnString.replace("}"+"}", "}") | |
returnString = returnString.replace("{"+"{", "{") | |
return returnString | |
def infer(self, plugin_manager, text, out_path, vocoder, speaker_i, pace=1.0, editor_data=None, old_sequence=None, globalAmplitudeModifier=None, base_lang="en", base_emb=None, useSR=False, useCleanup=False): | |
sequenceSplitByLanguage = self.preprocess_prompt_language(text, base_lang) | |
# Make sure all languages' text processors are initialized | |
for subSequence in sequenceSplitByLanguage: | |
langCode = list(subSequence.keys())[0] | |
if langCode not in self.lang_tp.keys(): | |
self.lang_tp[langCode] = get_text_preprocessor(langCode, self.base_dir, logger=self.logger) | |
try: | |
pad_symb = len(ALL_SYMBOLS)-2 | |
all_sequence = [] | |
all_cleaned_text = [] | |
all_text = [] | |
all_lang_ids = [] | |
# Collapse same-language words into phrases, so that heteronyms can still be detected | |
sequenceSplitByLanguage_grouped = [] | |
last_lang_group = None | |
group = "" | |
for ssi, subSequence in enumerate(sequenceSplitByLanguage): | |
if list(subSequence.keys())[0]!=last_lang_group: | |
if last_lang_group is not None: | |
sequenceSplitByLanguage_grouped.append({last_lang_group: group}) | |
group = "" | |
last_lang_group = list(subSequence.keys())[0] | |
group += subSequence[last_lang_group] | |
if len(group): | |
sequenceSplitByLanguage_grouped.append({last_lang_group: group}) | |
for ssi, subSequence in enumerate(sequenceSplitByLanguage_grouped): | |
langCode = list(subSequence.keys())[0] | |
subSeq = subSequence[langCode] | |
sequence, cleaned_text = self.lang_tp[langCode].text_to_sequence(subSeq) | |
if ssi<len(sequenceSplitByLanguage_grouped)-1: | |
sequence = sequence + [pad_symb] | |
all_sequence.append(sequence) | |
all_cleaned_text += ("|"+cleaned_text) if len(all_cleaned_text) else cleaned_text | |
if ssi<len(sequenceSplitByLanguage_grouped)-1: | |
all_cleaned_text = all_cleaned_text + ["|<PAD>"] | |
all_text.append(torch.LongTensor(sequence)) | |
language_id = self.language_id_mapping[langCode] | |
all_lang_ids += [language_id for _ in range(len(sequence))] | |
except ValueError as e: | |
self.logger.info("====") | |
self.logger.info(str(e)) | |
self.logger.info("====--") | |
if "not in list" in str(e): | |
symbol_not_in_list = str(e).split("is not in list")[0].split("ValueError:")[-1].replace("'", "").strip() | |
return f'ERR: ARPABET_NOT_IN_LIST: {symbol_not_in_list}' | |
all_cleaned_text = "".join(all_cleaned_text) | |
text = torch.cat(all_text, dim=0) | |
text = pad_sequence([text], batch_first=True).to(self.models_manager.device) | |
with torch.no_grad(): | |
if old_sequence is not None: | |
old_sequence = re.sub(r'[^a-zA-Z\s\(\)\[\]0-9\?\.\,\!\'\{\}\_\@]+', '', old_sequence) | |
old_sequence, clean_old_sequence = self.lang_tp[base_lang].text_to_sequence(old_sequence)#, "english_basic", ['english_cleaners']) | |
old_sequence = torch.LongTensor(old_sequence) | |
old_sequence = pad_sequence([old_sequence], batch_first=True).to(self.models_manager.device) | |
lang_ids = torch.tensor(all_lang_ids).to(self.models_manager.device) | |
num_embs = text.shape[1] | |
base_emb = [float(val) for val in base_emb.split(",")] if "," in base_emb else self.base_emb | |
speaker_embs = [torch.tensor(base_emb).unsqueeze(dim=0)[0].unsqueeze(-1)] | |
speaker_embs = torch.stack(speaker_embs, dim=0).to(self.models_manager.device)#.unsqueeze(-1) | |
speaker_embs = speaker_embs.repeat(1,1,num_embs) | |
# Do interpolations of speaker style embeddings | |
if editor_data is not None: | |
editorStyles = editor_data[-1] | |
if editorStyles is not None: | |
style_keys = list(editorStyles.keys()) | |
for style_key in style_keys: | |
emb = editorStyles[style_key]["embedding"] | |
sliders_vals = editorStyles[style_key]["sliders"] | |
style_embs = [torch.tensor(emb).unsqueeze(dim=0)[0].unsqueeze(-1)] | |
style_embs = torch.stack(style_embs, dim=0).to(self.models_manager.device)#.unsqueeze(-1) | |
style_embs = style_embs.repeat(1,1,num_embs) | |
sliders_vals = torch.tensor(sliders_vals).to(self.models_manager.device) | |
speaker_embs = speaker_embs*(1-sliders_vals) + sliders_vals*style_embs | |
speaker_embs = speaker_embs.float() | |
lang_embs = lang_ids # TODO, use pre-extracted trained language embeddings, for interpolation | |
out = self.model.infer_advanced(self.logger, plugin_manager, [all_cleaned_text], text, lang_embs=lang_embs, speaker_embs=speaker_embs, pace=pace, editor_data=editor_data, old_sequence=old_sequence) | |
if isinstance(out, str): | |
return f'ERR:{out}' | |
else: | |
output_wav, dur_pred, pitch_pred, energy_pred, em_pred, start_index, end_index, wav_mult = out | |
[em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred] = em_pred | |
wav = output_wav.squeeze().cpu().detach().numpy() | |
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) | |
if wav_mult is not None: | |
wav_norm = wav_norm * wav_mult | |
if useCleanup: | |
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe' | |
if useSR: | |
scipy.io.wavfile.write(out_path.replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16)) | |
else: | |
scipy.io.wavfile.write(out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16)) | |
stream = ffmpeg.input(out_path.replace(".wav", "_preCleanupPreFFmpeg.wav")) | |
ffmpeg_options = {"ar": 48000} | |
output_path = out_path.replace(".wav", "_preCleanup.wav") | |
stream = ffmpeg.output(stream, output_path, **ffmpeg_options) | |
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True)) | |
os.remove(out_path.replace(".wav", "_preCleanupPreFFmpeg.wav")) | |
else: | |
scipy.io.wavfile.write(out_path.replace(".wav", "_preSR.wav") if useSR else out_path, 22050, wav_norm.astype(np.int16)) | |
if useSR: | |
self.models_manager.init_model("nuwave2") | |
self.models_manager.models("nuwave2").sr_audio(out_path.replace(".wav", "_preSR.wav"), out_path.replace(".wav", "_preCleanup.wav") if useCleanup else out_path) | |
if useCleanup: | |
self.models_manager.init_model("deepfilternet2") | |
self.models_manager.models("deepfilternet2").cleanup_audio(out_path.replace(".wav", "_preCleanup.wav"), out_path) | |
[pitch, durations, energy, em_angry, em_happy, em_sad, em_surprise] = [ | |
pitch_pred.squeeze().cpu().detach().numpy(), | |
dur_pred.squeeze().cpu().detach().numpy(), | |
energy_pred.cpu().detach().numpy() if energy_pred is not None else [], | |
em_angry_pred.squeeze().cpu().detach().numpy() if em_angry_pred is not None else [], | |
em_happy_pred.squeeze().cpu().detach().numpy() if em_happy_pred is not None else [], | |
em_sad_pred.squeeze().cpu().detach().numpy() if em_sad_pred is not None else [], | |
em_surprise_pred.squeeze().cpu().detach().numpy() if em_surprise_pred is not None else [], | |
] | |
pitch = [float(v) for v in pitch] | |
durations = [float(v) for v in durations] | |
energy = [float(v) for v in energy] | |
em_angry = [float(v) for v in em_angry] | |
em_happy = [float(v) for v in em_happy] | |
em_sad = [float(v) for v in em_sad] | |
em_surprise = [float(v) for v in em_surprise] | |
del pitch_pred, dur_pred, energy_pred, text, sequence | |
return { | |
"pitch": pitch, | |
"durations": durations, | |
"energy": energy, | |
"em_angry": em_angry, | |
"em_happy": em_happy, | |
"em_sad": em_sad, | |
"em_surprise": em_surprise, | |
"editorStyles": json.dumps(editorStyles), | |
"arpabet": all_cleaned_text | |
} | |
def set_device (self, device): | |
self.device = device | |
self.model = self.model.to(device) | |
self.model.pitch_emb_values = self.model.pitch_emb_values.to(device) | |
self.model.device = device | |