import os import re import json import argparse # get arguments parser = argparse.ArgumentParser(description="VITS model for Blue Archive characters") parser.add_argument('--device', type=str, default='cpu', choices=["cpu","cuda"], help="inference method (default: cpu)") parser.add_argument("--share", action="store_true", default=False, help="share gradio app") parser.add_argument("--queue", type=int, default=1, help="max number of concurrent workers for the queue (default: 1, set to 0 to disable)") parser.add_argument('--api', action="store_true", default=False, help="enable REST routes (allows to skip queue)") parser.add_argument("--limit", type=int, default=100, help="prompt words limit (default: 100, set to 0 to disable)") args = parser.parse_args() #limit_text = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces import torch import utils import commons from models import SynthesizerTrn from text import text_to_sequence import gradio as gr metadata_path = "./models/metadata.json" config_path = "./models/config.json" default_sid = 10 # load each model's definition with open(metadata_path, "r", encoding="utf-8") as f: metadata = json.load(f) ##################### MODEL LOADING ##################### # loads one model and returns a container with all references to the model device = torch.device(args.device) def load_checkpoint(model_path): hps = utils.get_hparams_from_file(config_path) net_g = SynthesizerTrn( len(hps.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model) model = net_g.eval().to(device) pythomodel = utils.load_checkpoint(model_path, net_g, None) return {"hps": hps, "net_g": net_g} # loads all models on system memory and returns a dict with all references def load_models(): models = {} for name, model_info in metadata.items(): model = load_checkpoint(model_info["model"]) models[name] = model return models models = load_models() ##################### INFERENCE ##################### def get_text(text, hps): text_norm = text_to_sequence(text, hps.data.text_cleaners) if hps.data.add_blank: text_norm = commons.intersperse(text_norm, 0) text_norm = torch.LongTensor(text_norm) return text_norm def prepare_text(text): prepared_text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") if args.limit > 0: text_len = len(re.sub("\[([A-Z]{2})\]", "", text)) max_len = args.limit if text_len > max_len: return None return prepared_text # inferes a model def speak(selected_index, text, noise_scale, noise_scale_w, length_scale, speaker_id): # get selected model model = list(models.items())[selected_index][1] # clean and truncate text text = prepare_text(text) if text is None: return ["Error: text is too long", None] stn_tst = get_text(text, model["hps"]) with torch.no_grad(): x_tst = stn_tst.unsqueeze(0).to(device) x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) sid = torch.LongTensor([speaker_id]).to(device) audio = model["net_g"].infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy() return ["Success", (22050, audio)] ##################### GRADIO UI ##################### # javascript function for creating a download link, reads the audio file from the DOM and returns it with the prompt as the filename download_audio_js = """ () =>{{ let root = document.querySelector("body > gradio-app"); if (root.shadowRoot != null) root = root.shadowRoot; let audio = root.querySelector("#output-audio").querySelector("audio"); let text = root.querySelector("#input-text").querySelector("textarea"); if (audio == undefined) return; text = text.value; if (text == undefined) text = Math.floor(Math.random()*100000000); audio = audio.src; let oA = document.createElement("a"); oA.download = text.substr(0, 20)+'.wav'; oA.href = audio; document.body.appendChild(oA); oA.click(); oA.remove(); }} """ def make_gradio_ui(): with gr.Blocks(css="#avatar {width:auto;height:200px;}") as demo: # make title gr.Markdown( """ #
VITS for Blue Archive Characters ##
This is for educational purposes, please do not generate harmful or inappropriate content\n [![HitCount](https://hits.dwyl.com/tovaru/vits-for-ba.svg?style=flat-square&show=unique)](http://hits.dwyl.com/tovaru/vits-for-ba)\n This is based on [zomehwh's implementation](https://huggingface.co/spaces/zomehwh/vits-models), please visit their space if you want to finetune your own models! """ ) first = list(metadata.items())[0][1] with gr.Row(): # setup column with gr.Column(): # set text input field text = gr.Textbox( lines=5, interactive=True, label=f"Text ({args.limit} words limit)" if args.limit > 0 else "Text", value=first["example_text"], elem_id="input-text") # button for loading example texts example_btn = gr.Button("Example", size="sm") # character selector names = [] for _, model in metadata.items(): names.append(model["name_en"]) character = gr.Radio(label="Character", choices=names, value=names[0], type="index") # inference parameters with gr.Row(): ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True) nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True) with gr.Row(): ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1, interactive=True) sid = gr.Number(value=default_sid, label="Speaker ID (increase or decrease to change the intonation)") # generate button generate_btn = gr.Button(value="Generate", variant="primary") # results column with gr.Column(): avatar = gr.Image(first["avatar"], type="pil", elem_id="avatar") output_message = gr.Textbox(label="Result Message") output_audio = gr.Audio(label="Output Audio", interactive=False, elem_id="output-audio") download_btn = gr.Button("Download Audio") # set character selection updates def update_selection(index): avatars = list(metadata.items()) new_avatar = avatars[index][1]["avatar"] return gr.Image.update(value=new_avatar) character.change(update_selection, character, avatar) # set generate button action generate_btn.click(fn=speak, inputs=[character, text, ns, nsw, ls, sid], outputs=[output_message, output_audio], api_name="generate") # set download butto naction download_btn.click(None, [], [], _js=download_audio_js) def load_example(index): example = list(metadata.items())[index][1]["example_text"] return example example_btn.click(fn=load_example, inputs=[character], outputs=[text]) return demo if __name__ == "__main__": demo = make_gradio_ui() if args.queue > 0: demo.queue(concurrency_count=args.queue, api_open=args.api) demo.launch(share=args.share, show_api=args.api)