Naozumi0512 commited on
Commit
39c88a0
1 Parent(s): 1a79a73
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # flake8: noqa: E402
 
2
  import os
3
  import logging
4
  import re_matching
@@ -32,6 +33,7 @@ if device == "mps":
32
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
33
 
34
 
 
35
  def generate_audio(
36
  slices,
37
  sdp_ratio,
@@ -377,12 +379,10 @@ if __name__ == "__main__":
377
  hps = utils.get_hparams_from_file(config.webui_config.config_path)
378
  # 若config.json中未指定版本则默认为最新版本
379
  version = hps.version if hasattr(hps, "version") else latest_version
380
- net_g = get_net_g(
381
- model_path=config.webui_config.model, version=version, device=device, hps=hps
382
- )
383
  speaker_ids = hps.data.spk2id
384
  speakers = list(speaker_ids.keys())
385
- languages = ["ZH", "JP", "EN", "mix", "auto"]
386
  with gr.Blocks() as app:
387
  with gr.Row():
388
  with gr.Column():
 
1
  # flake8: noqa: E402
2
+ import spaces
3
  import os
4
  import logging
5
  import re_matching
 
33
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
34
 
35
 
36
+ @spaces.GPU
37
  def generate_audio(
38
  slices,
39
  sdp_ratio,
 
379
  hps = utils.get_hparams_from_file(config.webui_config.config_path)
380
  # 若config.json中未指定版本则默认为最新版本
381
  version = hps.version if hasattr(hps, "version") else latest_version
382
+ net_g = get_net_g(model_path=config.webui_config.model, device=device, hps=hps)
 
 
383
  speaker_ids = hps.data.spk2id
384
  speakers = list(speaker_ids.keys())
385
+ languages = ["HAKKA"]
386
  with gr.Blocks() as app:
387
  with gr.Row():
388
  with gr.Column():