barghavani commited on
Commit
a244f0f
1 Parent(s): b39f00f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -12
app.py CHANGED
@@ -8,14 +8,23 @@ import json
8
 
9
  # Define constants
10
  MODEL_INFO = [
11
- ["VITS Grapheme Multispeaker CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker"],
12
-
 
 
 
 
 
 
 
13
  ["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
14
 
15
  ["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
16
  "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
17
  ["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
18
  "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
 
 
19
  ]
20
 
21
  # Extract model names from MODEL_INFO
@@ -26,14 +35,37 @@ TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
26
 
27
  model_files = {}
28
  config_files = {}
 
29
 
30
  # Create a dictionary to store synthesizer objects for each model
31
  synthesizers = {}
32
 
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Download models and initialize synthesizers
35
  for info in MODEL_INFO:
36
  model_name, model_file, config_file, repo_name = info[:4]
 
37
 
38
  print(f"|> Downloading: {model_name}")
39
 
@@ -41,42 +73,90 @@ for info in MODEL_INFO:
41
  model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
42
  config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
43
 
44
- # Initialize synthesizer for the model
45
- synthesizer = Synthesizer(tts_checkpoint=model_files[model_name],
46
- tts_config_path=config_files[model_name],
47
- use_cuda=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  synthesizers[model_name] = synthesizer
49
 
50
- def synthesize(text: str, model_name: str) -> str:
51
-
 
 
 
52
  if len(text) > MAX_TXT_LEN:
53
  text = text[:MAX_TXT_LEN]
54
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
55
 
 
56
  synthesizer = synthesizers[model_name]
 
 
57
  if synthesizer is None:
58
  raise NameError("Model not found")
59
 
60
- wavs = synthesizer.tts(text)
 
 
 
 
 
 
 
61
 
62
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
63
  synthesizer.save_wav(wavs, fp)
64
  return fp.name
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  iface = gr.Interface(
67
  fn=synthesize,
68
  inputs=[
69
  gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
70
  gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
 
71
  ],
72
  outputs=gr.Audio(label="Output", type='filepath'),
73
- examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
74
  title='Persian TTS Playground',
75
  description="""
76
  ### Persian text to speech model demo.
 
 
77
  """,
78
  article="",
79
  live=False
80
  )
81
 
82
- iface.launch()
 
8
 
9
  # Define constants
10
  MODEL_INFO = [
11
+ #["vits checkpoint 57000", "checkpoint_57000.pth", "config.json", "mhrahmani/persian-tts-vits-0"],
12
+ # ["VITS Grapheme Multispeaker CV15(reduct)(best at 17864)", "best_model_17864.pth", "config.json",
13
+ # "saillab/persian-tts-cv15-reduct-grapheme-multispeaker"],
14
+ ["VITS Grapheme Multispeaker CV15(reduct)(22000)", "checkpoint_22000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
15
+ ["VITS Grapheme Multispeaker CV15(reduct)(26000)", "checkpoint_25000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
16
+ ["VITS Grapheme Multispeaker CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker", "speakers.pth"],
17
+
18
+ # ["VITS Grapheme Azure (best at 15934)", "best_model_15934.pth", "config.json",
19
+ # "saillab/persian-tts-azure-grapheme-60K"],
20
  ["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
21
 
22
  ["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
23
  "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
24
  ["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
25
  "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
26
+
27
+ # ... Add other models similarly
28
  ]
29
 
30
  # Extract model names from MODEL_INFO
 
35
 
36
  model_files = {}
37
  config_files = {}
38
+ speaker_files = {}
39
 
40
  # Create a dictionary to store synthesizer objects for each model
41
  synthesizers = {}
42
 
43
+ def update_config_speakers_file_recursive(config_dict, speakers_path):
44
+ """Recursively update speakers_file keys in a dictionary."""
45
+ if "speakers_file" in config_dict:
46
+ config_dict["speakers_file"] = speakers_path
47
+ for key, value in config_dict.items():
48
+ if isinstance(value, dict):
49
+ update_config_speakers_file_recursive(value, speakers_path)
50
+
51
+ def update_config_speakers_file(config_path, speakers_path):
52
+ """Update the config.json file to point to the correct speakers.pth file."""
53
+
54
+ # Load the existing config
55
+ with open(config_path, 'r') as f:
56
+ config = json.load(f)
57
+
58
+ # Modify the speakers_file entry
59
+ update_config_speakers_file_recursive(config, speakers_path)
60
+
61
+ # Save the modified config
62
+ with open(config_path, 'w') as f:
63
+ json.dump(config, f, indent=4)
64
+
65
  # Download models and initialize synthesizers
66
  for info in MODEL_INFO:
67
  model_name, model_file, config_file, repo_name = info[:4]
68
+ speaker_file = info[4] if len(info) == 5 else None # Check if speakers.pth is defined for the model
69
 
70
  print(f"|> Downloading: {model_name}")
71
 
 
73
  model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
74
  config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
75
 
76
+ # Download speakers.pth if it exists
77
+ if speaker_file:
78
+ speaker_files[model_name] = hf_hub_download(repo_id=repo_name, filename=speaker_file, use_auth_token=TOKEN)
79
+ update_config_speakers_file(config_files[model_name], speaker_files[model_name]) # Update the config file
80
+ print(speaker_files[model_name])
81
+ # Initialize synthesizer for the model
82
+ synthesizer = Synthesizer(
83
+ tts_checkpoint=model_files[model_name],
84
+ tts_config_path=config_files[model_name],
85
+ tts_speakers_file=speaker_files[model_name], # Pass the speakers.pth file if it exists
86
+ use_cuda=False # Assuming you don't want to use GPU, adjust if needed
87
+ )
88
+
89
+ elif speaker_file is None:
90
+
91
+ # Initialize synthesizer for the model
92
+ synthesizer = Synthesizer(
93
+ tts_checkpoint=model_files[model_name],
94
+ tts_config_path=config_files[model_name],
95
+ # tts_speakers_file=speaker_files.get(model_name, None), # Pass the speakers.pth file if it exists
96
+ use_cuda=False # Assuming you don't want to use GPU, adjust if needed
97
+ )
98
+
99
  synthesizers[model_name] = synthesizer
100
 
101
+
102
+
103
+
104
+ def synthesize(text: str, model_name: str, speaker_name="speaker-0") -> str:
105
+ """Synthesize speech using the selected model."""
106
  if len(text) > MAX_TXT_LEN:
107
  text = text[:MAX_TXT_LEN]
108
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
109
 
110
+ # Use the synthesizer object for the selected model
111
  synthesizer = synthesizers[model_name]
112
+
113
+
114
  if synthesizer is None:
115
  raise NameError("Model not found")
116
 
117
+ if synthesizer.tts_speakers_file is "":
118
+ wavs = synthesizer.tts(text)
119
+
120
+ elif synthesizer.tts_speakers_file is not "":
121
+ if speaker_name == "":
122
+ wavs = synthesizer.tts(text, speaker_name="speaker-0") ## should change, better if gradio conditions are figure out.
123
+ else:
124
+ wavs = synthesizer.tts(text, speaker_name=speaker_name)
125
 
126
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
127
  synthesizer.save_wav(wavs, fp)
128
  return fp.name
129
 
130
+ # Callback function to update UI based on the selected model
131
+ def update_options(model_name):
132
+ synthesizer = synthesizers[model_name]
133
+ # if synthesizer.tts.is_multi_speaker:
134
+ if model_name is MODEL_NAMES[1]:
135
+ speakers = synthesizer.tts_model.speaker_manager.speaker_names
136
+ # return options for the dropdown
137
+ return speakers
138
+ else:
139
+ # return empty options if not multi-speaker
140
+ return []
141
+
142
+ # Create Gradio interface
143
  iface = gr.Interface(
144
  fn=synthesize,
145
  inputs=[
146
  gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
147
  gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
148
+ gr.Dropdown(label="Select Speaker", choices=update_options(MODEL_NAMES[1]), type="value", default="speaker-0")
149
  ],
150
  outputs=gr.Audio(label="Output", type='filepath'),
151
+ examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0], ""]], # Example should include a speaker name for multispeaker models
152
  title='Persian TTS Playground',
153
  description="""
154
  ### Persian text to speech model demo.
155
+
156
+ #### Pick a speaker for MultiSpeaker models. (It won't affect the single speaker models)
157
  """,
158
  article="",
159
  live=False
160
  )
161
 
162
+ iface.launch()