File size: 3,185 Bytes
72aa6e6
 
 
5a39f1e
72aa6e6
 
 
2e910d9
 
 
 
 
 
5a39f1e
 
72aa6e6
5a39f1e
 
 
 
 
2e910d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72aa6e6
2e910d9
 
 
 
 
 
 
5a39f1e
2e910d9
72aa6e6
5a39f1e
2e910d9
 
 
 
 
72aa6e6
2e910d9
3980d4c
a15140f
 
 
5a39f1e
 
 
 
2e910d9
 
 
f6fdc84
 
 
 
 
 
5a39f1e
 
 
 
f6fdc84
2e910d9
 
72aa6e6
5a39f1e
2e910d9
 
72aa6e6
5a39f1e
 
2e910d9
 
 
 
 
 
 
 
 
 
 
5a39f1e
2e910d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from scipy.io.wavfile import write
import numpy as np

import gradio as gr
import IPython.display as ipd

import commons
import utils
from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence


def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm


def get_text_byroma(text, hps):
    text_norm = []
    for i in text:
        text_norm.append(symbols.index(i))
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm


hps = utils.get_hparams_from_file("./configs/leo.json")
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    n_speakers=hps.data.n_speakers,
    **hps.model
)
_ = net_g.eval()

_ = utils.load_checkpoint("logs/leo/G_4000.pth", net_g, None)

# 随机抽取情感参考音频的根目录
random_emotion_root = "wavs"
emotion_dict = json.load(open("configs/leo.json", "r"))


def tts(txt, emotion, roma=False, length_scale=1):
    """emotion为参考情感音频路径 或random_sample(随机抽取)"""
    if roma:
        stn_tst = get_text_byroma(txt, hps)
    else:
        stn_tst = get_text(txt, hps)
    with torch.no_grad():
        x_tst = stn_tst.unsqueeze(0)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
        sid = torch.LongTensor([0])
        if os.path.exists(f"{emotion}.emo.npy"):
            emo = torch.FloatTensor(np.load(f"{emotion}.emo.npy")).unsqueeze(0)
        elif emotion == "random_sample":
            while True:
                rand_wav = random.sample(os.listdir(random_emotion_root), 1)[0]
                if rand_wav.endswith('wav') and os.path.exists(f"{random_emotion_root}/{rand_wav}.emo.npy"):
                    break
            emo = torch.FloatTensor(np.load(f"{random_emotion_root}/{rand_wav}.emo.npy")).unsqueeze(0)
            print(f"{random_emotion_root}/{rand_wav}")
        elif emotion.endswith("wav"):
            import emotion_extract
            emo = torch.FloatTensor(emotion_extract.extract_wav(emotion))
        else:
            print("emotion参数不正确")

        audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.2, emo=emo)[0][0,0].data.float().numpy()
    ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))


def run_tts(text, emotion, roma):
    tts(text, emotion, roma)


iface = gr.Interface(
    fn=run_tts,
    inputs=["text", "text", "checkbox"],
    outputs="audio",
    layout="vertical",
    title="TTS Demo",
    description="Generative TTS Demo with Emotional Control",
    allow_flagging=False,
    theme="huggingface",
    flagging_dir="flagged",
)

iface.launch(inline=True)