Spaces:
Runtime error
Runtime error
File size: 7,990 Bytes
f2f3712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import os
import re
import json
import argparse
# get arguments
parser = argparse.ArgumentParser(description="VITS model for Blue Archive characters")
parser.add_argument('--device', type=str, default='cpu', choices=["cpu","cuda"], help="inference method (default: cpu)")
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
parser.add_argument("--queue", type=int, default=1, help="max number of concurrent workers for the queue (default: 1, set to 0 to disable)")
parser.add_argument('--api', action="store_true", default=False, help="enable REST routes (allows to skip queue)")
parser.add_argument("--limit", type=int, default=100, help="prompt words limit (default: 100, set to 0 to disable)")
args = parser.parse_args()
#limit_text = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
import torch
import utils
import commons
from models import SynthesizerTrn
from text import text_to_sequence
import gradio as gr
metadata_path = "./models/metadata.json"
config_path = "./models/config.json"
default_sid = 10
# load each model's definition
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
##################### MODEL LOADING #####################
# loads one model and returns a container with all references to the model
device = torch.device(args.device)
def load_checkpoint(model_path):
hps = utils.get_hparams_from_file(config_path)
net_g = SynthesizerTrn(
len(hps.symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
model = net_g.eval().to(device)
pythomodel = utils.load_checkpoint(model_path, net_g, None)
return {"hps": hps, "net_g": net_g}
# loads all models on system memory and returns a dict with all references
def load_models():
models = {}
for name, model_info in metadata.items():
model = load_checkpoint(model_info["model"])
models[name] = model
return models
models = load_models()
##################### INFERENCE #####################
def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def prepare_text(text):
prepared_text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if args.limit > 0:
text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
max_len = args.limit
if text_len > max_len:
return None
return prepared_text
# inferes a model
def speak(selected_index, text, noise_scale, noise_scale_w, length_scale, speaker_id):
# get selected model
model = list(models.items())[selected_index][1]
# clean and truncate text
text = prepare_text(text)
if text is None:
return ["Error: text is too long", None]
stn_tst = get_text(text, model["hps"])
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(device)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
sid = torch.LongTensor([speaker_id]).to(device)
audio = model["net_g"].infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
return ["Success", (22050, audio)]
##################### GRADIO UI #####################
# javascript function for creating a download link, reads the audio file from the DOM and returns it with the prompt as the filename
download_audio_js = """
() =>{{
let root = document.querySelector("body > gradio-app");
if (root.shadowRoot != null)
root = root.shadowRoot;
let audio = root.querySelector("#output-audio").querySelector("audio");
let text = root.querySelector("#input-text").querySelector("textarea");
if (audio == undefined)
return;
text = text.value;
if (text == undefined)
text = Math.floor(Math.random()*100000000);
audio = audio.src;
let oA = document.createElement("a");
oA.download = text.substr(0, 20)+'.wav';
oA.href = audio;
document.body.appendChild(oA);
oA.click();
oA.remove();
}}
"""
def make_gradio_ui():
with gr.Blocks(css="#avatar {width:auto;height:200px;}") as demo:
# make title
gr.Markdown(
"""
# <center> VITS for Blue Archive Characters
## <center> This is for educational purposes, please do not generate harmful or inappropriate content\n
[![HitCount](https://hits.dwyl.com/tovaru/vits-for-ba.svg?style=flat-square&show=unique)](http://hits.dwyl.com/tovaru/vits-for-ba)\n
This is based on [zomehwh's implementation](https://huggingface.co/spaces/zomehwh/vits-models), please visit their space if you want to finetune your own models!
"""
)
first = list(metadata.items())[0][1]
with gr.Row():
# setup column
with gr.Column():
# set text input field
text = gr.Textbox(
lines=5,
interactive=True,
label=f"Text ({args.limit} words limit)" if args.limit > 0 else "Text",
value=first["example_text"],
elem_id="input-text")
# button for loading example texts
example_btn = gr.Button("Example", size="sm")
# character selector
names = []
for _, model in metadata.items():
names.append(model["name_en"])
character = gr.Radio(label="Character", choices=names, value=names[0], type="index")
# inference parameters
with gr.Row():
ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
with gr.Row():
ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1, interactive=True)
sid = gr.Number(value=default_sid, label="Speaker ID (increase or decrease to change the intonation)")
# generate button
generate_btn = gr.Button(value="Generate", variant="primary")
# results column
with gr.Column():
avatar = gr.Image(first["avatar"], type="pil", elem_id="avatar")
output_message = gr.Textbox(label="Result Message")
output_audio = gr.Audio(label="Output Audio", interactive=False, elem_id="output-audio")
download_btn = gr.Button("Download Audio")
# set character selection updates
def update_selection(index):
avatars = list(metadata.items())
new_avatar = avatars[index][1]["avatar"]
return gr.Image.update(value=new_avatar)
character.change(update_selection, character, avatar)
# set generate button action
generate_btn.click(fn=speak, inputs=[character, text, ns, nsw, ls, sid], outputs=[output_message, output_audio], api_name="generate")
# set download butto naction
download_btn.click(None, [], [], _js=download_audio_js)
def load_example(index):
example = list(metadata.items())[index][1]["example_text"]
return example
example_btn.click(fn=load_example, inputs=[character], outputs=[text])
return demo
if __name__ == "__main__":
demo = make_gradio_ui()
if args.queue > 0:
demo.queue(concurrency_count=args.queue, api_open=args.api)
demo.launch(share=args.share, show_api=args.api)
|