barghavani commited on
Commit
c687704
1 Parent(s): eb9806c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -115
app.py CHANGED
@@ -8,32 +8,7 @@ import json
8
 
9
  # Define constants
10
  MODEL_INFO = [
11
- #["vits checkpoint 57000", "checkpoint_57000.pth", "config.json", "mhrahmani/persian-tts-vits-0"],
12
- # ["VITS Grapheme Multispeaker CV15(reduct)(best at 17864)", "best_model_17864.pth", "config.json",
13
- # "saillab/persian-tts-cv15-reduct-grapheme-multispeaker"],
14
- #["Single speaker (best)VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
15
-
16
- #["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json","saillab/persian-tts-grapheme-arm24-finetuned-on1"],
17
- #["Single speaker female best VITS Grapheme CV-Azure_male-Azure_female","best_model_15397.pth","config.json","saillab/female_cv_azure_male_azure_female","speakers1.pth"],
18
-
19
- ["Xtts Persian","checkpoint_361721.pth","config.json","saillab/xtts_v2_fa"],
20
- #["Multi Speaker Vits Grapheme CV+Azure in one set ","best_model_358320.pth","config.json","saillab/Multi_Speaker_Cv_plus_Azure_female_in_one_set","speakers.pth"],
21
- #["Multispeaker VITS Grapheme CV15(reduct)(22000)", "checkpoint_22000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
22
- #["Multispeaker VITS Grapheme CV15(reduct)(26000)", "checkpoint_25000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
23
- #["Multispeaker VITS Grapheme CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker", "speakers.pth"],
24
- #["Single speaker female best VITS Grapheme CV-Azure_male-Azure_female","best_model_15397.pth","config.json","saillab/female_cv_azure_male_azure_female","speakers.pth"],
25
-
26
-
27
- # ["VITS Grapheme Azure (best at 15934)", "best_model_15934.pth", "config.json",
28
- # "saillab/persian-tts-azure-grapheme-60K"],
29
-
30
-
31
- #["Single speaker VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json","saillab/persian-tts-grapheme-arm24-finetuned-on1"],
32
- #["Single speaker VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json","saillab/persian-tts-grapheme-arm24-finetuned-on1"],
33
-
34
-
35
-
36
- # ... Add other models similarly
37
  ]
38
 
39
  # Extract model names from MODEL_INFO
@@ -44,37 +19,13 @@ TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
44
 
45
  model_files = {}
46
  config_files = {}
47
- speaker_files = {}
48
 
49
  # Create a dictionary to store synthesizer objects for each model
50
  synthesizers = {}
51
 
52
- def update_config_speakers_file_recursive(config_dict, speakers_path):
53
- """Recursively update speakers_file keys in a dictionary."""
54
- if "speakers_file" in config_dict:
55
- config_dict["speakers_file"] = speakers_path
56
- for key, value in config_dict.items():
57
- if isinstance(value, dict):
58
- update_config_speakers_file_recursive(value, speakers_path)
59
-
60
- def update_config_speakers_file(config_path, speakers_path):
61
- """Update the config.json file to point to the correct speakers.pth file."""
62
-
63
- # Load the existing config
64
- with open(config_path, 'r') as f:
65
- config = json.load(f)
66
-
67
- # Modify the speakers_file entry
68
- update_config_speakers_file_recursive(config, speakers_path)
69
-
70
- # Save the modified config
71
- with open(config_path, 'w') as f:
72
- json.dump(config, f, indent=4)
73
-
74
  # Download models and initialize synthesizers
75
  for info in MODEL_INFO:
76
  model_name, model_file, config_file, repo_name = info[:4]
77
- speaker_file = info[4] if len(info) == 5 else None # Check if speakers.pth is defined for the model
78
 
79
  print(f"|> Downloading: {model_name}")
80
 
@@ -82,94 +33,41 @@ for info in MODEL_INFO:
82
  model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
83
  config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
84
 
85
- # Download speakers.pth if it exists
86
- if speaker_file:
87
- speaker_files[model_name] = hf_hub_download(repo_id=repo_name, filename=speaker_file, use_auth_token=TOKEN)
88
- update_config_speakers_file(config_files[model_name], speaker_files[model_name]) # Update the config file
89
- print(speaker_files[model_name])
90
- # Initialize synthesizer for the model
91
- synthesizer = Synthesizer(
92
- tts_checkpoint=model_files[model_name],
93
- tts_config_path=config_files[model_name],
94
- tts_speakers_file=speaker_files[model_name], # Pass the speakers.pth file if it exists
95
- use_cuda=False # Assuming you don't want to use GPU, adjust if needed
96
- )
97
-
98
- elif speaker_file is None:
99
-
100
- # Initialize synthesizer for the model
101
- synthesizer = Synthesizer(
102
- tts_checkpoint=model_files[model_name],
103
- tts_config_path=config_files[model_name],
104
- # tts_speakers_file=speaker_files.get(model_name, None), # Pass the speakers.pth file if it exists
105
- use_cuda=False # Assuming you don't want to use GPU, adjust if needed
106
- )
107
 
108
  synthesizers[model_name] = synthesizer
109
 
110
 
111
-
112
-
113
- #def synthesize(text: str, model_name: str, speaker_name="speaker-0") -> str:
114
- def synthesize(text: str, model_name: str, speaker_name=None) -> str:
115
- """Synthesize speech using the selected model."""
116
  if len(text) > MAX_TXT_LEN:
117
  text = text[:MAX_TXT_LEN]
118
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
119
-
120
- # Use the synthesizer object for the selected model
121
  synthesizer = synthesizers[model_name]
122
-
123
-
124
  if synthesizer is None:
125
  raise NameError("Model not found")
 
126
 
127
- if synthesizer.tts_speakers_file is "":
128
- wavs = synthesizer.tts(text)
129
-
130
- elif synthesizer.tts_speakers_file is not "":
131
- if speaker_name == "":
132
- #wavs = synthesizer.tts(text, speaker_name="speaker-0") ## should change, better if gradio conditions are figure out.
133
- wavs = synthesizer.tts(text, speaker_name=None)
134
- else:
135
- wavs = synthesizer.tts(text, speaker_name=speaker_name)
136
-
137
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
138
  synthesizer.save_wav(wavs, fp)
139
  return fp.name
140
-
141
- # Callback function to update UI based on the selected model
142
- def update_options(model_name):
143
- synthesizer = synthesizers[model_name]
144
- # if synthesizer.tts.is_multi_speaker:
145
- if model_name is MODEL_NAMES[1]:
146
- speakers = synthesizer.tts_model.speaker_manager.speaker_names
147
- # return options for the dropdown
148
- return speakers
149
- else:
150
- # return empty options if not multi-speaker
151
- return []
152
-
153
- # Create Gradio interface
154
  iface = gr.Interface(
155
  fn=synthesize,
156
  inputs=[
157
  gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
158
  gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
159
- #gr.Dropdown(label="Select Speaker", choices=update_options(MODEL_NAMES[1]), type="value", default="speaker-0")
160
- gr.Dropdown(label="Select Speaker", choices=update_options(MODEL_NAMES[1]), type="value", default=None)
161
  ],
162
  outputs=gr.Audio(label="Output", type='filepath'),
163
- examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0], ""]], # Example should include a speaker name for multispeaker models
164
  title='Persian TTS Playground',
165
- description="""
166
- ### Persian text to speech model demo.
167
-
168
-
169
- #### Pick a speaker for MultiSpeaker models. (for single speaker go for speaker-0)
170
- """,
171
  article="",
172
  live=False
173
  )
174
 
175
- iface.launch()
 
8
 
9
  # Define constants
10
  MODEL_INFO = [
11
+ ["Xtts Persian","best_model_110880.pth","config.json","saillab/xtts_v2_fa"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ]
13
 
14
  # Extract model names from MODEL_INFO
 
19
 
20
  model_files = {}
21
  config_files = {}
 
22
 
23
  # Create a dictionary to store synthesizer objects for each model
24
  synthesizers = {}
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Download models and initialize synthesizers
27
  for info in MODEL_INFO:
28
  model_name, model_file, config_file, repo_name = info[:4]
 
29
 
30
  print(f"|> Downloading: {model_name}")
31
 
 
33
  model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
34
  config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
35
 
36
+ # Initialize synthesizer for the model
37
+ synthesizer = Synthesizer(
38
+ tts_checkpoint=model_files[model_name],
39
+ tts_config_path=config_files[model_name],
40
+ use_cuda=False
41
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  synthesizers[model_name] = synthesizer
44
 
45
 
46
+ def synthesize(text: str, model_name: str) -> str:
 
 
 
 
47
  if len(text) > MAX_TXT_LEN:
48
  text = text[:MAX_TXT_LEN]
49
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
 
 
50
  synthesizer = synthesizers[model_name]
 
 
51
  if synthesizer is None:
52
  raise NameError("Model not found")
53
+ wavs = synthesizer.tts(text)
54
 
 
 
 
 
 
 
 
 
 
 
55
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
56
  synthesizer.save_wav(wavs, fp)
57
  return fp.name
58
+
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  iface = gr.Interface(
60
  fn=synthesize,
61
  inputs=[
62
  gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
63
  gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
 
 
64
  ],
65
  outputs=gr.Audio(label="Output", type='filepath'),
66
+ examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]], # Example should include a speaker name for multispeaker models
67
  title='Persian TTS Playground',
68
+ description="",
 
 
 
 
 
69
  article="",
70
  live=False
71
  )
72
 
73
+ iface.launch()