mhrahmani commited on
Commit
d60ac5e
1 Parent(s): d6e1594

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -20
app.py CHANGED
@@ -6,30 +6,26 @@ from huggingface_hub import hf_hub_download
6
 
7
  # Define constants
8
  MODEL_INFO = [
9
- # ["Model Name", "Model File", "Config File", "Hub URL"]
10
  ["vits-espeak-57000", "checkpoint_57000.pth", "config.json", "mhrahmani/persian-tts-vits-0"],
11
- # Add other models similarly...
12
  ]
13
 
14
  # Extract model names from MODEL_INFO
15
  MODEL_NAMES = [info[0] for info in MODEL_INFO]
16
 
17
  MAX_TXT_LEN = 400
18
- TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN') # Replace with the environment variable containing your token, if different
19
 
 
 
20
 
21
- # Verify if the files are downloaded correctly
22
  for model_name, model_file, config_file, repo_name in MODEL_INFO:
23
- # os.makedirs(model_name, exist_ok=True)
24
  print(f"|> Downloading: {model_name}")
25
-
26
  model_file_path = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
27
  config_file_path = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
28
-
29
- # # Check if the files exist after download
30
- # if not os.path.exists(model_file_path) or not os.path.exists(config_file_path):
31
- # raise FileNotFoundError(f"Failed to download files for {model_name}. Please check the repository and file names.")
32
 
 
33
 
34
  def synthesize(text: str, model_name: str) -> str:
35
  """Synthesize speech using the selected model."""
@@ -37,20 +33,17 @@ def synthesize(text: str, model_name: str) -> str:
37
  text = text[:MAX_TXT_LEN]
38
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
39
 
40
- # Extract model_file and config_file based on the model_name
41
- model_file, config_file = next((model_file, config_file) for name, model_file, config_file, _ in MODEL_INFO if name == model_name)
42
-
43
- synthesizer = Synthesizer(model_file_path, config_file_path)
44
  if synthesizer is None:
45
  raise NameError("Model not found")
46
-
47
  wavs = synthesizer.tts(text)
48
 
49
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
50
  synthesizer.save_wav(wavs, fp)
51
  return fp.name
52
 
53
- # Define Gradio interface
54
  iface = gr.Interface(
55
  fn=synthesize,
56
  inputs=[
@@ -59,12 +52,10 @@ iface = gr.Interface(
59
  ],
60
  outputs=gr.Audio(label="Output", type='filepath'),
61
  examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
62
- title='persian tts playground',
63
- description="Persian text to speech model demo", # Add the required description here.
64
  article="",
65
  live=False
66
  )
67
 
68
- # Launch the interface
69
  iface.launch(share=False)
70
-
 
6
 
7
  # Define constants
8
  MODEL_INFO = [
 
9
  ["vits-espeak-57000", "checkpoint_57000.pth", "config.json", "mhrahmani/persian-tts-vits-0"],
 
10
  ]
11
 
12
  # Extract model names from MODEL_INFO
13
  MODEL_NAMES = [info[0] for info in MODEL_INFO]
14
 
15
  MAX_TXT_LEN = 400
16
+ TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
17
 
18
+ # Dictionary to keep synthesizers
19
+ synthesizers = {}
20
 
21
+ # Download files and create synthesizers
22
  for model_name, model_file, config_file, repo_name in MODEL_INFO:
 
23
  print(f"|> Downloading: {model_name}")
24
+
25
  model_file_path = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
26
  config_file_path = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
 
 
 
 
27
 
28
+ synthesizers[model_name] = Synthesizer(model_file_path, config_file_path)
29
 
30
  def synthesize(text: str, model_name: str) -> str:
31
  """Synthesize speech using the selected model."""
 
33
  text = text[:MAX_TXT_LEN]
34
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
35
 
36
+ synthesizer = synthesizers[model_name]
 
 
 
37
  if synthesizer is None:
38
  raise NameError("Model not found")
39
+
40
  wavs = synthesizer.tts(text)
41
 
42
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
43
  synthesizer.save_wav(wavs, fp)
44
  return fp.name
45
 
46
+
47
  iface = gr.Interface(
48
  fn=synthesize,
49
  inputs=[
 
52
  ],
53
  outputs=gr.Audio(label="Output", type='filepath'),
54
  examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
55
+ title='Persian TTS Playground',
56
+ description="Persian text to speech model demo",
57
  article="",
58
  live=False
59
  )
60
 
 
61
  iface.launch(share=False)