import os import re import json import codecs import ffmpeg import argparse import torch import torch.nn as nn from python.fastpitch1_1 import models from scipy.io.wavfile import write from torch.nn.utils.rnn import pad_sequence from python.common.text import text_to_sequence, sequence_to_text ARPAbet_replacements_dict = { "YO": "IY0 UW0", "UH": "UH0", "AR": "R", "EY": "EY0", "A": "AA0", "AW": "AW0", "X": "K S", "CX": "K HH", "AO": "AO0", "PF": "P F", "AY": "AY0", "OE": "OW0 IY0", "IY": "IY0", "EH": "EH0", "OY": "OY0", "IH": "IH0", "H": "HH" } class FastPitch1_1(object): def __init__(self, logger, PROD, device, models_manager): super(FastPitch1_1, self).__init__() self.logger = logger self.PROD = PROD self.models_manager = models_manager self.device = device self.ckpt_path = None self.arpabet_dict = {} torch.backends.cudnn.benchmark = True self.init_model("english_basic") self.isReady = True def init_model (self, symbols_alphabet): parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference', allow_abbrev=False) self.symbols_alphabet = symbols_alphabet model_parser = models.parse_model_args("FastPitch", symbols_alphabet, parser, add_help=False) model_args, model_unk_args = model_parser.parse_known_args() model_config = models.get_model_config("FastPitch", model_args) self.model = models.get_model("FastPitch", model_config, self.device, self.logger, forward_is_infer=True, jitable=False) self.model.eval() self.model.device = self.device def load_state_dict (self, ckpt_path, ckpt, n_speakers=1, base_lang=None): self.ckpt_path = ckpt_path with open(ckpt_path.replace(".pt", ".json"), "r") as f: data = json.load(f) if "symbols_alphabet" in data.keys() and data["symbols_alphabet"]!=self.symbols_alphabet: self.logger.info(f'Changing symbols_alphabet from {self.symbols_alphabet} to {data["symbols_alphabet"]}') self.init_model(data["symbols_alphabet"]) if 'state_dict' in ckpt: ckpt = ckpt['state_dict'] self.model.load_state_dict(ckpt, strict=False) self.model = self.model.float() self.model.eval() def init_arpabet_dicts (self): if len(list(self.arpabet_dict.keys()))==0: self.refresh_arpabet_dicts() def refresh_arpabet_dicts (self): self.arpabet_dict = {} json_files = sorted(os.listdir(f'{"./resources/app" if self.PROD else "."}/arpabet')) json_files = [fname for fname in json_files if fname.endswith(".json")] for fname in json_files: with codecs.open(f'{"./resources/app" if self.PROD else "."}/arpabet/{fname}', encoding="utf-8") as f: json_data = json.load(f) for word in list(json_data["data"].keys()): if json_data["data"][word]["enabled"]==True: self.arpabet_dict[word] = json_data["data"][word]["arpabet"] def infer_arpabet_dict (self, sentence, plugin_manager=None): dict_words = list(self.arpabet_dict.keys()) data_context = {} data_context["sentence"] = sentence data_context["dict_words"] = dict_words data_context["language"] = "en" plugin_manager.run_plugins(plist=plugin_manager.plugins["arpabet-replace"]["pre"], event="pre arpabet-replace", data=data_context) sentence = data_context["sentence"] dict_words = data_context["dict_words"] # Don't run the ARPAbet replacement for every single word, as it would be too slow. Instead, do it only for words that are actually present in the prompt words_in_prompt = (sentence+" ").replace("}","").replace("{","").replace(",","").replace("?","").replace("!","").replace("...",".").replace(". "," ").lower().split(" ") words_in_prompt = [word.strip() for word in words_in_prompt if len(word.strip()) and word in dict_words] if len(words_in_prompt): # Pad out punctuation, to make sure they don't get used in the word look-ups sentence = " "+sentence.replace(",", " ,").replace(".", " .").replace("!", " !").replace("?", " ?")+" " for dict_word in words_in_prompt: arpabet_string = " "+self.arpabet_dict[dict_word]+" " if "CX" in arpabet_string: # German: # hhhhh sound After "a", "o", "u" and "au" # The usual K HH otherwise # Need to account for multiple ch per word pass for key in ARPAbet_replacements_dict.keys(): arpabet_string = arpabet_string.replace(f' {key} ', f' {ARPAbet_replacements_dict[key]} ') arpabet_string = arpabet_string.strip() sentence = re.sub("(?