Spaces:
Sleeping
Sleeping
barghavani
commited on
Commit
•
a244f0f
1
Parent(s):
b39f00f
Update app.py
Browse files
app.py
CHANGED
@@ -8,14 +8,23 @@ import json
|
|
8 |
|
9 |
# Define constants
|
10 |
MODEL_INFO = [
|
11 |
-
["
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
|
14 |
|
15 |
["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
|
16 |
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
17 |
["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
|
18 |
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
|
|
|
|
19 |
]
|
20 |
|
21 |
# Extract model names from MODEL_INFO
|
@@ -26,14 +35,37 @@ TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
|
|
26 |
|
27 |
model_files = {}
|
28 |
config_files = {}
|
|
|
29 |
|
30 |
# Create a dictionary to store synthesizer objects for each model
|
31 |
synthesizers = {}
|
32 |
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Download models and initialize synthesizers
|
35 |
for info in MODEL_INFO:
|
36 |
model_name, model_file, config_file, repo_name = info[:4]
|
|
|
37 |
|
38 |
print(f"|> Downloading: {model_name}")
|
39 |
|
@@ -41,42 +73,90 @@ for info in MODEL_INFO:
|
|
41 |
model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
|
42 |
config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
|
43 |
|
44 |
-
#
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
synthesizers[model_name] = synthesizer
|
49 |
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
52 |
if len(text) > MAX_TXT_LEN:
|
53 |
text = text[:MAX_TXT_LEN]
|
54 |
print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
|
55 |
|
|
|
56 |
synthesizer = synthesizers[model_name]
|
|
|
|
|
57 |
if synthesizer is None:
|
58 |
raise NameError("Model not found")
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
63 |
synthesizer.save_wav(wavs, fp)
|
64 |
return fp.name
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
iface = gr.Interface(
|
67 |
fn=synthesize,
|
68 |
inputs=[
|
69 |
gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
|
70 |
gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
|
|
|
71 |
],
|
72 |
outputs=gr.Audio(label="Output", type='filepath'),
|
73 |
-
examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
|
74 |
title='Persian TTS Playground',
|
75 |
description="""
|
76 |
### Persian text to speech model demo.
|
|
|
|
|
77 |
""",
|
78 |
article="",
|
79 |
live=False
|
80 |
)
|
81 |
|
82 |
-
iface.launch()
|
|
|
8 |
|
9 |
# Define constants
|
10 |
MODEL_INFO = [
|
11 |
+
#["vits checkpoint 57000", "checkpoint_57000.pth", "config.json", "mhrahmani/persian-tts-vits-0"],
|
12 |
+
# ["VITS Grapheme Multispeaker CV15(reduct)(best at 17864)", "best_model_17864.pth", "config.json",
|
13 |
+
# "saillab/persian-tts-cv15-reduct-grapheme-multispeaker"],
|
14 |
+
["VITS Grapheme Multispeaker CV15(reduct)(22000)", "checkpoint_22000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
|
15 |
+
["VITS Grapheme Multispeaker CV15(reduct)(26000)", "checkpoint_25000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
|
16 |
+
["VITS Grapheme Multispeaker CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker", "speakers.pth"],
|
17 |
+
|
18 |
+
# ["VITS Grapheme Azure (best at 15934)", "best_model_15934.pth", "config.json",
|
19 |
+
# "saillab/persian-tts-azure-grapheme-60K"],
|
20 |
["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
|
21 |
|
22 |
["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
|
23 |
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
24 |
["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
|
25 |
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
26 |
+
|
27 |
+
# ... Add other models similarly
|
28 |
]
|
29 |
|
30 |
# Extract model names from MODEL_INFO
|
|
|
35 |
|
36 |
model_files = {}
|
37 |
config_files = {}
|
38 |
+
speaker_files = {}
|
39 |
|
40 |
# Create a dictionary to store synthesizer objects for each model
|
41 |
synthesizers = {}
|
42 |
|
43 |
+
def update_config_speakers_file_recursive(config_dict, speakers_path):
|
44 |
+
"""Recursively update speakers_file keys in a dictionary."""
|
45 |
+
if "speakers_file" in config_dict:
|
46 |
+
config_dict["speakers_file"] = speakers_path
|
47 |
+
for key, value in config_dict.items():
|
48 |
+
if isinstance(value, dict):
|
49 |
+
update_config_speakers_file_recursive(value, speakers_path)
|
50 |
+
|
51 |
+
def update_config_speakers_file(config_path, speakers_path):
|
52 |
+
"""Update the config.json file to point to the correct speakers.pth file."""
|
53 |
+
|
54 |
+
# Load the existing config
|
55 |
+
with open(config_path, 'r') as f:
|
56 |
+
config = json.load(f)
|
57 |
+
|
58 |
+
# Modify the speakers_file entry
|
59 |
+
update_config_speakers_file_recursive(config, speakers_path)
|
60 |
+
|
61 |
+
# Save the modified config
|
62 |
+
with open(config_path, 'w') as f:
|
63 |
+
json.dump(config, f, indent=4)
|
64 |
+
|
65 |
# Download models and initialize synthesizers
|
66 |
for info in MODEL_INFO:
|
67 |
model_name, model_file, config_file, repo_name = info[:4]
|
68 |
+
speaker_file = info[4] if len(info) == 5 else None # Check if speakers.pth is defined for the model
|
69 |
|
70 |
print(f"|> Downloading: {model_name}")
|
71 |
|
|
|
73 |
model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
|
74 |
config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
|
75 |
|
76 |
+
# Download speakers.pth if it exists
|
77 |
+
if speaker_file:
|
78 |
+
speaker_files[model_name] = hf_hub_download(repo_id=repo_name, filename=speaker_file, use_auth_token=TOKEN)
|
79 |
+
update_config_speakers_file(config_files[model_name], speaker_files[model_name]) # Update the config file
|
80 |
+
print(speaker_files[model_name])
|
81 |
+
# Initialize synthesizer for the model
|
82 |
+
synthesizer = Synthesizer(
|
83 |
+
tts_checkpoint=model_files[model_name],
|
84 |
+
tts_config_path=config_files[model_name],
|
85 |
+
tts_speakers_file=speaker_files[model_name], # Pass the speakers.pth file if it exists
|
86 |
+
use_cuda=False # Assuming you don't want to use GPU, adjust if needed
|
87 |
+
)
|
88 |
+
|
89 |
+
elif speaker_file is None:
|
90 |
+
|
91 |
+
# Initialize synthesizer for the model
|
92 |
+
synthesizer = Synthesizer(
|
93 |
+
tts_checkpoint=model_files[model_name],
|
94 |
+
tts_config_path=config_files[model_name],
|
95 |
+
# tts_speakers_file=speaker_files.get(model_name, None), # Pass the speakers.pth file if it exists
|
96 |
+
use_cuda=False # Assuming you don't want to use GPU, adjust if needed
|
97 |
+
)
|
98 |
+
|
99 |
synthesizers[model_name] = synthesizer
|
100 |
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def synthesize(text: str, model_name: str, speaker_name="speaker-0") -> str:
|
105 |
+
"""Synthesize speech using the selected model."""
|
106 |
if len(text) > MAX_TXT_LEN:
|
107 |
text = text[:MAX_TXT_LEN]
|
108 |
print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
|
109 |
|
110 |
+
# Use the synthesizer object for the selected model
|
111 |
synthesizer = synthesizers[model_name]
|
112 |
+
|
113 |
+
|
114 |
if synthesizer is None:
|
115 |
raise NameError("Model not found")
|
116 |
|
117 |
+
if synthesizer.tts_speakers_file is "":
|
118 |
+
wavs = synthesizer.tts(text)
|
119 |
+
|
120 |
+
elif synthesizer.tts_speakers_file is not "":
|
121 |
+
if speaker_name == "":
|
122 |
+
wavs = synthesizer.tts(text, speaker_name="speaker-0") ## should change, better if gradio conditions are figure out.
|
123 |
+
else:
|
124 |
+
wavs = synthesizer.tts(text, speaker_name=speaker_name)
|
125 |
|
126 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
127 |
synthesizer.save_wav(wavs, fp)
|
128 |
return fp.name
|
129 |
|
130 |
+
# Callback function to update UI based on the selected model
|
131 |
+
def update_options(model_name):
|
132 |
+
synthesizer = synthesizers[model_name]
|
133 |
+
# if synthesizer.tts.is_multi_speaker:
|
134 |
+
if model_name is MODEL_NAMES[1]:
|
135 |
+
speakers = synthesizer.tts_model.speaker_manager.speaker_names
|
136 |
+
# return options for the dropdown
|
137 |
+
return speakers
|
138 |
+
else:
|
139 |
+
# return empty options if not multi-speaker
|
140 |
+
return []
|
141 |
+
|
142 |
+
# Create Gradio interface
|
143 |
iface = gr.Interface(
|
144 |
fn=synthesize,
|
145 |
inputs=[
|
146 |
gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
|
147 |
gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
|
148 |
+
gr.Dropdown(label="Select Speaker", choices=update_options(MODEL_NAMES[1]), type="value", default="speaker-0")
|
149 |
],
|
150 |
outputs=gr.Audio(label="Output", type='filepath'),
|
151 |
+
examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0], ""]], # Example should include a speaker name for multispeaker models
|
152 |
title='Persian TTS Playground',
|
153 |
description="""
|
154 |
### Persian text to speech model demo.
|
155 |
+
|
156 |
+
#### Pick a speaker for MultiSpeaker models. (It won't affect the single speaker models)
|
157 |
""",
|
158 |
article="",
|
159 |
live=False
|
160 |
)
|
161 |
|
162 |
+
iface.launch()
|