Pipe1213 commited on
Commit
132e1e7
1 Parent(s): 1ccde86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -42
app.py CHANGED
@@ -13,38 +13,43 @@ import commons
13
  import utils
14
  from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
15
  from models import SynthesizerTrn
16
- from text.symbols import symbols_gra # import symbols graphemes
17
  from text.symbols_ft import symbols_ft # import symbols finetuned model
18
  from scipy.io.wavfile import write
19
  from text import cleaners
20
 
21
- symbols = symbols_gra
22
- _symbol_to_id = {s: i for i, s in enumerate(symbols)}
23
- _id_to_symbol = {i: s for i, s in enumerate(symbols)}
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def text_to_sequence(text, cleaner_names):
26
- '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
27
- Args:
28
- text: string to convert to a sequence
29
- cleaner_names: names of the cleaner functions to run the text through
30
- Returns:
31
- List of integers corresponding to the symbols in the text
32
- '''
33
- sequence = []
34
-
35
- clean_text = _clean_text(text, cleaner_names)
36
- for symbol in clean_text:
37
- symbol_id = _symbol_to_id[symbol]
38
- sequence += [symbol_id]
39
- return sequence
40
 
41
  def _clean_text(text, cleaner_names):
42
- for name in cleaner_names:
43
- cleaner = getattr(cleaners, name)
44
- if not cleaner:
45
- raise Exception('Unknown cleaner: %s' % name)
46
- text = cleaner(text)
47
- return text
48
 
49
  def get_text(text, hps):
50
  text_norm = text_to_sequence(text, hps.data.text_cleaners)
@@ -53,31 +58,25 @@ def get_text(text, hps):
53
  text_norm = torch.LongTensor(text_norm)
54
  return text_norm
55
 
56
- def load_model(model_path, hps):
 
 
 
 
 
 
57
  net_g = SynthesizerTrn(
58
- len(symbols), # change here
59
  hps.data.filter_length // 2 + 1,
60
  hps.train.segment_size // hps.data.hop_length,
61
  n_speakers=hps.data.n_speakers,
62
  **hps.model)
63
  _ = net_g.eval()
64
- _ = utils.load_checkpoint(model_path, net_g, None)
65
- return net_g
66
-
67
- hps = utils.get_hparams_from_file("wa_graphemes/config.json")
68
- #hps = utils.get_hparams_from_file("wa_graphemes/config.json")
69
-
70
- model_paths = {
71
- "Graphemes": "wa_graphemes/G_258000.pth"
72
- }
73
- #"Graphemes_ft": "fr_wa_finetune/G_198000.pth"
74
- # Load the model
75
- net_g = load_model(model_paths["Graphemes"], hps)
76
 
77
  def tts(text, speaker_id, tab_name):
78
- global net_g
79
- net_g = load_model(model_paths[tab_name], hps)
80
- sid = torch.LongTensor([speaker_id]) # speaker ID
81
  stn_tst = get_text(text, hps)
82
 
83
  with torch.no_grad():
@@ -90,13 +89,15 @@ def tts(text, speaker_id, tab_name):
90
  def create_tab(tab_name):
91
  with gr.TabItem(tab_name):
92
  gr.Markdown(f"### {tab_name} TTS Model")
93
- tts_input1 = gr.TextArea(label="Text in Walloon on graphemes", value="")
94
  tts_input2 = gr.Dropdown(label="Speaker", choices=["Male", "Female"], type="index", value="Male")
95
  tts_submit = gr.Button("Generate", variant="primary")
96
  tts_output1 = gr.Textbox(label="Message")
97
  tts_output2 = gr.Audio(label="Output")
98
  tts_submit.click(lambda text, speaker_id: tts(text, speaker_id, tab_name), [tts_input1, tts_input2], [tts_output1, tts_output2])
99
 
 
 
100
  app = gr.Blocks()
101
  with app:
102
  gr.Markdown(
 
13
  import utils
14
  from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
15
  from models import SynthesizerTrn
16
+ from text.symbols import symbols as symbols_default # import symbols graphemes
17
  from text.symbols_ft import symbols_ft # import symbols finetuned model
18
  from scipy.io.wavfile import write
19
  from text import cleaners
20
 
21
+ model_configs = {
22
+ "Graphemes_finetuned": {
23
+ "path": "fr_wa_graphemes/G_30000.pth",
24
+ "symbols": symbols_ft
25
+ },
26
+ "Graphemes": {
27
+ "path": "wa_graphemes/G_258000.pth",
28
+ "symbols": symbols_default
29
+ }
30
+ }
31
+
32
+ # Global variables
33
+ net_g = None
34
+ symbols = []
35
+ _symbol_to_id = {}
36
+ _id_to_symbol = {}
37
 
38
  def text_to_sequence(text, cleaner_names):
39
+ sequence = []
40
+ clean_text = _clean_text(text, cleaner_names)
41
+ for symbol in clean_text:
42
+ symbol_id = _symbol_to_id[symbol]
43
+ sequence += [symbol_id]
44
+ return sequence
 
 
 
 
 
 
 
 
45
 
46
  def _clean_text(text, cleaner_names):
47
+ for name in cleaner_names:
48
+ cleaner = getattr(cleaners, name)
49
+ if not cleaner:
50
+ raise Exception('Unknown cleaner: %s' % name)
51
+ text = cleaner(text)
52
+ return text
53
 
54
  def get_text(text, hps):
55
  text_norm = text_to_sequence(text, hps.data.text_cleaners)
 
58
  text_norm = torch.LongTensor(text_norm)
59
  return text_norm
60
 
61
+ def load_model_and_symbols(tab_name):
62
+ global net_g, symbols, _symbol_to_id, _id_to_symbol
63
+ model_config = model_configs[tab_name]
64
+ symbols = model_config["symbols"]
65
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
66
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
67
+
68
  net_g = SynthesizerTrn(
69
+ len(symbols),
70
  hps.data.filter_length // 2 + 1,
71
  hps.train.segment_size // hps.data.hop_length,
72
  n_speakers=hps.data.n_speakers,
73
  **hps.model)
74
  _ = net_g.eval()
75
+ _ = utils.load_checkpoint(model_config["path"], net_g, None)
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def tts(text, speaker_id, tab_name):
78
+ load_model_and_symbols(tab_name)
79
+ sid = torch.LongTensor([speaker_id]) # speaker identity
 
80
  stn_tst = get_text(text, hps)
81
 
82
  with torch.no_grad():
 
89
  def create_tab(tab_name):
90
  with gr.TabItem(tab_name):
91
  gr.Markdown(f"### {tab_name} TTS Model")
92
+ tts_input1 = gr.TextArea(label="Text in Walloon on IPA phonemes", value="")
93
  tts_input2 = gr.Dropdown(label="Speaker", choices=["Male", "Female"], type="index", value="Male")
94
  tts_submit = gr.Button("Generate", variant="primary")
95
  tts_output1 = gr.Textbox(label="Message")
96
  tts_output2 = gr.Audio(label="Output")
97
  tts_submit.click(lambda text, speaker_id: tts(text, speaker_id, tab_name), [tts_input1, tts_input2], [tts_output1, tts_output2])
98
 
99
+ hps = utils.get_hparams_from_file("configs/vctk_base.json")
100
+
101
  app = gr.Blocks()
102
  with app:
103
  gr.Markdown(