File size: 2,996 Bytes
804432e
 
 
a91158d
 
 
 
 
804432e
 
a91158d
 
 
 
 
 
 
804432e
a91158d
 
 
 
804432e
 
a91158d
804432e
a91158d
 
 
 
804432e
a91158d
804432e
a91158d
 
804432e
 
a91158d
804432e
a91158d
804432e
 
a91158d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804432e
 
a91158d
 
804432e
a91158d
 
 
804432e
 
a91158d
 
 
 
 
 
 
 
804432e
 
 
 
 
a91158d
804432e
 
 
 
 
a91158d
 
 
 
 
778d33f
804432e
778d33f
804432e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os

os.system('cd monotonic_align && python setup.py build_ext --inplace && cd ..')

import logging

numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)

import librosa

import matplotlib.pyplot as plt
import IPython.display as ipd

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text.cleaners import japanese_phrase_cleaners
from text import cleaned_text_to_sequence
from pypinyin import lazy_pinyin, Style

from scipy.io.wavfile import write

def get_text(text, hps):
    text_norm = cleaned_text_to_sequence(text)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm
# hps_ms = utils.get_hparams_from_file("./configs/vctk_base.json")


hps = utils.get_hparams_from_file("./configs/tokaiteio.json")
# net_g_ms = SynthesizerTrn(
#     len(symbols),
#     hps_ms.data.filter_length // 2 + 1,
#     hps_ms.train.segment_size // hps.data.hop_length,
#     n_speakers=hps_ms.data.n_speakers,
#     **hps_ms.model)

net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model)
_ = net_g.eval()


def tts(text):
    if len(text) > 150:
        return "Error: Text is too long", None
    stn_tst = get_text(text, hps)
    with torch.no_grad():
        x_tst = stn_tst.unsqueeze(0)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
        audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.float().numpy()
    ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate))


def tts_fn(text, speaker_id):
    if len(text) > 150:
        return "Error: Text is too long", None
    stn_tst = get_text(text, hps)
    with torch.no_grad():
        x_tst = stn_tst.unsqueeze(0)
        x_tst_lengths = LongTensor([stn_tst.size(0)])
        audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][
            0, 0].data.cpu().float().numpy()
    return "Success", (hps.data.sampling_rate, audio)


if __name__ == '__main__':
    _ = utils.load_checkpoint("G_50000.pth", net_g, None)

    app = gr.Blocks()

    with app:
        with gr.Tabs():
            with gr.Column():
                tts_input1 = gr.TextArea(label="Text (150 words limitation)", value="こんにちは。")
                tts_submit = gr.Button("Generate", variant="primary")
                tts_output1 = gr.Textbox(label="Output Message")
                tts_output2 = gr.Audio(label="Output Audio")

        tts_submit.click(tts_fn, [tts_input1, tts_input2], [tts_output1, tts_output2])

    app.launch()