vits-for-ba /
tovaru's picture
initial commit
history blame
7.99 kB
import os
import re
import json
import argparse
# get arguments
parser = argparse.ArgumentParser(description="VITS model for Blue Archive characters")
parser.add_argument('--device', type=str, default='cpu', choices=["cpu","cuda"], help="inference method (default: cpu)")
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
parser.add_argument("--queue", type=int, default=1, help="max number of concurrent workers for the queue (default: 1, set to 0 to disable)")
parser.add_argument('--api', action="store_true", default=False, help="enable REST routes (allows to skip queue)")
parser.add_argument("--limit", type=int, default=100, help="prompt words limit (default: 100, set to 0 to disable)")
args = parser.parse_args()
#limit_text = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
import torch
import utils
import commons
from models import SynthesizerTrn
from text import text_to_sequence
import gradio as gr
metadata_path = "./models/metadata.json"
config_path = "./models/config.json"
default_sid = 10
# load each model's definition
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
##################### MODEL LOADING #####################
# loads one model and returns a container with all references to the model
device = torch.device(args.device)
def load_checkpoint(model_path):
hps = utils.get_hparams_from_file(config_path)
net_g = SynthesizerTrn(
len(hps.symbols), // 2 + 1,
hps.train.segment_size //,,
model = net_g.eval().to(device)
pythomodel = utils.load_checkpoint(model_path, net_g, None)
return {"hps": hps, "net_g": net_g}
# loads all models on system memory and returns a dict with all references
def load_models():
models = {}
for name, model_info in metadata.items():
model = load_checkpoint(model_info["model"])
models[name] = model
return models
models = load_models()
##################### INFERENCE #####################
def get_text(text, hps):
text_norm = text_to_sequence(text,
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def prepare_text(text):
prepared_text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if args.limit > 0:
text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
max_len = args.limit
if text_len > max_len:
return None
return prepared_text
# inferes a model
def speak(selected_index, text, noise_scale, noise_scale_w, length_scale, speaker_id):
# get selected model
model = list(models.items())[selected_index][1]
# clean and truncate text
text = prepare_text(text)
if text is None:
return ["Error: text is too long", None]
stn_tst = get_text(text, model["hps"])
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(device)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
sid = torch.LongTensor([speaker_id]).to(device)
audio = model["net_g"].infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
return ["Success", (22050, audio)]
##################### GRADIO UI #####################
# javascript function for creating a download link, reads the audio file from the DOM and returns it with the prompt as the filename
download_audio_js = """
() =>{{
let root = document.querySelector("body > gradio-app");
if (root.shadowRoot != null)
root = root.shadowRoot;
let audio = root.querySelector("#output-audio").querySelector("audio");
let text = root.querySelector("#input-text").querySelector("textarea");
if (audio == undefined)
text = text.value;
if (text == undefined)
text = Math.floor(Math.random()*100000000);
audio = audio.src;
let oA = document.createElement("a"); = text.substr(0, 20)+'.wav';
oA.href = audio;
def make_gradio_ui():
with gr.Blocks(css="#avatar {width:auto;height:200px;}") as demo:
# make title
# <center> VITS for Blue Archive Characters
## <center> This is for educational purposes, please do not generate harmful or inappropriate content\n
This is based on [zomehwh's implementation](, please visit their space if you want to finetune your own models!
first = list(metadata.items())[0][1]
with gr.Row():
# setup column
with gr.Column():
# set text input field
text = gr.Textbox(
label=f"Text ({args.limit} words limit)" if args.limit > 0 else "Text",
# button for loading example texts
example_btn = gr.Button("Example", size="sm")
# character selector
names = []
for _, model in metadata.items():
character = gr.Radio(label="Character", choices=names, value=names[0], type="index")
# inference parameters
with gr.Row():
ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
with gr.Row():
ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1, interactive=True)
sid = gr.Number(value=default_sid, label="Speaker ID (increase or decrease to change the intonation)")
# generate button
generate_btn = gr.Button(value="Generate", variant="primary")
# results column
with gr.Column():
avatar = gr.Image(first["avatar"], type="pil", elem_id="avatar")
output_message = gr.Textbox(label="Result Message")
output_audio = gr.Audio(label="Output Audio", interactive=False, elem_id="output-audio")
download_btn = gr.Button("Download Audio")
# set character selection updates
def update_selection(index):
avatars = list(metadata.items())
new_avatar = avatars[index][1]["avatar"]
return gr.Image.update(value=new_avatar)
character.change(update_selection, character, avatar)
# set generate button action, inputs=[character, text, ns, nsw, ls, sid], outputs=[output_message, output_audio], api_name="generate")
# set download butto naction, [], [], _js=download_audio_js)
def load_example(index):
example = list(metadata.items())[index][1]["example_text"]
return example, inputs=[character], outputs=[text])
return demo
if __name__ == "__main__":
demo = make_gradio_ui()
if args.queue > 0:
demo.queue(concurrency_count=args.queue, api_open=args.api)
demo.launch(share=args.share, show_api=args.api)