import os import re import json import argparse # get arguments parser = argparse.ArgumentParser(description="VITS model for Blue Archive characters") parser.add_argument('--device', type=str, default='cpu', choices=["cpu","cuda"], help="inference method (default: cpu)") parser.add_argument("--share", action="store_true", default=False, help="share gradio app") parser.add_argument("--queue", type=int, default=1, help="max number of concurrent workers for the queue (default: 1, set to 0 to disable)") parser.add_argument('--api', action="store_true", default=False, help="enable REST routes (allows to skip queue)") parser.add_argument("--limit", type=int, default=100, help="prompt words limit (default: 100, set to 0 to disable)") args = parser.parse_args() #limit_text = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces import torch import utils import commons from models import SynthesizerTrn from text import text_to_sequence import gradio as gr metadata_path = "./models/metadata.json" config_path = "./models/config.json" default_sid = 10 # load each model's definition with open(metadata_path, "r", encoding="utf-8") as f: metadata = json.load(f) ##################### MODEL LOADING ##################### # loads one model and returns a container with all references to the model device = torch.device(args.device) def load_checkpoint(model_path): hps = utils.get_hparams_from_file(config_path) net_g = SynthesizerTrn( len(hps.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model) model = net_g.eval().to(device) pythomodel = utils.load_checkpoint(model_path, net_g, None) return {"hps": hps, "net_g": net_g} # loads all models on system memory and returns a dict with all references def load_models(): models = {} for name, model_info in metadata.items(): model = load_checkpoint(model_info["model"]) models[name] = model return models models = load_models() ##################### INFERENCE ##################### def get_text(text, hps): text_norm = text_to_sequence(text, hps.data.text_cleaners) if hps.data.add_blank: text_norm = commons.intersperse(text_norm, 0) text_norm = torch.LongTensor(text_norm) return text_norm def prepare_text(text): prepared_text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") if args.limit > 0: text_len = len(re.sub("\[([A-Z]{2})\]", "", text)) max_len = args.limit if text_len > max_len: return None return prepared_text # inferes a model def speak(selected_index, text, noise_scale, noise_scale_w, length_scale, speaker_id): # get selected model model = list(models.items())[selected_index][1] # clean and truncate text text = prepare_text(text) if text is None: return ["Error: text is too long", None] stn_tst = get_text(text, model["hps"]) with torch.no_grad(): x_tst = stn_tst.unsqueeze(0).to(device) x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) sid = torch.LongTensor([speaker_id]).to(device) audio = model["net_g"].infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy() return ["Success", (22050, audio)] ##################### GRADIO UI ##################### # javascript function for creating a download link, reads the audio file from the DOM and returns it with the prompt as the filename download_audio_js = """ () =>{{ let root = document.querySelector("body > gradio-app"); if (root.shadowRoot != null) root = root.shadowRoot; let audio = root.querySelector("#output-audio").querySelector("audio"); let text = root.querySelector("#input-text").querySelector("textarea"); if (audio == undefined) return; text = text.value; if (text == undefined) text = Math.floor(Math.random()*100000000); audio = audio.src; let oA = document.createElement("a"); oA.download = text.substr(0, 20)+'.wav'; oA.href = audio; document.body.appendChild(oA); oA.click(); oA.remove(); }} """ def make_gradio_ui(): with gr.Blocks(css="#avatar {width:auto;height:200px;}") as demo: # make title gr.Markdown( """ #