File size: 8,073 Bytes
2d8ad0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c96e2d
2d8ad0f
 
 
5eb461c
2d8ad0f
 
 
 
 
 
 
 
 
 
0c151d7
 
2d8ad0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2eaf5c
3607437
2d8ad0f
 
3607437
2d8ad0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adaad9d
2d8ad0f
 
 
 
982af36
2d8ad0f
 
 
 
 
 
fd8a022
 
2d8ad0f
fd8a022
2d8ad0f
 
3607437
2d8ad0f
 
 
 
 
 
 
 
 
 
 
5eb461c
2d8ad0f
 
 
 
 
 
 
 
 
81369fa
2d8ad0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c96e2d
2d8ad0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adaad9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# import gradio as gr

# gr.Interface.load("models/ulysses115/pmvoice").launch()

import argparse
import json
import os
import re
import tempfile

import librosa
import numpy as np
import torch
from torch import no_grad, LongTensor
import commons
import utils
import gradio as gr
import gradio.utils as gr_utils
import gradio.processing_utils as gr_processing_utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence, _clean_text
from mel_processing import spectrogram_torch

limitation = False#os.getenv("SYSTEM") == "spaces"  # limit text and audio length in huggingface spaces


def audio_postprocess(self, y):
    if y is None:
        return None

    self.temp_dir = "./"
    
    if gr_utils.validate_url(y):
        file = gr_processing_utils.download_to_file(y, dir=self.temp_dir)
    elif isinstance(y, tuple):
        sample_rate, data = y
        file = tempfile.NamedTemporaryFile(
            suffix=".wav", dir=self.temp_dir, delete=False
        )
        gr_processing_utils.audio_to_file(sample_rate, data, file.name)
    else:
        file = gr_processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)

    return gr_processing_utils.encode_url_or_file_to_base64(file.name)


gr.Audio.postprocess = audio_postprocess

def get_text(text, hps, is_symbol):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

def create_tts_fn(model, hps, speaker_ids):
    def tts_fn(text, speaker, speed, is_symbol):
        if limitation:
            text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
            max_len = 150
            if is_symbol:
                max_len *= 3
            if text_len > max_len:
                return "Error: Text is too long", None

        speaker_id = speaker_ids[speaker]
        stn_tst = get_text(text, hps, is_symbol)
        with no_grad():
            x_tst = stn_tst.unsqueeze(0).to(device)
            x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
            sid = LongTensor([speaker_id]).to(device)
            audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
                                length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
        del stn_tst, x_tst, x_tst_lengths, sid
        return "Success", (hps.data.sampling_rate, audio)

    return tts_fn


def create_to_symbol_fn(hps):
    def to_symbol_fn(is_symbol_input, input_text, temp_text):
        return (_clean_text(input_text, hps.data.text_cleaners), input_text) if is_symbol_input \
            else (temp_text, temp_text)

    return to_symbol_fn


download_audio_js = """
() =>{{
    let root = document.querySelector("body > gradio-app");
    if (root.shadowRoot != null)
        root = root.shadowRoot;
    let audio = root.querySelector("#{audio_id}").querySelector("audio");
    if (audio == undefined)
        return;
    audio = audio.src;
    let oA = document.createElement("a");
    oA.download = Math.floor(Math.random()*100000000)+'.wav';
    oA.href = audio;
    document.body.appendChild(oA);
    oA.click();
    oA.remove();
}}
"""

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
    args = parser.parse_args()

    device = torch.device(args.device)
    models_tts = []
    with open("save_model/info.json", "r", encoding="utf-8") as f:
        models_info = json.load(f)
    for i, info in models_info.items():
        name = info["title"]
        author = info["author"]
        lang = info["lang"]
        example = info["example"]
        config_path = f"config.json"
        model_path = f"G_1434000.pth"
        cover = info["cover"]
        cover_path = cover 
        hps = utils.get_hparams_from_file(config_path)
        model = SynthesizerTrn(
            len(symbols),
            hps.data.filter_length // 2 + 1,
            hps.train.segment_size // hps.data.hop_length,
            **hps.model)
        utils.load_checkpoint(model_path, model, None)
        model.eval().to(device)
        speaker_ids = [sid for sid, name in enumerate(hps.speakers) if name != "None"]
        speakers = [name for sid, name in enumerate(hps.speakers) if name != "None"]

        t = info["type"]
        if t == "vits":
            models_tts.append((name, author, cover_path, speakers, lang, example,
                               symbols, create_tts_fn(model, hps, speaker_ids),
                               create_to_symbol_fn(hps)))

    app = gr.Blocks()

    with app:
        for i, (name, author, cover_path, speakers, lang, example, symbols, tts_fn,
                to_symbol_fn) in enumerate(models_tts):
            with gr.TabItem(f"model{i}"):
                with gr.Column():
                    tts_input1 = gr.TextArea(label="Text", value="你好,旅行者!我是派蒙~有什么可以帮助你的吗?",
                                             elem_id=f"tts-input{i}")
                    tts_input2 = gr.Dropdown(label="Speaker", choices=speakers,
                                             type="index", value=speakers[0])
                    tts_input3 = gr.Slider(label="Speed", value=1, minimum=0.5, maximum=2, step=0.1)
                    with gr.Accordion(label="Advanced Options", open=False):
                        temp_text_var = gr.Variable()
                        symbol_input = gr.Checkbox(value=False, label="Symbol input")
                        symbol_list = gr.Dataset(label="Symbol list", components=[tts_input1],
                                                 samples=[[x] for x in symbols],
                                                 elem_id=f"symbol-list{i}")
                        symbol_list_json = gr.Json(value=symbols, visible=False)
                    tts_submit = gr.Button("Generate", variant="primary")
                    tts_output1 = gr.Textbox(label="Output Message")
                    tts_output2 = gr.Audio(label="Output Audio", elem_id=f"tts-audio{i}")
                    download = gr.Button("Download Audio")
                    download.click(None, [], [], _js=download_audio_js.format(audio_id=f"tts-audio{i}"))

                    tts_submit.click(tts_fn, [tts_input1, tts_input2, tts_input3, symbol_input],
                                     [tts_output1, tts_output2])
                    symbol_input.change(to_symbol_fn,
                                        [symbol_input, tts_input1, temp_text_var],
                                        [tts_input1, temp_text_var])
                    symbol_list.click(None, [symbol_list, symbol_list_json], [],
                                      _js=f"""
                    (i,symbols) => {{
                        let root = document.querySelector("body > gradio-app");
                        if (root.shadowRoot != null)
                            root = root.shadowRoot;
                        let text_input = root.querySelector("#tts-input{i}").querySelector("textarea");
                        let startPos = text_input.selectionStart;
                        let endPos = text_input.selectionEnd;
                        let oldTxt = text_input.value;
                        let result = oldTxt.substring(0, startPos) + symbols[i] + oldTxt.substring(endPos);
                        text_input.value = result;
                        let x = window.scrollX, y = window.scrollY;
                        text_input.focus();
                        text_input.selectionStart = startPos + symbols[i].length;
                        text_input.selectionEnd = startPos + symbols[i].length;
                        text_input.blur();
                        window.scrollTo(x, y);
                        return [];
                    }}""")
    app.queue(concurrency_count=3).launch(show_api=True, share=args.share)