Pendrokar's picture
ndarray serialization fix
bd8f4e0
raw
history blame
No virus
37.7 kB
import os
import re
import json
import codecs
import ffmpeg
import argparse
import platform
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
import scipy
import scipy.io.wavfile
import librosa
from scipy.io.wavfile import write
import numpy as np
try:
import sys
sys.path.append(".")
from resources.app.python.xvapitch.text import ALL_SYMBOLS, get_text_preprocessor, lang_names
from resources.app.python.xvapitch.xvapitch_model import xVAPitch as xVAPitchModel
except ModuleNotFoundError:
try:
from python.xvapitch.text import ALL_SYMBOLS, get_text_preprocessor, lang_names
from python.xvapitch.xvapitch_model import xVAPitch as xVAPitchModel
except ModuleNotFoundError:
try:
from xvapitch.text import ALL_SYMBOLS, get_text_preprocessor, lang_names
from xvapitch.xvapitch_model import xVAPitch as xVAPitchModel
except ModuleNotFoundError:
from text import ALL_SYMBOLS, get_text_preprocessor, lang_names
from xvapitch_model import xVAPitch as xVAPitchModel
class xVAPitch(object):
def __init__(self, logger, PROD, device, models_manager):
super(xVAPitch, self).__init__()
self.logger = logger
self.PROD = PROD
self.models_manager = models_manager
self.device = device
self.ckpt_path = None
self.arpabet_dict = {}
# torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark = False
self.base_dir = f'{"./resources/app" if self.PROD else "."}/python/xvapitch/text'
self.lang_tp = {}
self.lang_tp["en"] = get_text_preprocessor("en", self.base_dir, logger=self.logger)
self.language_id_mapping = {name: i for i, name in enumerate(sorted(list(lang_names.keys())))}
self.pitch_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/pitch_emb.npy')).unsqueeze(0).unsqueeze(-1)
self.angry_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/angry.npy')).unsqueeze(0).unsqueeze(-1)
self.happy_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/happy.npy')).unsqueeze(0).unsqueeze(-1)
self.sad_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/sad.npy')).unsqueeze(0).unsqueeze(-1)
self.surprise_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/surprise.npy')).unsqueeze(0).unsqueeze(-1)
self.base_lang = "en"
self.init_model()
self.model.pitch_emb_values = self.pitch_emb_values.to(self.models_manager.device)
self.model.angry_emb_values = self.angry_emb_values.to(self.models_manager.device)
self.model.happy_emb_values = self.happy_emb_values.to(self.models_manager.device)
self.model.sad_emb_values = self.sad_emb_values.to(self.models_manager.device)
self.model.surprise_emb_values = self.surprise_emb_values.to(self.models_manager.device)
self.isReady = True
def init_model (self):
parser = argparse.ArgumentParser()
args = parser.parse_args()
# Params from training
args.pitch = 1
args.pe_scaling = 0.1
args.expanded_flow = 0
args.ow_flow = 0
args.energy = 0
self.model = xVAPitchModel(args).to(self.device)
self.model.eval()
self.model.device = self.device
def load_state_dict (self, ckpt_path, ckpt, n_speakers=1, base_lang="en"):
self.logger.info(f'load_state_dict base_lang: {base_lang}')
if base_lang not in self.lang_tp.keys():
self.lang_tp[base_lang] = get_text_preprocessor(base_lang, self.base_dir, logger=self.logger)
self.base_lang = base_lang
self.ckpt_path = ckpt_path
if os.path.exists(ckpt_path.replace(".pt", ".json")):
with open(ckpt_path.replace(".pt", ".json"), "r") as f:
data = json.load(f)
self.base_emb = data["games"][0]["base_speaker_emb"]
if 'model' in ckpt:
ckpt = ckpt['model']
if ckpt["emb_l.weight"].shape[0]==31:
self.model.emb_l = nn.Embedding(31, self.model.embedded_language_dim).to(self.models_manager.device)
elif ckpt["emb_l.weight"].shape[0]==50:
num_languages = 50
self.model.emb_l = nn.Embedding(num_languages, self.model.embedded_language_dim).to(self.models_manager.device)
self.model.load_state_dict(ckpt, strict=False)
self.model = self.model.float()
self.model.eval()
def init_arpabet_dicts (self):
if len(list(self.arpabet_dict.keys()))==0:
self.refresh_arpabet_dicts()
def refresh_arpabet_dicts (self):
self.arpabet_dict = {}
json_files = sorted(os.listdir(f'{"./resources/app" if self.PROD else "."}/arpabet'))
json_files = [fname for fname in json_files if fname.endswith(".json")]
for fname in json_files:
with codecs.open(f'{"./resources/app" if self.PROD else "."}/arpabet/{fname}', encoding="utf-8") as f:
json_data = json.load(f)
for word in list(json_data["data"].keys()):
if json_data["data"][word]["enabled"]==True:
self.arpabet_dict[word] = json_data["data"][word]["arpabet"]
def run_speech_to_speech (self, audiopath, audio_out_path, style_emb, models_manager, plugin_manager, vc_strength=1, useSR=False, useCleanup=False):
if ".wav" in style_emb:
self.logger.info(f'Getting style emb from: {style_emb}')
style_emb = models_manager.models("speaker_rep").compute_embedding(style_emb).squeeze()
else:
self.logger.info(f'Given style emb')
style_emb = torch.tensor(style_emb).squeeze()
try:
content_emb = models_manager.models("speaker_rep").compute_embedding(audiopath).squeeze()
except:
return "TOO_SHORT"
style_emb = F.normalize(style_emb.unsqueeze(0), dim=1).unsqueeze(-1).to(self.models_manager.device)
content_emb = F.normalize(content_emb.unsqueeze(0), dim=1).unsqueeze(-1).to(self.models_manager.device)
content_emb = content_emb + (-(vc_strength-1) * (style_emb - content_emb))
y, sr = librosa.load(audiopath, sr=22050)
D = librosa.stft(
y=y,
n_fft=1024,
hop_length=256,
win_length=1024,
pad_mode="reflect",
window="hann",
center=True,
)
spec = np.abs(D).astype(np.float32)
ref_spectrogram = torch.FloatTensor(spec).unsqueeze(0)
y_lengths = torch.tensor([ref_spectrogram.size(-1)]).to(self.models_manager.device)
y = ref_spectrogram.to(self.models_manager.device)
wav = self.model.voice_conversion(y=y, y_lengths=y_lengths, spk1_emb=content_emb, spk2_emb=style_emb)
wav = wav.squeeze().cpu().detach().numpy()
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
if useCleanup:
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe'
if useSR:
scipy.io.wavfile.write(audio_out_path.replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16))
else:
scipy.io.wavfile.write(audio_out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16))
stream = ffmpeg.input(audio_out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"))
ffmpeg_options = {"ar": 48000}
output_path = audio_out_path.replace(".wav", "_preCleanup.wav")
stream = ffmpeg.output(stream, output_path, **ffmpeg_options)
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True))
os.remove(audio_out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"))
else:
scipy.io.wavfile.write(audio_out_path.replace(".wav", "_preSR.wav") if useSR else audio_out_path, 22050, wav_norm.astype(np.int16))
if useSR:
self.models_manager.init_model("nuwave2")
self.models_manager.models("nuwave2").sr_audio(audio_out_path.replace(".wav", "_preSR.wav"), audio_out_path.replace(".wav", "_preCleanup.wav") if useCleanup else audio_out_path)
if useCleanup:
self.models_manager.init_model("deepfilternet2")
self.models_manager.models("deepfilternet2").cleanup_audio(audio_out_path.replace(".wav", "_preCleanup.wav"), audio_out_path)
return
def infer_batch(self, plugin_manager, linesBatch, outputJSON, vocoder, speaker_i, old_sequence=None, useSR=False, useCleanup=False):
print(f'Inferring batch of {len(linesBatch)} lines')
text_sequences = []
cleaned_text_sequences = []
lang_embs = []
speaker_embs = []
# [sequence, pitch, duration, pace, tempFileLocation, outPath, outFolder, pitch_amp, base_lang, base_emb, vc_content, vc_style]
vc_input = []
tts_input = []
for ri,record in enumerate(linesBatch):
if record[-2]: # If a VC content file has been given, handle this as VC
vc_input.append(record)
else:
tts_input.append(record)
# =================
# ======= Handle VC
# =================
if len(vc_input):
for ri,record in enumerate(vc_input):
content_emb = self.models_manager.models("speaker_rep").compute_embedding(record[-2]).squeeze()
style_emb = self.models_manager.models("speaker_rep").compute_embedding(record[-1]).squeeze()
# content_emb = F.normalize(content_emb.unsqueeze(0), dim=1).squeeze(0)
# style_emb = F.normalize(style_emb.unsqueeze(0), dim=1).squeeze(0)
content_emb = content_emb.unsqueeze(0).unsqueeze(-1).to(self.models_manager.device)
style_emb = style_emb.unsqueeze(0).unsqueeze(-1).to(self.models_manager.device)
y, sr = librosa.load(record[-2], sr=22050)
D = librosa.stft(
y=y,
n_fft=1024,
hop_length=256,
win_length=1024,
pad_mode="reflect",
window="hann",
center=True,
)
spec = np.abs(D).astype(np.float32)
ref_spectrogram = torch.FloatTensor(spec).unsqueeze(0)
y_lengths = torch.tensor([ref_spectrogram.size(-1)]).to(self.models_manager.device)
y = ref_spectrogram.to(self.models_manager.device)
# Run Voice Conversion
self.model.logger = self.logger
wav = self.model.voice_conversion(y=y, y_lengths=y_lengths, spk1_emb=content_emb, spk2_emb=style_emb)
wav = wav.squeeze().cpu().detach().numpy()
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
if useCleanup:
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe'
if useSR:
scipy.io.wavfile.write(tts_input[ri][4].replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16))
else:
scipy.io.wavfile.write(tts_input[ri][4].replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16))
stream = ffmpeg.input(tts_input[ri][4].replace(".wav", "_preCleanupPreFFmpeg.wav"))
ffmpeg_options = {"ar": 48000}
output_path = tts_input[ri][4].replace(".wav", "_preCleanup.wav")
stream = ffmpeg.output(stream, output_path, **ffmpeg_options)
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True))
os.remove(tts_input[ri][4].replace(".wav", "_preCleanupPreFFmpeg.wav"))
else:
scipy.io.wavfile.write(vc_input[ri][4].replace(".wav", "_preSR.wav") if useSR else vc_input[ri][4], 22050, wav_norm.astype(np.int16))
if useSR:
self.models_manager.init_model("nuwave2")
self.models_manager.models("nuwave2").sr_audio(vc_input[ri][4].replace(".wav", "_preSR.wav"), vc_input[ri][4].replace(".wav", "_preCleanup.wav") if useCleanup else vc_input[ri][4])
os.remove(vc_input[ri][4].replace(".wav", "_preSR.wav"))
if useCleanup:
self.models_manager.init_model("deepfilternet2")
self.models_manager.models("deepfilternet2").cleanup_audio(vc_input[ri][4].replace(".wav", "_preCleanup.wav"), vc_input[ri][4])
os.remove(vc_input[ri][4].replace(".wav", "_preCleanup.wav"))
# ==================
# ======= Handle TTS
# ==================
if len(tts_input):
lang_embs_sizes = []
for ri,record in enumerate(tts_input):
# Pre-process text
text = record[0].replace("/lang", "\\lang")
base_lang = record[-4]
self.logger.info(f'[infer_batch] text: {text}')
sequenceSplitByLanguage = self.preprocess_prompt_language(text, base_lang)
# Make sure all languages' text processors are initialized
for subSequence in sequenceSplitByLanguage:
langCode = list(subSequence.keys())[0]
if langCode not in self.lang_tp.keys():
self.lang_tp[langCode] = get_text_preprocessor(langCode, self.base_dir, logger=self.logger)
try:
pad_symb = len(ALL_SYMBOLS)-2
all_sequence = []
all_cleaned_text = []
all_text = []
all_lang_ids = []
# Collapse same-language words into phrases, so that heteronyms can still be detected
sequenceSplitByLanguage_grouped = []
last_lang_group = None
group = ""
for ssi, subSequence in enumerate(sequenceSplitByLanguage):
if list(subSequence.keys())[0]!=last_lang_group:
if last_lang_group is not None:
sequenceSplitByLanguage_grouped.append({last_lang_group: group})
group = ""
last_lang_group = list(subSequence.keys())[0]
group += subSequence[last_lang_group]
if len(group):
sequenceSplitByLanguage_grouped.append({last_lang_group: group})
for ssi, subSequence in enumerate(sequenceSplitByLanguage_grouped):
langCode = list(subSequence.keys())[0]
subSeq = subSequence[langCode]
sequence, cleaned_text = self.lang_tp[langCode].text_to_sequence(subSeq)
if ssi<len(sequenceSplitByLanguage_grouped)-1:
sequence = sequence + [pad_symb]
all_sequence.append(sequence)
all_cleaned_text += ("|"+cleaned_text) if len(all_cleaned_text) else cleaned_text
if ssi<len(sequenceSplitByLanguage_grouped)-1:
all_cleaned_text = all_cleaned_text + ["|<PAD>"]
all_text.append(torch.LongTensor(sequence))
language_id = self.language_id_mapping[langCode]
all_lang_ids += [language_id for _ in range(len(sequence))]
except ValueError as e:
self.logger.info("====")
self.logger.info(str(e))
self.logger.info("====--")
if "not in list" in str(e):
symbol_not_in_list = str(e).split("is not in list")[0].split("ValueError:")[-1].replace("'", "").strip()
return f'ERR: ARPABET_NOT_IN_LIST: {symbol_not_in_list}'
all_cleaned_text = "".join(all_cleaned_text)
text = torch.cat(all_text, dim=0)
cleaned_text_sequences.append(all_cleaned_text)
text = torch.LongTensor(text)
text_sequences.append(text)
lang_ids = torch.tensor(all_lang_ids).to(self.models_manager.device)
lang_embs.append(lang_ids)
lang_embs_sizes.append(lang_ids.shape[0])
speaker_embs.append(torch.tensor(tts_input[ri][-3]).unsqueeze(-1))
lang_embs = pad_sequence(lang_embs, batch_first=True).to(self.models_manager.device)
text_sequences = pad_sequence(text_sequences, batch_first=True).to(self.models_manager.device)
speaker_embs = pad_sequence(speaker_embs, batch_first=True).to(self.models_manager.device)
pace = torch.tensor([record[3] for record in tts_input]).unsqueeze(1).to(self.device)
pitch_amp = torch.tensor([record[7] for record in tts_input]).unsqueeze(1).to(self.device)
# Could pass indexes (and get them returned) to the tts inference fn
# Do the same to the vc infer fn
# Then marge them into their place in an output array?
out = self.model.infer_advanced(self.logger, plugin_manager, [cleaned_text_sequences], text_sequences, lang_embs=lang_embs, speaker_embs=speaker_embs, pace=pace, old_sequence=None, pitch_amp=pitch_amp)
if isinstance(out, str):
return out
else:
output_wav, dur_pred, pitch_pred, energy_pred, _, _, _, _ = out
for i,wav in enumerate(output_wav):
wav = wav.squeeze().cpu().detach().numpy()
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
if useCleanup:
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe'
if useSR:
scipy.io.wavfile.write(tts_input[i][4].replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16))
else:
scipy.io.wavfile.write(tts_input[i][4].replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16))
stream = ffmpeg.input(tts_input[i][4].replace(".wav", "_preCleanupPreFFmpeg.wav"))
ffmpeg_options = {"ar": 48000}
output_path = tts_input[i][4].replace(".wav", "_preCleanup.wav")
stream = ffmpeg.output(stream, output_path, **ffmpeg_options)
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True))
os.remove(tts_input[i][4].replace(".wav", "_preCleanupPreFFmpeg.wav"))
else:
scipy.io.wavfile.write(tts_input[i][4].replace(".wav", "_preSR.wav") if useSR else tts_input[i][4], 22050, wav_norm.astype(np.int16))
if useSR:
self.models_manager.init_model("nuwave2")
self.models_manager.models("nuwave2").sr_audio(tts_input[i][4].replace(".wav", "_preSR.wav"), tts_input[i][4].replace(".wav", "_preCleanup.wav") if useCleanup else tts_input[i][4])
os.remove(tts_input[i][4].replace(".wav", "_preSR.wav"))
if useCleanup:
self.models_manager.init_model("deepfilternet2")
self.models_manager.models("deepfilternet2").cleanup_audio(tts_input[i][4].replace(".wav", "_preCleanup.wav"), tts_input[i][4])
os.remove(tts_input[i][4].replace(".wav", "_preCleanup.wav"))
if outputJSON:
for ri, record in enumerate(tts_input):
# tts_input: sequence, pitch, duration, pace, tempFileLocation, outPath, outFolder
output_fname = tts_input[ri][5].replace(".wav", ".json")
containing_folder = "/".join(output_fname.split("/")[:-1])
os.makedirs(containing_folder, exist_ok=True)
with open(output_fname, "w+") as f:
data = {}
data["modelType"] = "xVAPitch"
data["inputSequence"] = str(tts_input[ri][0])
data["pacing"] = float(tts_input[ri][3])
data["letters"] = [char.replace("{", "").replace("}", "") for char in list(cleaned_text_sequences[ri].split("|"))]
data["currentVoice"] = self.ckpt_path.split("/")[-1].replace(".pt", "")
# data["resetEnergy"] = [float(val) for val in list(energy_pred[ri].cpu().detach().numpy())]
data["resetEnergy"] = [float(1) for val in list(pitch_pred[ri][0].cpu().detach().numpy())]
data["resetPitch"] = [float(val) for val in list(pitch_pred[ri][0].cpu().detach().numpy())]
data["resetDurs"] = [float(val) for val in list(dur_pred[ri][0].cpu().detach().numpy())]
data["ampFlatCounter"] = 0
data["pitchNew"] = data["resetPitch"]
data["energyNew"] = data["resetEnergy"]
data["dursNew"] = data["resetDurs"]
f.write(json.dumps(data, indent=4))
return ""
# Split words by space, while also breaking out the \land[code][text] formatting
def splitWords (self, sequence, addSpace=False):
words = []
for word in sequence:
if word.startswith("\\lang["):
words.append(word.split("][")[0]+"][")
word = word.split("][")[1]
for char in ["}","]","[","{"]:
if word.startswith(char):
words.append(char)
word = word[1:]
end_extras = []
for char in ["}","]","[","{"]:
if word.endswith(char):
end_extras.append(char)
word = word[:-1]
words.append(word)
end_extras.reverse()
for extra in end_extras:
words.append(extra)
if addSpace:
words.append(" ")
return words
def preprocess_prompt_language (self, sequence, base_lang):
# Separate the ARPAbet brackets from punctuation
sequence = sequence.replace("}.", "} .")
sequence = sequence.replace("}!", "} !")
sequence = sequence.replace("}?", "} ?")
sequence = sequence.replace("},", "} ,")
sequence = sequence.replace("}\"", "} \"")
sequence = sequence.replace("}'", "} '")
sequence = sequence.replace("}-", "} -")
sequence = sequence.replace("})", "} )")
sequence = sequence.replace(".{", ". {")
sequence = sequence.replace("!{", "! {")
sequence = sequence.replace("?{", "? {")
sequence = sequence.replace(",{", ", {")
sequence = sequence.replace("\"{", "\" {")
sequence = sequence.replace("'{", "' {")
sequence = sequence.replace("-{", "- {")
sequence = sequence.replace("({", "( {")
# Prepare the input sequence for processing. Do a few times to catch edge cases
sequence = self.splitWords(sequence.split(" "), True)
sequence = self.splitWords(sequence)
sequence = self.splitWords(sequence)
sequence = self.splitWords(sequence)
subSequences = []
openedLangs = 0
langs_stack = [base_lang]
for word in sequence:
skip_word = False
if word.startswith("\\lang["):
openedLangs += 1
langs_stack.append(word.split("lang[")[1].split("]")[0])
skip_word = True
if word.endswith("]"):
openedLangs -= 1
langs_stack.pop()
skip_word = True
# Add the word to the list if not skipping it, if it's not empty, or it's not a second space in a row
if not skip_word and len(word) and (word!=" " or len(subSequences)==0 or subSequences[-1][list(subSequences[-1].keys())[0]]!=" "):
subSequences.append({langs_stack[-1]: word})
subSequences_collapsed = []
current_open_arpabet = []
last_lang = None
is_in_arpabet = False
# Collapse groups of inlined ARPABet symbols, to have them treated as such
for subSequence in subSequences:
ss_lang = list(subSequence.keys())[0]
ss_val = subSequence[ss_lang]
if ss_lang is not last_lang:
if len(current_open_arpabet):
subSequences_collapsed.append({ss_lang: "{"+" ".join(current_open_arpabet).replace(" "," ")+"}"})
current_open_arpabet = []
last_lang = ss_lang
if ss_val.strip()=="{":
is_in_arpabet = True
elif ss_val.strip()=="}":
subSequences_collapsed.append({ss_lang: "{"+" ".join(current_open_arpabet).replace(" "," ")+"}"})
current_open_arpabet = []
is_in_arpabet = False
else:
if is_in_arpabet:
current_open_arpabet.append(ss_val)
else:
subSequences_collapsed.append({ss_lang: ss_val})
return subSequences_collapsed
def getG2P (self, text, base_lang):
sequenceSplitByLanguage = self.preprocess_prompt_language(text, base_lang)
# Make sure all languages' text processors are initialized
for subSequence in sequenceSplitByLanguage:
langCode = list(subSequence.keys())[0]
if langCode not in self.lang_tp.keys():
self.lang_tp[langCode] = get_text_preprocessor(langCode, self.base_dir, logger=self.logger)
returnString = "{"
langs_stack = [base_lang]
last_lang = base_lang
for subSequence in sequenceSplitByLanguage:
langCode = list(subSequence.keys())[0]
subSeq = subSequence[langCode]
sequence, cleaned_text = self.lang_tp[langCode].text_to_sequence(subSeq)
if langCode != last_lang:
last_lang = langCode
if len(langs_stack)>1 and langs_stack[-2]==langCode:
langs_stack.pop()
if returnString[-1]=="}":
returnString = returnString[:-1]
returnString += "]}"
else:
langs_stack.append(langCode)
if returnString[-1]=="{":
returnString = returnString[:-1]
returnString += f'\\lang[{langCode}][' + "{"
returnString += " ".join([symb for symb in cleaned_text.split("|") if symb != "<PAD>"]).replace("_", "} {")
if returnString[-1]=="{":
returnString = returnString[:-1]
else:
returnString = returnString+"}"
returnString = returnString.replace(".}", "}.")
returnString = returnString.replace(",}", "},")
returnString = returnString.replace("!}", "}!")
returnString = returnString.replace("?}", "}?")
returnString = returnString.replace("]}", "}]")
returnString = returnString.replace("}]}", "}]")
returnString = returnString.replace("{"+"}", "")
returnString = returnString.replace("}"+"}", "}")
returnString = returnString.replace("{"+"{", "{")
return returnString
def infer(self, plugin_manager, text, out_path, vocoder, speaker_i, pace=1.0, editor_data=None, old_sequence=None, globalAmplitudeModifier=None, base_lang="en", base_emb=None, useSR=False, useCleanup=False):
sequenceSplitByLanguage = self.preprocess_prompt_language(text, base_lang)
# Make sure all languages' text processors are initialized
for subSequence in sequenceSplitByLanguage:
langCode = list(subSequence.keys())[0]
if langCode not in self.lang_tp.keys():
self.lang_tp[langCode] = get_text_preprocessor(langCode, self.base_dir, logger=self.logger)
try:
pad_symb = len(ALL_SYMBOLS)-2
all_sequence = []
all_cleaned_text = []
all_text = []
all_lang_ids = []
# Collapse same-language words into phrases, so that heteronyms can still be detected
sequenceSplitByLanguage_grouped = []
last_lang_group = None
group = ""
for ssi, subSequence in enumerate(sequenceSplitByLanguage):
if list(subSequence.keys())[0]!=last_lang_group:
if last_lang_group is not None:
sequenceSplitByLanguage_grouped.append({last_lang_group: group})
group = ""
last_lang_group = list(subSequence.keys())[0]
group += subSequence[last_lang_group]
if len(group):
sequenceSplitByLanguage_grouped.append({last_lang_group: group})
for ssi, subSequence in enumerate(sequenceSplitByLanguage_grouped):
langCode = list(subSequence.keys())[0]
subSeq = subSequence[langCode]
sequence, cleaned_text = self.lang_tp[langCode].text_to_sequence(subSeq)
if ssi<len(sequenceSplitByLanguage_grouped)-1:
sequence = sequence + [pad_symb]
all_sequence.append(sequence)
all_cleaned_text += ("|"+cleaned_text) if len(all_cleaned_text) else cleaned_text
if ssi<len(sequenceSplitByLanguage_grouped)-1:
all_cleaned_text = all_cleaned_text + ["|<PAD>"]
all_text.append(torch.LongTensor(sequence))
language_id = self.language_id_mapping[langCode]
all_lang_ids += [language_id for _ in range(len(sequence))]
except ValueError as e:
self.logger.info("====")
self.logger.info(str(e))
self.logger.info("====--")
if "not in list" in str(e):
symbol_not_in_list = str(e).split("is not in list")[0].split("ValueError:")[-1].replace("'", "").strip()
return f'ERR: ARPABET_NOT_IN_LIST: {symbol_not_in_list}'
all_cleaned_text = "".join(all_cleaned_text)
text = torch.cat(all_text, dim=0)
text = pad_sequence([text], batch_first=True).to(self.models_manager.device)
with torch.no_grad():
if old_sequence is not None:
old_sequence = re.sub(r'[^a-zA-Z\s\(\)\[\]0-9\?\.\,\!\'\{\}\_\@]+', '', old_sequence)
old_sequence, clean_old_sequence = self.lang_tp[base_lang].text_to_sequence(old_sequence)#, "english_basic", ['english_cleaners'])
old_sequence = torch.LongTensor(old_sequence)
old_sequence = pad_sequence([old_sequence], batch_first=True).to(self.models_manager.device)
lang_ids = torch.tensor(all_lang_ids).to(self.models_manager.device)
num_embs = text.shape[1]
base_emb = [float(val) for val in base_emb.split(",")] if "," in base_emb else self.base_emb
speaker_embs = [torch.tensor(base_emb).unsqueeze(dim=0)[0].unsqueeze(-1)]
speaker_embs = torch.stack(speaker_embs, dim=0).to(self.models_manager.device)#.unsqueeze(-1)
speaker_embs = speaker_embs.repeat(1,1,num_embs)
# Do interpolations of speaker style embeddings
if editor_data is not None:
editorStyles = editor_data[-1]
if editorStyles is not None:
style_keys = list(editorStyles.keys())
for style_key in style_keys:
emb = editorStyles[style_key]["embedding"]
sliders_vals = editorStyles[style_key]["sliders"]
style_embs = [torch.tensor(emb).unsqueeze(dim=0)[0].unsqueeze(-1)]
style_embs = torch.stack(style_embs, dim=0).to(self.models_manager.device)#.unsqueeze(-1)
style_embs = style_embs.repeat(1,1,num_embs)
sliders_vals = torch.tensor(sliders_vals).to(self.models_manager.device)
speaker_embs = speaker_embs*(1-sliders_vals) + sliders_vals*style_embs
speaker_embs = speaker_embs.float()
lang_embs = lang_ids # TODO, use pre-extracted trained language embeddings, for interpolation
out = self.model.infer_advanced(self.logger, plugin_manager, [all_cleaned_text], text, lang_embs=lang_embs, speaker_embs=speaker_embs, pace=pace, editor_data=editor_data, old_sequence=old_sequence)
if isinstance(out, str):
return f'ERR:{out}'
else:
output_wav, dur_pred, pitch_pred, energy_pred, em_pred, start_index, end_index, wav_mult = out
[em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred] = em_pred
wav = output_wav.squeeze().cpu().detach().numpy()
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
if wav_mult is not None:
wav_norm = wav_norm * wav_mult
if useCleanup:
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe'
if useSR:
scipy.io.wavfile.write(out_path.replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16))
else:
scipy.io.wavfile.write(out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16))
stream = ffmpeg.input(out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"))
ffmpeg_options = {"ar": 48000}
output_path = out_path.replace(".wav", "_preCleanup.wav")
stream = ffmpeg.output(stream, output_path, **ffmpeg_options)
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True))
os.remove(out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"))
else:
scipy.io.wavfile.write(out_path.replace(".wav", "_preSR.wav") if useSR else out_path, 22050, wav_norm.astype(np.int16))
if useSR:
self.models_manager.init_model("nuwave2")
self.models_manager.models("nuwave2").sr_audio(out_path.replace(".wav", "_preSR.wav"), out_path.replace(".wav", "_preCleanup.wav") if useCleanup else out_path)
if useCleanup:
self.models_manager.init_model("deepfilternet2")
self.models_manager.models("deepfilternet2").cleanup_audio(out_path.replace(".wav", "_preCleanup.wav"), out_path)
[pitch, durations, energy, em_angry, em_happy, em_sad, em_surprise] = [
pitch_pred.squeeze().cpu().detach().numpy(),
dur_pred.squeeze().cpu().detach().numpy(),
energy_pred.cpu().detach().numpy() if energy_pred is not None else [],
em_angry_pred.squeeze().cpu().detach().numpy() if em_angry_pred is not None else [],
em_happy_pred.squeeze().cpu().detach().numpy() if em_happy_pred is not None else [],
em_sad_pred.squeeze().cpu().detach().numpy() if em_sad_pred is not None else [],
em_surprise_pred.squeeze().cpu().detach().numpy() if em_surprise_pred is not None else [],
]
pitch = [float(v) for v in pitch]
durations = [float(v) for v in durations]
energy = [float(v) for v in energy]
em_angry = [float(v) for v in em_angry]
em_happy = [float(v) for v in em_happy]
em_sad = [float(v) for v in em_sad]
em_surprise = [float(v) for v in em_surprise]
del pitch_pred, dur_pred, energy_pred, text, sequence
return {
"pitch": pitch,
"durations": durations,
"energy": energy,
"em_angry": em_angry,
"em_happy": em_happy,
"em_sad": em_sad,
"em_surprise": em_surprise,
"editorStyles": json.dumps(editorStyles),
"arpabet": all_cleaned_text
}
def set_device (self, device):
self.device = device
self.model = self.model.to(device)
self.model.pitch_emb_values = self.model.pitch_emb_values.to(device)
self.model.device = device