File size: 7,990 Bytes
f2f3712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import re
import json
import argparse

# get arguments
parser = argparse.ArgumentParser(description="VITS model for Blue Archive characters")
parser.add_argument('--device', type=str, default='cpu', choices=["cpu","cuda"], help="inference method (default: cpu)")
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
parser.add_argument("--queue", type=int, default=1, help="max number of concurrent workers for the queue (default: 1, set to 0 to disable)")
parser.add_argument('--api', action="store_true", default=False, help="enable REST routes (allows to skip queue)")
parser.add_argument("--limit", type=int, default=100, help="prompt words limit (default: 100, set to 0 to disable)")
args = parser.parse_args()
#limit_text = os.getenv("SYSTEM") == "spaces"  # limit text and audio length in huggingface spaces

import torch
import utils
import commons
from models import SynthesizerTrn
from text import text_to_sequence

import gradio as gr

metadata_path = "./models/metadata.json"
config_path = "./models/config.json"
default_sid = 10

# load each model's definition
with open(metadata_path, "r", encoding="utf-8") as f:
    metadata = json.load(f)

##################### MODEL LOADING #####################
# loads one model and returns a container with all references to the model
device = torch.device(args.device)
def load_checkpoint(model_path):    
    hps = utils.get_hparams_from_file(config_path)
    net_g = SynthesizerTrn(
        len(hps.symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model)
    model = net_g.eval().to(device)
    pythomodel = utils.load_checkpoint(model_path, net_g, None)
    
    return {"hps": hps, "net_g": net_g}

# loads all models on system memory and returns a dict with all references
def load_models():
    models = {}
    for name, model_info in metadata.items():
        model = load_checkpoint(model_info["model"])
        models[name] = model

    return models

models = load_models()

##################### INFERENCE #####################
def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

def prepare_text(text):
    prepared_text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
    if args.limit > 0:
        text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
        max_len = args.limit
        if text_len > max_len:
            return None
    return prepared_text

# inferes a model
def speak(selected_index, text, noise_scale, noise_scale_w, length_scale, speaker_id):
    # get selected model
    model = list(models.items())[selected_index][1]
    # clean and truncate text
    text = prepare_text(text)
    if text is None:
        return ["Error: text is too long", None]
    
    stn_tst = get_text(text, model["hps"])
    with torch.no_grad():
        x_tst = stn_tst.unsqueeze(0).to(device)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
        sid = torch.LongTensor([speaker_id]).to(device)
        audio = model["net_g"].infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()

    return ["Success", (22050, audio)]

##################### GRADIO UI #####################
# javascript function for creating a download link, reads the audio file from the DOM and returns it with the prompt as the filename
download_audio_js = """
() =>{{
    let root = document.querySelector("body > gradio-app");
    if (root.shadowRoot != null)
        root = root.shadowRoot;
    let audio = root.querySelector("#output-audio").querySelector("audio");
    let text = root.querySelector("#input-text").querySelector("textarea");
    if (audio == undefined)
        return;
    text = text.value;
    if (text == undefined)
        text = Math.floor(Math.random()*100000000);
    audio = audio.src;
    let oA = document.createElement("a");
    oA.download = text.substr(0, 20)+'.wav';
    oA.href = audio;
    document.body.appendChild(oA);
    oA.click();
    oA.remove();
}}
"""
def make_gradio_ui():
    with gr.Blocks(css="#avatar {width:auto;height:200px;}") as demo:
        # make title
        gr.Markdown(
            """
            # <center> VITS for Blue Archive Characters
            ## <center> This is for educational purposes, please do not generate harmful or inappropriate content\n
            [![HitCount](https://hits.dwyl.com/tovaru/vits-for-ba.svg?style=flat-square&show=unique)](http://hits.dwyl.com/tovaru/vits-for-ba)\n
            This is based on [zomehwh's implementation](https://huggingface.co/spaces/zomehwh/vits-models), please visit their space if you want to finetune your own models!
            """
        )

        first = list(metadata.items())[0][1]
        with gr.Row():
            # setup column
            with gr.Column():
                # set text input field
                text = gr.Textbox(
                    lines=5,
                    interactive=True,
                    label=f"Text ({args.limit} words limit)" if args.limit > 0 else "Text",
                    value=first["example_text"],
                    elem_id="input-text")                

                # button for loading example texts
                example_btn = gr.Button("Example", size="sm")
                
                # character selector
                names = []
                for _, model in metadata.items():
                    names.append(model["name_en"])
                character = gr.Radio(label="Character", choices=names, value=names[0], type="index")

                # inference parameters
                with gr.Row():
                    ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
                    nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
                with gr.Row():
                    ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1, interactive=True)
                    sid = gr.Number(value=default_sid, label="Speaker ID (increase or decrease to change the intonation)")

                # generate button
                generate_btn = gr.Button(value="Generate", variant="primary")
                
            # results column
            with gr.Column():
                avatar = gr.Image(first["avatar"], type="pil", elem_id="avatar")
                output_message = gr.Textbox(label="Result Message")
                output_audio = gr.Audio(label="Output Audio", interactive=False, elem_id="output-audio")
                download_btn = gr.Button("Download Audio")

            # set character selection updates
            def update_selection(index):
                avatars = list(metadata.items())
                new_avatar = avatars[index][1]["avatar"]
                return gr.Image.update(value=new_avatar)
            
            character.change(update_selection, character, avatar)

        # set generate button action
        generate_btn.click(fn=speak, inputs=[character, text, ns, nsw, ls, sid], outputs=[output_message, output_audio], api_name="generate")

        # set download butto naction
        download_btn.click(None, [], [], _js=download_audio_js)

        def load_example(index):
            example = list(metadata.items())[index][1]["example_text"]
            return example
        example_btn.click(fn=load_example, inputs=[character], outputs=[text])
    
    return demo

if __name__ == "__main__":   
    demo = make_gradio_ui()
    if args.queue > 0:
        demo.queue(concurrency_count=args.queue, api_open=args.api)
    demo.launch(share=args.share, show_api=args.api)