chilge commited on
Commit
2e910d9
1 Parent(s): 03a7259

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -29
app.py CHANGED
@@ -1,6 +1,3 @@
1
- import gradio as gr
2
- import matplotlib.pyplot as plt
3
- import IPython.display as ipd
4
  import os
5
  import json
6
  import math
@@ -8,39 +5,68 @@ import torch
8
  from torch import nn
9
  from torch.nn import functional as F
10
  from torch.utils.data import DataLoader
 
 
 
 
 
 
11
  import commons
12
  import utils
13
  from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate
14
  from models import SynthesizerTrn
15
  from text.symbols import symbols
16
  from text import text_to_sequence
17
- from scipy.io.wavfile import write
18
- import numpy as np
19
 
20
- # 加载情感字典
21
- emotion_dict = json.load(open("configs/leo.json", "r"))
22
 
23
- # 加载预训练模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  hps = utils.get_hparams_from_file("./configs/leo.json")
25
- net_g = SynthesizerTrn(len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model)
 
 
 
 
 
 
26
  _ = net_g.eval()
 
27
  _ = utils.load_checkpoint("logs/leo/G_4000.pth", net_g, None)
28
 
29
- # 定义文本转语音函数
 
 
 
 
30
  def tts(txt, emotion, roma=False, length_scale=1):
 
31
  if roma:
32
  stn_tst = get_text_byroma(txt, hps)
33
  else:
34
  stn_tst = get_text(txt, hps)
35
-
36
  with torch.no_grad():
37
  x_tst = stn_tst.unsqueeze(0)
38
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
39
  sid = torch.LongTensor([0])
40
-
41
- if emotion == "random_sample":
42
- # 随机选择一个情感参考音频
43
- random_emotion_root = "wavs"
44
  while True:
45
  rand_wav = random.sample(os.listdir(random_emotion_root), 1)[0]
46
  if rand_wav.endswith('wav') and os.path.exists(f"{random_emotion_root}/{rand_wav}.emo.npy"):
@@ -48,27 +74,29 @@ def tts(txt, emotion, roma=False, length_scale=1):
48
  emo = torch.FloatTensor(np.load(f"{random_emotion_root}/{rand_wav}.emo.npy")).unsqueeze(0)
49
  print(f"{random_emotion_root}/{rand_wav}")
50
  elif emotion.endswith("wav"):
51
- # 从提供的音频中提取情感特征
52
  import emotion_extract
53
  emo = torch.FloatTensor(emotion_extract.extract_wav(emotion))
54
  else:
55
  print("emotion参数不正确")
56
-
57
- audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.2, emo=emo)[0][0, 0].data.float().numpy()
58
-
59
  ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))
60
 
61
- # 创建GUI界面
62
- def run_tts(text, emotion, roma=False):
63
  tts(text, emotion, roma)
64
 
65
- inputs = [
66
- gr.inputs.Textbox(label="请输入文本"),
67
- gr.inputs.Textbox(label="请输入参考音频路径或选择'random_sample'随机选择"),
68
- gr.inputs.Checkbox(label="是否使用音素合成")
69
- ]
70
 
71
- outputs = gr.outputs.Audio(type="numpy",label="合成音频")
 
 
 
 
 
 
 
 
 
 
72
 
73
- interface = gr.Interface(fn=run_tts, inputs=inputs, outputs=outputs, title="中文文本转语音")
74
- interface.launch()
 
 
 
 
1
  import os
2
  import json
3
  import math
 
5
  from torch import nn
6
  from torch.nn import functional as F
7
  from torch.utils.data import DataLoader
8
+ from scipy.io.wavfile import write
9
+ import numpy as np
10
+
11
+ import gradio as gr
12
+ import IPython.display as ipd
13
+
14
  import commons
15
  import utils
16
  from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate
17
  from models import SynthesizerTrn
18
  from text.symbols import symbols
19
  from text import text_to_sequence
 
 
20
 
 
 
21
 
22
+ def get_text(text, hps):
23
+ text_norm = text_to_sequence(text, hps.data.text_cleaners)
24
+ if hps.data.add_blank:
25
+ text_norm = commons.intersperse(text_norm, 0)
26
+ text_norm = torch.LongTensor(text_norm)
27
+ return text_norm
28
+
29
+
30
+ def get_text_byroma(text, hps):
31
+ text_norm = []
32
+ for i in text:
33
+ text_norm.append(symbols.index(i))
34
+ if hps.data.add_blank:
35
+ text_norm = commons.intersperse(text_norm, 0)
36
+ text_norm = torch.LongTensor(text_norm)
37
+ return text_norm
38
+
39
+
40
  hps = utils.get_hparams_from_file("./configs/leo.json")
41
+ net_g = SynthesizerTrn(
42
+ len(symbols),
43
+ hps.data.filter_length // 2 + 1,
44
+ hps.train.segment_size // hps.data.hop_length,
45
+ n_speakers=hps.data.n_speakers,
46
+ **hps.model
47
+ )
48
  _ = net_g.eval()
49
+
50
  _ = utils.load_checkpoint("logs/leo/G_4000.pth", net_g, None)
51
 
52
+ # 随机抽取情感参考音频的根目录
53
+ random_emotion_root = "wavs"
54
+ emotion_dict = json.load(open("configs/leo.json", "r"))
55
+
56
+
57
  def tts(txt, emotion, roma=False, length_scale=1):
58
+ """emotion为参考情感音频路径 或random_sample(随机抽取)"""
59
  if roma:
60
  stn_tst = get_text_byroma(txt, hps)
61
  else:
62
  stn_tst = get_text(txt, hps)
 
63
  with torch.no_grad():
64
  x_tst = stn_tst.unsqueeze(0)
65
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
66
  sid = torch.LongTensor([0])
67
+ if os.path.exists(f"{emotion}.emo.npy"):
68
+ emo = torch.FloatTensor(np.load(f"{emotion}.emo.npy")).unsqueeze(0)
69
+ elif emotion == "random_sample":
 
70
  while True:
71
  rand_wav = random.sample(os.listdir(random_emotion_root), 1)[0]
72
  if rand_wav.endswith('wav') and os.path.exists(f"{random_emotion_root}/{rand_wav}.emo.npy"):
 
74
  emo = torch.FloatTensor(np.load(f"{random_emotion_root}/{rand_wav}.emo.npy")).unsqueeze(0)
75
  print(f"{random_emotion_root}/{rand_wav}")
76
  elif emotion.endswith("wav"):
 
77
  import emotion_extract
78
  emo = torch.FloatTensor(emotion_extract.extract_wav(emotion))
79
  else:
80
  print("emotion参数不正确")
81
+
82
+ audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.2, emo=emo)[0][0,0].data.float().numpy()
 
83
  ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))
84
 
85
+
86
+ def run_tts(text, emotion, roma):
87
  tts(text, emotion, roma)
88
 
 
 
 
 
 
89
 
90
+ iface = gr.Interface(
91
+ fn=run_tts,
92
+ inputs=["text", "text", "checkbox"],
93
+ outputs="audio",
94
+ layout="vertical",
95
+ title="TTS Demo",
96
+ description="Generative TTS Demo with Emotional Control",
97
+ allow_flagging=False,
98
+ theme="huggingface",
99
+ flagging_dir="flagged",
100
+ )
101
 
102
+ iface.launch(inline=True)