Pipe1213 commited on
Commit
06435a9
1 Parent(s): c7920d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -20
app.py CHANGED
@@ -13,18 +13,20 @@ import commons
13
  import utils
14
  from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
15
  from models import SynthesizerTrn
16
- from text import cleaners
 
17
  from scipy.io.wavfile import write
 
18
 
19
- # Define a dictionary to store the model paths and symbols for each tab
20
  model_configs = {
21
  "Phonemes_finetuned": {
22
  "path": "fr_wa_finetuned_pho/G_125000.pth",
23
- "symbols_module": "text.symbols"
24
  },
25
  "Phonemes": {
26
  "path": "wallon_pho/G_277000.pth",
27
- "symbols_module": "text.symbols_pho"
28
  }
29
  }
30
 
@@ -34,12 +36,6 @@ symbols = []
34
  _symbol_to_id = {}
35
  _id_to_symbol = {}
36
 
37
- def load_symbols(module_name):
38
- global symbols, _symbol_to_id, _id_to_symbol
39
- symbols = __import__(module_name, fromlist=['symbols']).symbols
40
- _symbol_to_id = {s: i for i, s in enumerate(symbols)}
41
- _id_to_symbol = {i: s for i, s in enumerate(symbols)}
42
-
43
  def text_to_sequence(text, cleaner_names):
44
  sequence = []
45
  clean_text = _clean_text(text, cleaner_names)
@@ -63,7 +59,13 @@ def get_text(text, hps):
63
  text_norm = torch.LongTensor(text_norm)
64
  return text_norm
65
 
66
- def load_model(model_path, hps):
 
 
 
 
 
 
67
  net_g = SynthesizerTrn(
68
  len(symbols),
69
  hps.data.filter_length // 2 + 1,
@@ -71,17 +73,10 @@ def load_model(model_path, hps):
71
  n_speakers=hps.data.n_speakers,
72
  **hps.model)
73
  _ = net_g.eval()
74
- _ = utils.load_checkpoint(model_path, net_g, None)
75
- return net_g
76
-
77
- def update_model_and_symbols(tab_name):
78
- global net_g, hps
79
- model_config = model_configs[tab_name]
80
- load_symbols(model_config["symbols_module"])
81
- net_g = load_model(model_config["path"], hps)
82
 
83
  def tts(text, speaker_id, tab_name):
84
- update_model_and_symbols(tab_name)
85
  sid = torch.LongTensor([speaker_id]) # speaker identity
86
  stn_tst = get_text(text, hps)
87
 
@@ -141,3 +136,4 @@ with app:
141
 
142
  app.launch()
143
 
 
 
13
  import utils
14
  from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
15
  from models import SynthesizerTrn
16
+ from text.symbols import symbols as symbols_default
17
+ from text.symbols_pho import symbols_pho
18
  from scipy.io.wavfile import write
19
+ from text import cleaners
20
 
21
+ # Define a dictionary to store the model paths and corresponding symbols
22
  model_configs = {
23
  "Phonemes_finetuned": {
24
  "path": "fr_wa_finetuned_pho/G_125000.pth",
25
+ "symbols": symbols_default
26
  },
27
  "Phonemes": {
28
  "path": "wallon_pho/G_277000.pth",
29
+ "symbols": symbols_pho
30
  }
31
  }
32
 
 
36
  _symbol_to_id = {}
37
  _id_to_symbol = {}
38
 
 
 
 
 
 
 
39
  def text_to_sequence(text, cleaner_names):
40
  sequence = []
41
  clean_text = _clean_text(text, cleaner_names)
 
59
  text_norm = torch.LongTensor(text_norm)
60
  return text_norm
61
 
62
+ def load_model_and_symbols(tab_name):
63
+ global net_g, symbols, _symbol_to_id, _id_to_symbol
64
+ model_config = model_configs[tab_name]
65
+ symbols = model_config["symbols"]
66
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
67
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
68
+
69
  net_g = SynthesizerTrn(
70
  len(symbols),
71
  hps.data.filter_length // 2 + 1,
 
73
  n_speakers=hps.data.n_speakers,
74
  **hps.model)
75
  _ = net_g.eval()
76
+ _ = utils.load_checkpoint(model_config["path"], net_g, None)
 
 
 
 
 
 
 
77
 
78
  def tts(text, speaker_id, tab_name):
79
+ load_model_and_symbols(tab_name)
80
  sid = torch.LongTensor([speaker_id]) # speaker identity
81
  stn_tst = get_text(text, hps)
82
 
 
136
 
137
  app.launch()
138
 
139
+