Spaces:
Runtime error
Runtime error
import os | |
import re | |
import json | |
import argparse | |
# get arguments | |
parser = argparse.ArgumentParser(description="VITS model for Blue Archive characters") | |
parser.add_argument('--device', type=str, default='cpu', choices=["cpu","cuda"], help="inference method (default: cpu)") | |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app") | |
parser.add_argument("--queue", type=int, default=1, help="max number of concurrent workers for the queue (default: 1, set to 0 to disable)") | |
parser.add_argument('--api', action="store_true", default=False, help="enable REST routes (allows to skip queue)") | |
parser.add_argument("--limit", type=int, default=100, help="prompt words limit (default: 100, set to 0 to disable)") | |
args = parser.parse_args() | |
#limit_text = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces | |
import torch | |
import utils | |
import commons | |
from models import SynthesizerTrn | |
from text import text_to_sequence | |
import gradio as gr | |
metadata_path = "./models/metadata.json" | |
config_path = "./models/config.json" | |
default_sid = 10 | |
# load each model's definition | |
with open(metadata_path, "r", encoding="utf-8") as f: | |
metadata = json.load(f) | |
##################### MODEL LOADING ##################### | |
# loads one model and returns a container with all references to the model | |
device = torch.device(args.device) | |
def load_checkpoint(model_path): | |
hps = utils.get_hparams_from_file(config_path) | |
net_g = SynthesizerTrn( | |
len(hps.symbols), | |
hps.data.filter_length // 2 + 1, | |
hps.train.segment_size // hps.data.hop_length, | |
n_speakers=hps.data.n_speakers, | |
**hps.model) | |
model = net_g.eval().to(device) | |
pythomodel = utils.load_checkpoint(model_path, net_g, None) | |
return {"hps": hps, "net_g": net_g} | |
# loads all models on system memory and returns a dict with all references | |
def load_models(): | |
models = {} | |
for name, model_info in metadata.items(): | |
model = load_checkpoint(model_info["model"]) | |
models[name] = model | |
return models | |
models = load_models() | |
##################### INFERENCE ##################### | |
def get_text(text, hps): | |
text_norm = text_to_sequence(text, hps.data.text_cleaners) | |
if hps.data.add_blank: | |
text_norm = commons.intersperse(text_norm, 0) | |
text_norm = torch.LongTensor(text_norm) | |
return text_norm | |
def prepare_text(text): | |
prepared_text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") | |
if args.limit > 0: | |
text_len = len(re.sub("\[([A-Z]{2})\]", "", text)) | |
max_len = args.limit | |
if text_len > max_len: | |
return None | |
return prepared_text | |
# inferes a model | |
def speak(selected_index, text, noise_scale, noise_scale_w, length_scale, speaker_id): | |
# get selected model | |
model = list(models.items())[selected_index][1] | |
# clean and truncate text | |
text = prepare_text(text) | |
if text is None: | |
return ["Error: text is too long", None] | |
stn_tst = get_text(text, model["hps"]) | |
with torch.no_grad(): | |
x_tst = stn_tst.unsqueeze(0).to(device) | |
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) | |
sid = torch.LongTensor([speaker_id]).to(device) | |
audio = model["net_g"].infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy() | |
return ["Success", (22050, audio)] | |
##################### GRADIO UI ##################### | |
# javascript function for creating a download link, reads the audio file from the DOM and returns it with the prompt as the filename | |
download_audio_js = """ | |
() =>{{ | |
let root = document.querySelector("body > gradio-app"); | |
if (root.shadowRoot != null) | |
root = root.shadowRoot; | |
let audio = root.querySelector("#output-audio").querySelector("audio"); | |
let text = root.querySelector("#input-text").querySelector("textarea"); | |
if (audio == undefined) | |
return; | |
text = text.value; | |
if (text == undefined) | |
text = Math.floor(Math.random()*100000000); | |
audio = audio.src; | |
let oA = document.createElement("a"); | |
oA.download = text.substr(0, 20)+'.wav'; | |
oA.href = audio; | |
document.body.appendChild(oA); | |
oA.click(); | |
oA.remove(); | |
}} | |
""" | |
def make_gradio_ui(): | |
with gr.Blocks(css="#avatar {width:auto;height:200px;}") as demo: | |
# make title | |
gr.Markdown( | |
""" | |
# <center> VITS for Blue Archive Characters | |
## <center> This is for educational purposes, please do not generate harmful or inappropriate content\n | |
[![HitCount](https://hits.dwyl.com/tovaru/vits-for-ba.svg?style=flat-square&show=unique)](http://hits.dwyl.com/tovaru/vits-for-ba)\n | |
This is based on [zomehwh's implementation](https://huggingface.co/spaces/zomehwh/vits-models), please visit their space if you want to finetune your own models! | |
""" | |
) | |
first = list(metadata.items())[0][1] | |
with gr.Row(): | |
# setup column | |
with gr.Column(): | |
# set text input field | |
text = gr.Textbox( | |
lines=5, | |
interactive=True, | |
label=f"Text ({args.limit} words limit)" if args.limit > 0 else "Text", | |
value=first["example_text"], | |
elem_id="input-text") | |
# button for loading example texts | |
example_btn = gr.Button("Example", size="sm") | |
# character selector | |
names = [] | |
for _, model in metadata.items(): | |
names.append(model["name_en"]) | |
character = gr.Radio(label="Character", choices=names, value=names[0], type="index") | |
# inference parameters | |
with gr.Row(): | |
ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True) | |
nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True) | |
with gr.Row(): | |
ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1, interactive=True) | |
sid = gr.Number(value=default_sid, label="Speaker ID (increase or decrease to change the intonation)") | |
# generate button | |
generate_btn = gr.Button(value="Generate", variant="primary") | |
# results column | |
with gr.Column(): | |
avatar = gr.Image(first["avatar"], type="pil", elem_id="avatar") | |
output_message = gr.Textbox(label="Result Message") | |
output_audio = gr.Audio(label="Output Audio", interactive=False, elem_id="output-audio") | |
download_btn = gr.Button("Download Audio") | |
# set character selection updates | |
def update_selection(index): | |
avatars = list(metadata.items()) | |
new_avatar = avatars[index][1]["avatar"] | |
return gr.Image.update(value=new_avatar) | |
character.change(update_selection, character, avatar) | |
# set generate button action | |
generate_btn.click(fn=speak, inputs=[character, text, ns, nsw, ls, sid], outputs=[output_message, output_audio], api_name="generate") | |
# set download butto naction | |
download_btn.click(None, [], [], _js=download_audio_js) | |
def load_example(index): | |
example = list(metadata.items())[index][1]["example_text"] | |
return example | |
example_btn.click(fn=load_example, inputs=[character], outputs=[text]) | |
return demo | |
if __name__ == "__main__": | |
demo = make_gradio_ui() | |
if args.queue > 0: | |
demo.queue(concurrency_count=args.queue, api_open=args.api) | |
demo.launch(share=args.share, show_api=args.api) | |