|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import tempfile |
|
import torch |
|
import sys |
|
import gradio as gr |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
if "vits" not in sys.path: |
|
sys.path.append("vits") |
|
|
|
from vits import commons, utils |
|
from vits.models import SynthesizerTrn |
|
|
|
|
|
class TextMapper(object): |
|
def __init__(self, vocab_file): |
|
self.symbols = [ |
|
x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines() |
|
] |
|
self.SPACE_ID = self.symbols.index(" ") |
|
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} |
|
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)} |
|
|
|
def text_to_sequence(self, text, cleaner_names): |
|
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text. |
|
Args: |
|
text: string to convert to a sequence |
|
cleaner_names: names of the cleaner functions to run the text through |
|
Returns: |
|
List of integers corresponding to the symbols in the text |
|
""" |
|
sequence = [] |
|
clean_text = text.strip() |
|
for symbol in clean_text: |
|
symbol_id = self._symbol_to_id[symbol] |
|
sequence += [symbol_id] |
|
return sequence |
|
|
|
def uromanize(self, text, uroman_pl): |
|
iso = "xxx" |
|
with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2: |
|
with open(tf.name, "w") as f: |
|
f.write("\n".join([text])) |
|
cmd = f"perl " + uroman_pl |
|
cmd += f" -l {iso} " |
|
cmd += f" < {tf.name} > {tf2.name}" |
|
os.system(cmd) |
|
outtexts = [] |
|
with open(tf2.name) as f: |
|
for line in f: |
|
line = re.sub(r"\s+", " ", line).strip() |
|
outtexts.append(line) |
|
outtext = outtexts[0] |
|
return outtext |
|
|
|
def get_text(self, text, hps): |
|
text_norm = self.text_to_sequence(text, hps.data.text_cleaners) |
|
if hps.data.add_blank: |
|
text_norm = commons.intersperse(text_norm, 0) |
|
text_norm = torch.LongTensor(text_norm) |
|
return text_norm |
|
|
|
def filter_oov(self, text, lang=None): |
|
text = self.preprocess_char(text, lang=lang) |
|
val_chars = self._symbol_to_id |
|
txt_filt = "".join(list(filter(lambda x: x in val_chars, text))) |
|
return txt_filt |
|
|
|
def preprocess_char(self, text, lang=None): |
|
""" |
|
Special treatement of characters in certain languages |
|
""" |
|
if lang == "ron": |
|
text = text.replace("ț", "ţ") |
|
print(f"{lang} (ț -> ţ): {text}") |
|
return text |
|
|
|
|
|
def synthesize(text, lang, speed): |
|
|
|
if speed is None: |
|
speed = 1.0 |
|
|
|
lang_code = lang.split(":")[0].strip() |
|
|
|
vocab_file = hf_hub_download( |
|
repo_id="facebook/mms-tts", |
|
filename="vocab.txt", |
|
subfolder=f"models/{lang_code}", |
|
) |
|
config_file = hf_hub_download( |
|
repo_id="facebook/mms-tts", |
|
filename="config.json", |
|
subfolder=f"models/{lang_code}", |
|
) |
|
g_pth = hf_hub_download( |
|
repo_id="facebook/mms-tts", |
|
filename="G_100000.pth", |
|
subfolder=f"models/{lang_code}", |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif ( |
|
hasattr(torch.backends, "mps") |
|
and torch.backends.mps.is_available() |
|
and torch.backends.mps.is_built() |
|
): |
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
print(f"Run inference with {device}") |
|
|
|
assert os.path.isfile(config_file), f"{config_file} doesn't exist" |
|
hps = utils.get_hparams_from_file(config_file) |
|
text_mapper = TextMapper(vocab_file) |
|
net_g = SynthesizerTrn( |
|
len(text_mapper.symbols), |
|
hps.data.filter_length // 2 + 1, |
|
hps.train.segment_size // hps.data.hop_length, |
|
**hps.model, |
|
) |
|
net_g.to(device) |
|
_ = net_g.eval() |
|
|
|
_ = utils.load_checkpoint(g_pth, net_g, None) |
|
|
|
is_uroman = hps.data.training_files.split(".")[-1] == "uroman" |
|
|
|
if is_uroman: |
|
uroman_dir = "uroman" |
|
assert os.path.exists(uroman_dir) |
|
uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl") |
|
text = text_mapper.uromanize(text, uroman_pl) |
|
|
|
text = text.lower() |
|
text = text_mapper.filter_oov(text, lang=lang) |
|
stn_tst = text_mapper.get_text(text, hps) |
|
with torch.no_grad(): |
|
x_tst = stn_tst.unsqueeze(0).to(device) |
|
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) |
|
hyp = ( |
|
net_g.infer( |
|
x_tst, |
|
x_tst_lengths, |
|
noise_scale=0.667, |
|
noise_scale_w=0.8, |
|
length_scale=1.0 / speed, |
|
)[0][0, 0] |
|
.cpu() |
|
.float() |
|
.numpy() |
|
) |
|
|
|
return gr.Audio.update(value=(hps.data.sampling_rate, hyp)), text |
|
|
|
|
|
TTS_EXAMPLES = [ |
|
["I am going to the store.", "eng: English"], |
|
["안녕하세요", "kor: Korean"], |
|
["क्या मुझे पीने का पानी मिल सकता है?", "hin: Hindi"], |
|
["Tanış olmağıma çox şadam", "azj-script_latin: Azerbaijani, North"], |
|
["Mu zo murna a cikin ƙasar.", "hau: Hausa"], |
|
] |
|
|