zhzluke96 commited on
Commit
627d3d7
1 Parent(s): 72bb5b8
language/zh-CN.json CHANGED
@@ -31,8 +31,8 @@
31
  "🔊Generate": "🔊生成",
32
  "Disable Normalize": "禁用文本预处理",
33
  "💪🏼Enhance": "💪🏼增强",
34
- "Enable Enhance": "启用增强",
35
- "Enable De-noise": "启用降噪",
36
  "🔊Generate Audio": "🔊生成音频",
37
  "SSML": "SSML",
38
  "Editor": "编辑器",
 
31
  "🔊Generate": "🔊生成",
32
  "Disable Normalize": "禁用文本预处理",
33
  "💪🏼Enhance": "💪🏼增强",
34
+ "Enable Enhance": "启用人声增强",
35
+ "Enable De-noise": "启用背景降噪",
36
  "🔊Generate Audio": "🔊生成音频",
37
  "SSML": "SSML",
38
  "Editor": "编辑器",
launch.py CHANGED
@@ -1,201 +1,69 @@
1
  import os
2
  import logging
3
 
 
 
 
 
4
  logging.basicConfig(
5
  level=os.getenv("LOG_LEVEL", "INFO"),
6
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
7
  )
8
 
9
- from modules.devices import devices
10
  import argparse
11
  import uvicorn
12
 
13
- import torch
14
  from modules import config
15
  from modules.utils import env
16
- from modules import generate_audio as generate
17
- from modules.api.Api import APIManager
18
 
19
- from modules.api.impl import (
20
- style_api,
21
- tts_api,
22
- ssml_api,
23
- google_api,
24
- openai_api,
25
- refiner_api,
26
- speaker_api,
27
- ping_api,
28
- models_api,
29
- )
30
 
31
  logger = logging.getLogger(__name__)
32
 
33
- torch._dynamo.config.cache_size_limit = 64
34
- torch._dynamo.config.suppress_errors = True
35
- torch.set_float32_matmul_precision("high")
36
-
37
-
38
- def create_api(app, no_docs=False, exclude=[]):
39
- app_mgr = APIManager(app=app, no_docs=no_docs, exclude_patterns=exclude)
40
-
41
- ping_api.setup(app_mgr)
42
- models_api.setup(app_mgr)
43
- style_api.setup(app_mgr)
44
- speaker_api.setup(app_mgr)
45
- tts_api.setup(app_mgr)
46
- ssml_api.setup(app_mgr)
47
- google_api.setup(app_mgr)
48
- openai_api.setup(app_mgr)
49
- refiner_api.setup(app_mgr)
50
-
51
- return app_mgr
52
-
53
-
54
- def get_and_update_env(*args):
55
- val = env.get_env_or_arg(*args)
56
- key = args[1]
57
- config.runtime_env_vars[key] = val
58
- return val
59
-
60
-
61
- def setup_model_args(parser: argparse.ArgumentParser):
62
- parser.add_argument("--compile", action="store_true", help="Enable model compile")
63
- parser.add_argument(
64
- "--half",
65
- action="store_true",
66
- help="Enable half precision for model inference",
67
- )
68
- parser.add_argument(
69
- "--off_tqdm",
70
- action="store_true",
71
- help="Disable tqdm progress bar",
72
- )
73
- parser.add_argument(
74
- "--device_id",
75
- type=str,
76
- help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
77
- default=None,
78
- )
79
- parser.add_argument(
80
- "--use_cpu",
81
- nargs="+",
82
- help="use CPU as torch device for specified modules",
83
- default=[],
84
- type=str.lower,
85
- )
86
- parser.add_argument(
87
- "--lru_size",
88
- type=int,
89
- default=64,
90
- help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
91
- )
92
-
93
-
94
- def setup_api_args(parser: argparse.ArgumentParser):
95
- parser.add_argument("--api_host", type=str, help="Host to run the server on")
96
- parser.add_argument("--api_port", type=int, help="Port to run the server on")
97
- parser.add_argument(
98
- "--reload", action="store_true", help="Enable auto-reload for development"
99
- )
100
- parser.add_argument(
101
- "--cors_origin",
102
- type=str,
103
- help="Allowed CORS origins. Use '*' to allow all origins.",
104
- )
105
- parser.add_argument(
106
- "--no_playground",
107
- action="store_true",
108
- help="Disable the playground entry",
109
- )
110
- parser.add_argument(
111
- "--no_docs",
112
- action="store_true",
113
- help="Disable the documentation entry",
114
- )
115
- # 配置哪些api要跳过 比如 exclude="/v1/speakers/*,/v1/tts/*"
116
- parser.add_argument(
117
- "--exclude",
118
- type=str,
119
- help="Exclude the specified API from the server",
120
- )
121
-
122
-
123
- def process_model_args(args):
124
- lru_size = get_and_update_env(args, "lru_size", 64, int)
125
- compile = get_and_update_env(args, "compile", False, bool)
126
- device_id = get_and_update_env(args, "device_id", None, str)
127
- use_cpu = get_and_update_env(args, "use_cpu", [], list)
128
- half = get_and_update_env(args, "half", False, bool)
129
- off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
130
-
131
- generate.setup_lru_cache()
132
- devices.reset_device()
133
- devices.first_time_calculation()
134
-
135
-
136
- def process_api_args(args, app):
137
- cors_origin = get_and_update_env(args, "cors_origin", "*", str)
138
- no_playground = get_and_update_env(args, "no_playground", False, bool)
139
- no_docs = get_and_update_env(args, "no_docs", False, bool)
140
- exclude = get_and_update_env(args, "exclude", "", str)
141
-
142
- api = create_api(app=app, no_docs=no_docs, exclude=exclude.split(","))
143
- config.api = api
144
-
145
- if cors_origin:
146
- api.set_cors(allow_origins=[cors_origin])
147
-
148
- if not no_playground:
149
- api.setup_playground()
150
-
151
- if compile:
152
- logger.info("Model compile is enabled")
153
-
154
-
155
- app_description = """
156
- ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
157
- ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
158
-
159
- 项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
160
-
161
- > 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
162
- > All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
163
-
164
- > 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
165
- > [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
166
- """
167
- app_title = "ChatTTS Forge API"
168
- app_version = "0.1.0"
169
-
170
  if __name__ == "__main__":
171
  import dotenv
172
- from fastapi import FastAPI
173
 
174
  dotenv.load_dotenv(
175
  dotenv_path=os.getenv("ENV_FILE", ".env.api"),
176
  )
177
-
178
  parser = argparse.ArgumentParser(
179
  description="Start the FastAPI server with command line arguments"
180
  )
181
  setup_api_args(parser)
182
  setup_model_args(parser)
 
183
 
184
  args = parser.parse_args()
185
 
186
- app = FastAPI(
187
- title=app_title,
188
- description=app_description,
189
- version=app_version,
190
- redoc_url=None if config.runtime_env_vars.no_docs else "/redoc",
191
- docs_url=None if config.runtime_env_vars.no_docs else "/docs",
 
 
 
 
 
 
 
 
 
192
  )
193
 
194
- process_model_args(args)
195
- process_api_args(args, app)
196
-
197
- host = get_and_update_env(args, "api_host", "0.0.0.0", str)
198
- port = get_and_update_env(args, "api_port", 7870, int)
199
- reload = get_and_update_env(args, "reload", False, bool)
200
-
201
- uvicorn.run(app, host=host, port=port, reload=reload)
 
 
 
 
 
 
 
 
1
  import os
2
  import logging
3
 
4
+ from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
5
+ from modules.ffmpeg_env import setup_ffmpeg_path
6
+
7
+ setup_ffmpeg_path()
8
  logging.basicConfig(
9
  level=os.getenv("LOG_LEVEL", "INFO"),
10
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
11
  )
12
 
 
13
  import argparse
14
  import uvicorn
15
 
 
16
  from modules import config
17
  from modules.utils import env
 
 
18
 
19
+ from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
20
 
21
  logger = logging.getLogger(__name__)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  if __name__ == "__main__":
24
  import dotenv
 
25
 
26
  dotenv.load_dotenv(
27
  dotenv_path=os.getenv("ENV_FILE", ".env.api"),
28
  )
 
29
  parser = argparse.ArgumentParser(
30
  description="Start the FastAPI server with command line arguments"
31
  )
32
  setup_api_args(parser)
33
  setup_model_args(parser)
34
+ setup_uvicon_args(parser=parser)
35
 
36
  args = parser.parse_args()
37
 
38
+ host = env.get_and_update_env(args, "host", "0.0.0.0", str)
39
+ port = env.get_and_update_env(args, "port", 7870, int)
40
+ reload = env.get_and_update_env(args, "reload", False, bool)
41
+ workers = env.get_and_update_env(args, "workers", 1, int)
42
+ log_level = env.get_and_update_env(args, "log_level", "info", str)
43
+ access_log = env.get_and_update_env(args, "access_log", True, bool)
44
+ proxy_headers = env.get_and_update_env(args, "proxy_headers", True, bool)
45
+ timeout_keep_alive = env.get_and_update_env(args, "timeout_keep_alive", 5, int)
46
+ timeout_graceful_shutdown = env.get_and_update_env(
47
+ args, "timeout_graceful_shutdown", 0, int
48
+ )
49
+ ssl_keyfile = env.get_and_update_env(args, "ssl_keyfile", None, str)
50
+ ssl_certfile = env.get_and_update_env(args, "ssl_certfile", None, str)
51
+ ssl_keyfile_password = env.get_and_update_env(
52
+ args, "ssl_keyfile_password", None, str
53
  )
54
 
55
+ uvicorn.run(
56
+ "modules.api.worker:app",
57
+ host=host,
58
+ port=port,
59
+ reload=reload,
60
+ workers=workers,
61
+ log_level=log_level,
62
+ access_log=access_log,
63
+ proxy_headers=proxy_headers,
64
+ timeout_keep_alive=timeout_keep_alive,
65
+ timeout_graceful_shutdown=timeout_graceful_shutdown,
66
+ ssl_keyfile=ssl_keyfile,
67
+ ssl_certfile=ssl_certfile,
68
+ ssl_keyfile_password=ssl_keyfile_password,
69
+ )
modules/Enhancer/ResembleEnhance.py CHANGED
@@ -1,5 +1,8 @@
 
1
  import os
2
  from typing import List, Literal
 
 
3
  from modules.devices import devices
4
  from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
5
  from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
@@ -11,53 +14,54 @@ from modules.utils.constants import MODELS_DIR
11
  from pathlib import Path
12
 
13
  from threading import Lock
 
14
 
15
- resemble_enhance = None
16
- lock = Lock()
17
 
 
18
 
19
- def load_enhancer(device: torch.device):
20
- global resemble_enhance
21
- with lock:
22
- if resemble_enhance is None:
23
- resemble_enhance = ResembleEnhance(device)
24
- resemble_enhance.load_model()
25
- return resemble_enhance
26
 
27
 
28
  class ResembleEnhance:
29
- def __init__(self, device: torch.device):
30
  self.device = device
 
31
 
32
  self.enhancer: HParams = None
33
  self.hparams: Enhancer = None
34
 
35
  def load_model(self):
36
  hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
37
- enhancer = Enhancer(hparams)
38
  state_dict = torch.load(
39
  Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
40
  map_location=self.device,
41
  )["module"]
42
  enhancer.load_state_dict(state_dict)
43
- enhancer.to(self.device).eval()
44
 
45
  self.hparams = hparams
46
  self.enhancer = enhancer
47
 
48
  @torch.inference_mode()
49
- def denoise(self, dwav, sr, device) -> tuple[torch.Tensor, int]:
50
  assert self.enhancer is not None, "Model not loaded"
51
  assert self.enhancer.denoiser is not None, "Denoiser not loaded"
52
  enhancer = self.enhancer
53
- return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
 
 
 
 
 
 
54
 
55
  @torch.inference_mode()
56
  def enhance(
57
  self,
58
  dwav,
59
  sr,
60
- device,
61
  nfe=32,
62
  solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
63
  lambd=0.5,
@@ -74,7 +78,81 @@ class ResembleEnhance:
74
  assert self.enhancer is not None, "Model not loaded"
75
  enhancer = self.enhancer
76
  enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
77
- return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  if __name__ == "__main__":
 
1
+ import gc
2
  import os
3
  from typing import List, Literal
4
+
5
+ import numpy as np
6
  from modules.devices import devices
7
  from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
8
  from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
 
14
  from pathlib import Path
15
 
16
  from threading import Lock
17
+ from modules import config
18
 
19
+ import logging
 
20
 
21
+ logger = logging.getLogger(__name__)
22
 
23
+ resemble_enhance = None
24
+ lock = Lock()
 
 
 
 
 
25
 
26
 
27
  class ResembleEnhance:
28
+ def __init__(self, device: torch.device, dtype=torch.float32):
29
  self.device = device
30
+ self.dtype = dtype
31
 
32
  self.enhancer: HParams = None
33
  self.hparams: Enhancer = None
34
 
35
  def load_model(self):
36
  hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
37
+ enhancer = Enhancer(hparams).to(device=self.device, dtype=self.dtype).eval()
38
  state_dict = torch.load(
39
  Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
40
  map_location=self.device,
41
  )["module"]
42
  enhancer.load_state_dict(state_dict)
 
43
 
44
  self.hparams = hparams
45
  self.enhancer = enhancer
46
 
47
  @torch.inference_mode()
48
+ def denoise(self, dwav, sr) -> tuple[torch.Tensor, int]:
49
  assert self.enhancer is not None, "Model not loaded"
50
  assert self.enhancer.denoiser is not None, "Denoiser not loaded"
51
  enhancer = self.enhancer
52
+ return inference(
53
+ model=enhancer.denoiser,
54
+ dwav=dwav,
55
+ sr=sr,
56
+ device=self.devicem,
57
+ dtype=self.dtype,
58
+ )
59
 
60
  @torch.inference_mode()
61
  def enhance(
62
  self,
63
  dwav,
64
  sr,
 
65
  nfe=32,
66
  solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
67
  lambd=0.5,
 
78
  assert self.enhancer is not None, "Model not loaded"
79
  enhancer = self.enhancer
80
  enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
81
+ return inference(
82
+ model=enhancer, dwav=dwav, sr=sr, device=self.device, dtype=self.dtype
83
+ )
84
+
85
+
86
+ def load_enhancer() -> ResembleEnhance:
87
+ global resemble_enhance
88
+ with lock:
89
+ if resemble_enhance is None:
90
+ logger.info("Loading ResembleEnhance model")
91
+ resemble_enhance = ResembleEnhance(
92
+ device=devices.device, dtype=devices.dtype
93
+ )
94
+ resemble_enhance.load_model()
95
+ logger.info("ResembleEnhance model loaded")
96
+ return resemble_enhance
97
+
98
+
99
+ def unload_enhancer():
100
+ global resemble_enhance
101
+ with lock:
102
+ if resemble_enhance is not None:
103
+ logger.info("Unloading ResembleEnhance model")
104
+ del resemble_enhance
105
+ resemble_enhance = None
106
+ devices.torch_gc()
107
+ gc.collect()
108
+ logger.info("ResembleEnhance model unloaded")
109
+
110
+
111
+ def reload_enhancer():
112
+ logger.info("Reloading ResembleEnhance model")
113
+ unload_enhancer()
114
+ load_enhancer()
115
+ logger.info("ResembleEnhance model reloaded")
116
+
117
+
118
+ def apply_audio_enhance_full(
119
+ audio_data: np.ndarray,
120
+ sr: int,
121
+ nfe=32,
122
+ solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
123
+ lambd=0.5,
124
+ tau=0.5,
125
+ ):
126
+ # FIXME: 这里可能改成 to(device) 会优化一点?
127
+ tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
128
+ enhancer = load_enhancer()
129
+
130
+ tensor, sr = enhancer.enhance(
131
+ tensor, sr, tau=tau, nfe=nfe, solver=solver, lambd=lambd
132
+ )
133
+
134
+ audio_data = tensor.cpu().numpy()
135
+ return audio_data, int(sr)
136
+
137
+
138
+ def apply_audio_enhance(
139
+ audio_data: np.ndarray, sr: int, enable_denoise: bool, enable_enhance: bool
140
+ ):
141
+ if not enable_denoise and not enable_enhance:
142
+ return audio_data, sr
143
+
144
+ # FIXME: 这里可能改成 to(device) 会优化一点?
145
+ tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
146
+ enhancer = load_enhancer()
147
+
148
+ if enable_enhance or enable_denoise:
149
+ lambd = 0.9 if enable_denoise else 0.1
150
+ tensor, sr = enhancer.enhance(
151
+ tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd
152
+ )
153
+
154
+ audio_data = tensor.cpu().numpy()
155
+ return audio_data, int(sr)
156
 
157
 
158
  if __name__ == "__main__":
modules/api/Api.py CHANGED
@@ -24,7 +24,7 @@ def is_excluded(path, exclude_patterns):
24
 
25
 
26
  class APIManager:
27
- def __init__(self, app: FastAPI, no_docs=False, exclude_patterns=[]):
28
  self.app = app
29
  self.registered_apis = {}
30
  self.logger = logging.getLogger(__name__)
 
24
 
25
 
26
  class APIManager:
27
+ def __init__(self, app: FastAPI, exclude_patterns=[]):
28
  self.app = app
29
  self.registered_apis = {}
30
  self.logger = logging.getLogger(__name__)
modules/api/api_setup.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from modules.devices import devices
3
+ import argparse
4
+
5
+ import torch
6
+ from modules import config
7
+ from modules.utils import env
8
+ from modules import generate_audio
9
+ from modules.api.Api import APIManager
10
+
11
+ from modules.api.impl import (
12
+ style_api,
13
+ tts_api,
14
+ ssml_api,
15
+ google_api,
16
+ openai_api,
17
+ refiner_api,
18
+ speaker_api,
19
+ ping_api,
20
+ models_api,
21
+ )
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def create_api(app, exclude=[]):
27
+ app_mgr = APIManager(app=app, exclude_patterns=exclude)
28
+
29
+ ping_api.setup(app_mgr)
30
+ models_api.setup(app_mgr)
31
+ style_api.setup(app_mgr)
32
+ speaker_api.setup(app_mgr)
33
+ tts_api.setup(app_mgr)
34
+ ssml_api.setup(app_mgr)
35
+ google_api.setup(app_mgr)
36
+ openai_api.setup(app_mgr)
37
+ refiner_api.setup(app_mgr)
38
+
39
+ return app_mgr
40
+
41
+
42
+ def setup_model_args(parser: argparse.ArgumentParser):
43
+ parser.add_argument("--compile", action="store_true", help="Enable model compile")
44
+ parser.add_argument(
45
+ "--half",
46
+ action="store_true",
47
+ help="Enable half precision for model inference",
48
+ )
49
+ parser.add_argument(
50
+ "--off_tqdm",
51
+ action="store_true",
52
+ help="Disable tqdm progress bar",
53
+ )
54
+ parser.add_argument(
55
+ "--device_id",
56
+ type=str,
57
+ help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
58
+ default=None,
59
+ )
60
+ parser.add_argument(
61
+ "--use_cpu",
62
+ nargs="+",
63
+ help="use CPU as torch device for specified modules",
64
+ default=[],
65
+ type=str.lower,
66
+ )
67
+ parser.add_argument(
68
+ "--lru_size",
69
+ type=int,
70
+ default=64,
71
+ help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
72
+ )
73
+ parser.add_argument(
74
+ "--debug_generate",
75
+ action="store_true",
76
+ help="Enable debug mode for audio generation",
77
+ )
78
+
79
+
80
+ def process_model_args(args):
81
+ lru_size = env.get_and_update_env(args, "lru_size", 64, int)
82
+ compile = env.get_and_update_env(args, "compile", False, bool)
83
+ device_id = env.get_and_update_env(args, "device_id", None, str)
84
+ use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
85
+ half = env.get_and_update_env(args, "half", False, bool)
86
+ off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
87
+ debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
88
+
89
+ generate_audio.setup_lru_cache()
90
+ devices.reset_device()
91
+ devices.first_time_calculation()
92
+
93
+ if debug_generate:
94
+ generate_audio.logger.setLevel(logging.DEBUG)
95
+
96
+
97
+ def setup_uvicon_args(parser: argparse.ArgumentParser):
98
+ parser.add_argument("--host", type=str, help="Host to run the server on")
99
+ parser.add_argument("--port", type=int, help="Port to run the server on")
100
+ parser.add_argument(
101
+ "--reload", action="store_true", help="Enable auto-reload for development"
102
+ )
103
+ parser.add_argument("--workers", type=int, help="Number of worker processes")
104
+ parser.add_argument("--log_level", type=str, help="Log level")
105
+ parser.add_argument("--access_log", action="store_true", help="Enable access log")
106
+ parser.add_argument(
107
+ "--proxy_headers", action="store_true", help="Enable proxy headers"
108
+ )
109
+ parser.add_argument(
110
+ "--timeout_keep_alive", type=int, help="Keep-alive timeout duration"
111
+ )
112
+ parser.add_argument(
113
+ "--timeout_graceful_shutdown",
114
+ type=int,
115
+ help="Graceful shutdown timeout duration",
116
+ )
117
+ parser.add_argument("--ssl_keyfile", type=str, help="SSL key file path")
118
+ parser.add_argument("--ssl_certfile", type=str, help="SSL certificate file path")
119
+ parser.add_argument(
120
+ "--ssl_keyfile_password", type=str, help="SSL key file password"
121
+ )
122
+
123
+
124
+ def setup_api_args(parser: argparse.ArgumentParser):
125
+ parser.add_argument(
126
+ "--cors_origin",
127
+ type=str,
128
+ help="Allowed CORS origins. Use '*' to allow all origins.",
129
+ )
130
+ parser.add_argument(
131
+ "--no_playground",
132
+ action="store_true",
133
+ help="Disable the playground entry",
134
+ )
135
+ parser.add_argument(
136
+ "--no_docs",
137
+ action="store_true",
138
+ help="Disable the documentation entry",
139
+ )
140
+ # 配置哪些api要跳过 比如 exclude="/v1/speakers/*,/v1/tts/*"
141
+ parser.add_argument(
142
+ "--exclude",
143
+ type=str,
144
+ help="Exclude the specified API from the server",
145
+ )
146
+
147
+
148
+ def process_api_args(args, app):
149
+ cors_origin = env.get_and_update_env(args, "cors_origin", "*", str)
150
+ no_playground = env.get_and_update_env(args, "no_playground", False, bool)
151
+ no_docs = env.get_and_update_env(args, "no_docs", False, bool)
152
+ exclude = env.get_and_update_env(args, "exclude", "", str)
153
+
154
+ api = create_api(app=app, exclude=exclude.split(","))
155
+ config.api = api
156
+
157
+ if cors_origin:
158
+ api.set_cors(allow_origins=[cors_origin])
159
+
160
+ if not no_playground:
161
+ api.setup_playground()
162
+
163
+ if compile:
164
+ logger.info("Model compile is enabled")
modules/api/app_config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ app_description = """
2
+ ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
3
+ ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
4
+
5
+ 项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
6
+
7
+ > 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
8
+ > All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
9
+
10
+ > 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
11
+ > [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
12
+ """
13
+ app_title = "ChatTTS Forge API"
14
+ app_version = "0.1.0"
modules/api/impl/google_api.py CHANGED
@@ -1,4 +1,5 @@
1
  import base64
 
2
  from fastapi import HTTPException
3
 
4
  import io
@@ -6,7 +7,12 @@ import soundfile as sf
6
  from pydantic import BaseModel
7
 
8
 
 
 
 
 
9
  from modules.api.Api import APIManager
 
10
  from modules.utils.audio import apply_prosody_to_audio_data
11
  from modules.normalization import text_normalize
12
 
@@ -44,15 +50,25 @@ class AudioConfig(BaseModel):
44
  speakingRate: float = 1
45
  pitch: float = 0
46
  volumeGainDb: float = 0
47
- sampleRateHertz: int
48
  batchSize: int = 1
49
  spliterThreshold: int = 100
50
 
51
 
 
 
 
 
 
 
 
 
 
52
  class GoogleTextSynthesizeRequest(BaseModel):
53
  input: SynthesisInput
54
  voice: VoiceSelectionParams
55
- audioConfig: dict
 
56
 
57
 
58
  class GoogleTextSynthesizeResponse(BaseModel):
@@ -63,6 +79,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
63
  input = request.input
64
  voice = request.voice
65
  audioConfig = request.audioConfig
 
66
 
67
  # 提取参数
68
 
@@ -70,40 +87,41 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
70
  language_code = voice.languageCode
71
  voice_name = voice.name
72
  infer_seed = voice.seed or 42
73
- audio_format = audioConfig.get("audioEncoding", "mp3")
74
- speaking_rate = audioConfig.get("speakingRate", 1)
75
- pitch = audioConfig.get("pitch", 0)
76
- volume_gain_db = audioConfig.get("volumeGainDb", 0)
77
 
78
- batch_size = audioConfig.get("batchSize", 1)
79
 
80
  # TODO spliter_threshold
81
- spliter_threshold = audioConfig.get("spliterThreshold", 100)
82
 
83
  # TODO sample_rate
84
- sample_rate_hertz = audioConfig.get("sampleRateHertz", 24000)
85
 
86
  params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
87
 
88
- # TODO maybe need to change the sample rate
89
- sample_rate = 24000
90
-
91
  # 虽然 calc_spk_style 可以解析 seed 形式,但是这个接口只准备支持 speakers list 中存在的 speaker
92
  if speaker_mgr.get_speaker(voice_name) is None:
93
  raise HTTPException(
94
- status_code=400, detail="The specified voice name is not supported."
95
  )
96
 
97
  if audio_format != "mp3" and audio_format != "wav":
98
  raise HTTPException(
99
- status_code=400, detail="Invalid audio encoding format specified."
100
  )
101
 
 
 
 
 
102
  try:
103
  if input.text:
104
  # 处理文本合成逻辑
105
  text = text_normalize(input.text, is_end=True)
106
- sample_rate, audio_data = generate.generate_audio(
107
  text,
108
  temperature=(
109
  voice.temperature
@@ -117,6 +135,8 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
117
  prompt1=params.get("prompt1", ""),
118
  prompt2=params.get("prompt2", ""),
119
  prefix=params.get("prefix", ""),
 
 
120
  )
121
 
122
  elif input.ssml:
@@ -128,7 +148,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
128
 
129
  if len(segments) == 0:
130
  raise HTTPException(
131
- status_code=400, detail="The SSML text is empty or parsing failed."
132
  )
133
 
134
  synthesize = SynthesizeSegments(batch_size=batch_size)
@@ -144,7 +164,17 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
144
 
145
  else:
146
  raise HTTPException(
147
- status_code=400, detail="Either text or SSML input must be provided."
 
 
 
 
 
 
 
 
 
 
148
  )
149
 
150
  audio_data = apply_prosody_to_audio_data(
 
1
  import base64
2
+ from typing import Literal
3
  from fastapi import HTTPException
4
 
5
  import io
 
7
  from pydantic import BaseModel
8
 
9
 
10
+ from modules.Enhancer.ResembleEnhance import (
11
+ apply_audio_enhance,
12
+ apply_audio_enhance_full,
13
+ )
14
  from modules.api.Api import APIManager
15
+ from modules.synthesize_audio import synthesize_audio
16
  from modules.utils.audio import apply_prosody_to_audio_data
17
  from modules.normalization import text_normalize
18
 
 
50
  speakingRate: float = 1
51
  pitch: float = 0
52
  volumeGainDb: float = 0
53
+ sampleRateHertz: int = 24000
54
  batchSize: int = 1
55
  spliterThreshold: int = 100
56
 
57
 
58
+ class EnhancerConfig(BaseModel):
59
+ enabled: bool = False
60
+ model: str = "resemble-enhance"
61
+ nfe: int = 32
62
+ solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
63
+ lambd: float = 0.5
64
+ tau: float = 0.5
65
+
66
+
67
  class GoogleTextSynthesizeRequest(BaseModel):
68
  input: SynthesisInput
69
  voice: VoiceSelectionParams
70
+ audioConfig: AudioConfig
71
+ enhancerConfig: EnhancerConfig = None
72
 
73
 
74
  class GoogleTextSynthesizeResponse(BaseModel):
 
79
  input = request.input
80
  voice = request.voice
81
  audioConfig = request.audioConfig
82
+ enhancerConfig = request.enhancerConfig
83
 
84
  # 提取参数
85
 
 
87
  language_code = voice.languageCode
88
  voice_name = voice.name
89
  infer_seed = voice.seed or 42
90
+ audio_format = audioConfig.audioEncoding or "mp3"
91
+ speaking_rate = audioConfig.speakingRate or 1
92
+ pitch = audioConfig.pitch or 0
93
+ volume_gain_db = audioConfig.volumeGainDb or 0
94
 
95
+ batch_size = audioConfig.batchSize or 1
96
 
97
  # TODO spliter_threshold
98
+ spliter_threshold = audioConfig.spliterThreshold or 100
99
 
100
  # TODO sample_rate
101
+ sample_rate_hertz = audioConfig.sampleRateHertz or 24000
102
 
103
  params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
104
 
 
 
 
105
  # 虽然 calc_spk_style 可以解析 seed 形式,但是这个接口只准备支持 speakers list 中存在的 speaker
106
  if speaker_mgr.get_speaker(voice_name) is None:
107
  raise HTTPException(
108
+ status_code=422, detail="The specified voice name is not supported."
109
  )
110
 
111
  if audio_format != "mp3" and audio_format != "wav":
112
  raise HTTPException(
113
+ status_code=422, detail="Invalid audio encoding format specified."
114
  )
115
 
116
+ if enhancerConfig.enabled:
117
+ # TODO enhancer params checker
118
+ pass
119
+
120
  try:
121
  if input.text:
122
  # 处理文本合成逻辑
123
  text = text_normalize(input.text, is_end=True)
124
+ sample_rate, audio_data = synthesize_audio(
125
  text,
126
  temperature=(
127
  voice.temperature
 
135
  prompt1=params.get("prompt1", ""),
136
  prompt2=params.get("prompt2", ""),
137
  prefix=params.get("prefix", ""),
138
+ batch_size=batch_size,
139
+ spliter_threshold=spliter_threshold,
140
  )
141
 
142
  elif input.ssml:
 
148
 
149
  if len(segments) == 0:
150
  raise HTTPException(
151
+ status_code=422, detail="The SSML text is empty or parsing failed."
152
  )
153
 
154
  synthesize = SynthesizeSegments(batch_size=batch_size)
 
164
 
165
  else:
166
  raise HTTPException(
167
+ status_code=422, detail="Either text or SSML input must be provided."
168
+ )
169
+
170
+ if enhancerConfig.enabled:
171
+ audio_data, sample_rate = apply_audio_enhance_full(
172
+ audio_data=audio_data,
173
+ sr=sample_rate,
174
+ nfe=enhancerConfig.nfe,
175
+ solver=enhancerConfig.solver,
176
+ lambd=enhancerConfig.lambd,
177
+ tau=enhancerConfig.tau,
178
  )
179
 
180
  audio_data = apply_prosody_to_audio_data(
modules/api/impl/models_api.py CHANGED
@@ -1,11 +1,18 @@
 
1
  from modules.api import utils as api_utils
2
  from modules.api.Api import APIManager
3
- from modules.models import reload_chat_tts
4
 
5
 
6
  def setup(app: APIManager):
7
  @app.get("/v1/models/reload", response_model=api_utils.BaseResponse)
8
  async def reload_models():
9
- # Reload models
10
  reload_chat_tts()
 
11
  return api_utils.success_response("Models reloaded")
 
 
 
 
 
 
 
1
+ from modules.Enhancer.ResembleEnhance import reload_enhancer, unload_enhancer
2
  from modules.api import utils as api_utils
3
  from modules.api.Api import APIManager
4
+ from modules.models import reload_chat_tts, unload_chat_tts
5
 
6
 
7
  def setup(app: APIManager):
8
  @app.get("/v1/models/reload", response_model=api_utils.BaseResponse)
9
  async def reload_models():
 
10
  reload_chat_tts()
11
+ reload_enhancer()
12
  return api_utils.success_response("Models reloaded")
13
+
14
+ @app.get("/v1/models/unload", response_model=api_utils.BaseResponse)
15
+ async def reload_models():
16
+ unload_chat_tts()
17
+ unload_enhancer()
18
+ return api_utils.success_response("Models unloaded")
modules/api/worker.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import dotenv
5
+ from fastapi import FastAPI
6
+
7
+ from modules.ffmpeg_env import setup_ffmpeg_path
8
+
9
+ setup_ffmpeg_path()
10
+ logging.basicConfig(
11
+ level=os.getenv("LOG_LEVEL", "INFO"),
12
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
13
+ )
14
+
15
+ from modules.api.api_setup import (
16
+ process_api_args,
17
+ process_model_args,
18
+ setup_api_args,
19
+ setup_model_args,
20
+ setup_uvicon_args,
21
+ )
22
+ from modules.api.app_config import app_description, app_title, app_version
23
+ from modules import config
24
+ from modules.utils.torch_opt import configure_torch_optimizations
25
+
26
+ dotenv.load_dotenv(
27
+ dotenv_path=os.getenv("ENV_FILE", ".env.api"),
28
+ )
29
+ parser = argparse.ArgumentParser(
30
+ description="Start the FastAPI server with command line arguments"
31
+ )
32
+ setup_api_args(parser)
33
+ setup_model_args(parser)
34
+ setup_uvicon_args(parser)
35
+
36
+ args = parser.parse_args()
37
+
38
+ app = FastAPI(
39
+ title=app_title,
40
+ description=app_description,
41
+ version=app_version,
42
+ redoc_url=None if config.runtime_env_vars.no_docs else "/redoc",
43
+ docs_url=None if config.runtime_env_vars.no_docs else "/docs",
44
+ )
45
+
46
+ process_model_args(args)
47
+ process_api_args(args, app)
48
+
49
+ configure_torch_optimizations()
modules/config.py CHANGED
@@ -3,7 +3,7 @@ import sys
3
  import torch
4
  from modules.utils.JsonObject import JsonObject
5
 
6
- from modules.utils import git
7
 
8
  # TODO impl RuntimeEnvVars() class
9
  runtime_env_vars = JsonObject({})
@@ -20,5 +20,6 @@ versions = JsonObject(
20
  "git_tag": git.git_tag(),
21
  "git_branch": git.branch_name(),
22
  "git_commit": git.commit_hash(),
 
23
  }
24
  )
 
3
  import torch
4
  from modules.utils.JsonObject import JsonObject
5
 
6
+ from modules.utils import git, ffmpeg
7
 
8
  # TODO impl RuntimeEnvVars() class
9
  runtime_env_vars = JsonObject({})
 
20
  "git_tag": git.git_tag(),
21
  "git_branch": git.branch_name(),
22
  "git_commit": git.commit_hash(),
23
+ "ffmpeg_version": ffmpeg.ffmpeg_version(),
24
  }
25
  )
modules/ffmpeg_env.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from modules.utils.constants import ROOT_DIR
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def setup_ffmpeg_path():
9
+ ffmpeg_path = os.path.join(ROOT_DIR, "ffmpeg")
10
+ os.environ["PATH"] = ffmpeg_path + os.pathsep + os.environ["PATH"]
11
+
12
+ import pydub.utils
13
+
14
+ if pydub.utils.which("ffmpeg") is None:
15
+ logger.error("ffmpeg not found in PATH")
16
+ raise Exception("ffmpeg not found in PATH")
modules/generate_audio.py CHANGED
@@ -74,10 +74,10 @@ def generate_audio_batch(
74
  if isinstance(spk, int):
75
  with SeedContext(spk, True):
76
  params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
77
- logger.info(("spk", spk))
78
  elif isinstance(spk, Speaker):
79
  params_infer_code["spk_emb"] = spk.emb
80
- logger.info(("spk", spk.name))
81
  else:
82
  logger.warn(
83
  f"spk must be int or Speaker, but: <{type(spk)}> {spk}, wiil set to default voice"
@@ -85,7 +85,7 @@ def generate_audio_batch(
85
  with SeedContext(2, True):
86
  params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
87
 
88
- logger.info(
89
  {
90
  "text": texts,
91
  "infer_seed": infer_seed,
 
74
  if isinstance(spk, int):
75
  with SeedContext(spk, True):
76
  params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
77
+ logger.debug(("spk", spk))
78
  elif isinstance(spk, Speaker):
79
  params_infer_code["spk_emb"] = spk.emb
80
+ logger.debug(("spk", spk.name))
81
  else:
82
  logger.warn(
83
  f"spk must be int or Speaker, but: <{type(spk)}> {spk}, wiil set to default voice"
 
85
  with SeedContext(2, True):
86
  params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
87
 
88
+ logger.debug(
89
  {
90
  "text": texts,
91
  "infer_seed": infer_seed,
modules/gradio_dcls_fix.py CHANGED
@@ -1,6 +1,7 @@
1
  def dcls_patch():
2
  from gradio import data_classes
3
 
 
4
  data_classes.PredictBody.__get_pydantic_json_schema__ = lambda x, y: {
5
  "type": "object",
6
  }
 
1
  def dcls_patch():
2
  from gradio import data_classes
3
 
4
+ # https://github.com/gradio-app/gradio/pull/8530
5
  data_classes.PredictBody.__get_pydantic_json_schema__ = lambda x, y: {
6
  "type": "object",
7
  }
modules/models.py CHANGED
@@ -55,10 +55,9 @@ def unload_chat_tts():
55
  if isinstance(model, torch.nn.Module):
56
  model.cpu()
57
  del model
58
- if torch.cuda.is_available():
59
- torch.cuda.empty_cache()
60
- gc.collect()
61
  chat_tts = None
 
 
62
  logger.info("ChatTTS models unloaded")
63
 
64
 
 
55
  if isinstance(model, torch.nn.Module):
56
  model.cpu()
57
  del model
 
 
 
58
  chat_tts = None
59
+ devices.torch_gc()
60
+ gc.collect()
61
  logger.info("ChatTTS models unloaded")
62
 
63
 
modules/repos_static/resemble_enhance/denoiser/denoiser.py CHANGED
@@ -65,7 +65,9 @@ class Denoiser(nn.Module):
65
  x = x.cpu()
66
 
67
  window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
68
- s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1)
 
 
69
 
70
  s = s[..., :-1] # (b f t)
71
 
@@ -106,6 +108,7 @@ class Denoiser(nn.Module):
106
  if s.isnan().any():
107
  logger.warning("NaN detected in ISTFT input.")
108
 
 
109
  s = F.pad(s, (0, 1), "replicate") # (b f t+1)
110
 
111
  window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
@@ -168,7 +171,9 @@ class Denoiser(nn.Module):
168
 
169
  mag, cos, sin = self._stft(x) # (b 2f t)
170
  mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
171
- sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res)
 
 
172
 
173
  o = self._istft(sep_mag, sep_cos, sep_sin)
174
 
 
65
  x = x.cpu()
66
 
67
  window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
68
+ s = torch.stft(
69
+ x.float(), **self.stft_cfg, window=window, return_complex=True
70
+ ) # (b f t+1)
71
 
72
  s = s[..., :-1] # (b f t)
73
 
 
108
  if s.isnan().any():
109
  logger.warning("NaN detected in ISTFT input.")
110
 
111
+ s = s.to(torch.complex64)
112
  s = F.pad(s, (0, 1), "replicate") # (b f t+1)
113
 
114
  window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
 
171
 
172
  mag, cos, sin = self._stft(x) # (b 2f t)
173
  mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
174
+ sep_mag, sep_cos, sep_sin = self._separate(
175
+ mag, cos, sin, mag_mask, cos_res, sin_res
176
+ )
177
 
178
  o = self._istft(sep_mag, sep_cos, sep_sin)
179
 
modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py CHANGED
@@ -64,7 +64,12 @@ class IRMAE(nn.Module):
64
  nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
65
  *[ResBlock(hidden_dim) for _ in range(4)],
66
  # Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
67
- *[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
 
 
 
 
 
68
  nn.Tanh(),
69
  )
70
 
@@ -92,9 +97,10 @@ class IRMAE(nn.Module):
92
  self.stats = {}
93
  self.stats["z_mean"] = z.mean().item()
94
  self.stats["z_std"] = z.std().item()
95
- self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
96
- self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
97
- self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
 
98
  return z
99
 
100
  def decode(self, z):
 
64
  nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
65
  *[ResBlock(hidden_dim) for _ in range(4)],
66
  # Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
67
+ *[
68
+ nn.Conv1d(
69
+ hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False
70
+ )
71
+ for i in range(num_irms)
72
+ ],
73
  nn.Tanh(),
74
  )
75
 
 
97
  self.stats = {}
98
  self.stats["z_mean"] = z.mean().item()
99
  self.stats["z_std"] = z.std().item()
100
+ z_float = z.float()
101
+ self.stats["z_abs_68"] = z_float.abs().quantile(0.6827).item()
102
+ self.stats["z_abs_95"] = z_float.abs().quantile(0.9545).item()
103
+ self.stats["z_abs_99"] = z_float.abs().quantile(0.9973).item()
104
  return z
105
 
106
  def decode(self, z):
modules/repos_static/resemble_enhance/inference.py CHANGED
@@ -8,6 +8,8 @@ from torchaudio.functional import resample
8
  from torchaudio.transforms import MelSpectrogram
9
  from tqdm import trange
10
 
 
 
11
  from .hparams import HParams
12
 
13
  from modules import config
@@ -16,7 +18,14 @@ logger = logging.getLogger(__name__)
16
 
17
 
18
  @torch.inference_mode()
19
- def inference_chunk(model, dwav, sr, device, npad=441):
 
 
 
 
 
 
 
20
  assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
21
  del sr
22
 
@@ -24,10 +33,10 @@ def inference_chunk(model, dwav, sr, device, npad=441):
24
  abs_max = dwav.abs().max().clamp(min=1e-7)
25
 
26
  assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
27
- dwav = dwav.to(device)
28
  dwav = dwav / abs_max # Normalize
29
  dwav = F.pad(dwav, (0, npad))
30
- hwav = model(dwav[None])[0].cpu() # (T,)
31
  hwav = hwav[:length] # Trim padding
32
  hwav = hwav * abs_max # Unnormalize
33
 
@@ -60,6 +69,9 @@ def compute_offset(chunk1, chunk2, sr=44100):
60
  f_max=sr // 2,
61
  )
62
 
 
 
 
63
  spec1 = mel_fn(chunk1).log1p()
64
  spec2 = mel_fn(chunk2).log1p()
65
 
@@ -123,7 +135,13 @@ def remove_weight_norm_recursively(module):
123
 
124
 
125
  def inference(
126
- model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0
 
 
 
 
 
 
127
  ):
128
  if config.runtime_env_vars.off_tqdm:
129
  trange = range
@@ -159,9 +177,11 @@ def inference(
159
 
160
  chunks = []
161
  for start in trange(0, dwav.shape[-1], hop_length):
162
- chunks.append(
163
- inference_chunk(model, dwav[start : start + chunk_length], sr, device)
164
  )
 
 
165
 
166
  hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
167
 
@@ -172,5 +192,6 @@ def inference(
172
  logger.info(
173
  f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz"
174
  )
 
175
 
176
  return hwav, sr
 
8
  from torchaudio.transforms import MelSpectrogram
9
  from tqdm import trange
10
 
11
+ from modules.devices import devices
12
+
13
  from .hparams import HParams
14
 
15
  from modules import config
 
18
 
19
 
20
  @torch.inference_mode()
21
+ def inference_chunk(
22
+ model,
23
+ dwav: torch.Tensor,
24
+ sr: int,
25
+ device: torch.device,
26
+ dtype: torch.dtype,
27
+ npad=441,
28
+ ) -> torch.Tensor:
29
  assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
30
  del sr
31
 
 
33
  abs_max = dwav.abs().max().clamp(min=1e-7)
34
 
35
  assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
36
+ dwav = dwav.to(device=device, dtype=dtype)
37
  dwav = dwav / abs_max # Normalize
38
  dwav = F.pad(dwav, (0, npad))
39
+ hwav: torch.Tensor = model(dwav[None])[0].cpu() # (T,)
40
  hwav = hwav[:length] # Trim padding
41
  hwav = hwav * abs_max # Unnormalize
42
 
 
69
  f_max=sr // 2,
70
  )
71
 
72
+ chunk1 = chunk1.float()
73
+ chunk2 = chunk2.float()
74
+
75
  spec1 = mel_fn(chunk1).log1p()
76
  spec2 = mel_fn(chunk2).log1p()
77
 
 
135
 
136
 
137
  def inference(
138
+ model,
139
+ dwav,
140
+ sr,
141
+ device,
142
+ dtype,
143
+ chunk_seconds: float = 30.0,
144
+ overlap_seconds: float = 1.0,
145
  ):
146
  if config.runtime_env_vars.off_tqdm:
147
  trange = range
 
177
 
178
  chunks = []
179
  for start in trange(0, dwav.shape[-1], hop_length):
180
+ chunk_dwav = inference_chunk(
181
+ model, dwav[start : start + chunk_length], sr, device, dtype
182
  )
183
+ chunks.append(chunk_dwav.cpu())
184
+ devices.torch_gc()
185
 
186
  hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
187
 
 
192
  logger.info(
193
  f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz"
194
  )
195
+ devices.torch_gc()
196
 
197
  return hwav, sr
modules/speaker.py CHANGED
@@ -104,7 +104,7 @@ class SpeakerManager:
104
  if not os.path.exists(self.speaker_dir + fname):
105
  del self.speakers[fname]
106
 
107
- def list_speakers(self):
108
  return list(self.speakers.values())
109
 
110
  def create_speaker_from_seed(self, seed, name="", gender="", describe=""):
 
104
  if not os.path.exists(self.speaker_dir + fname):
105
  del self.speakers[fname]
106
 
107
+ def list_speakers(self) -> list[Speaker]:
108
  return list(self.speakers.values())
109
 
110
  def create_speaker_from_seed(self, seed, name="", gender="", describe=""):
modules/utils/env.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
 
 
 
3
 
4
  def get_env_val(key, val_type):
5
  env_val = os.getenv(key.upper())
@@ -27,3 +29,10 @@ def get_env_or_arg(args, arg_name, default, arg_type):
27
  return env_val
28
 
29
  return default
 
 
 
 
 
 
 
 
1
  import os
2
 
3
+ from modules import config
4
+
5
 
6
  def get_env_val(key, val_type):
7
  env_val = os.getenv(key.upper())
 
29
  return env_val
30
 
31
  return default
32
+
33
+
34
+ def get_and_update_env(*args):
35
+ val = get_env_or_arg(*args)
36
+ key = args[1]
37
+ config.runtime_env_vars[key] = val
38
+ return val
modules/utils/ffmpeg.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from functools import lru_cache
3
+
4
+
5
+ @lru_cache()
6
+ def ffmpeg_version():
7
+ try:
8
+ result = subprocess.check_output(
9
+ ["ffmpeg", "-version"], shell=False, encoding="utf8"
10
+ )
11
+ version_info = result.split("\n")[0]
12
+ version_info = version_info.split("ffmpeg version")[1].strip()
13
+ version_info = version_info.split("Copyright")[0].strip()
14
+ return version_info
15
+ except Exception:
16
+ return "<none>"
17
+
18
+
19
+ if __name__ == "__main__":
20
+ print(ffmpeg_version())
modules/utils/git.py CHANGED
@@ -3,23 +3,15 @@ import os
3
  import subprocess
4
 
5
 
6
- from modules.utils import constants
7
-
8
- # 用于判断是否在hf spaces
9
- try:
10
- import spaces
11
- except:
12
- spaces = None
13
 
14
  git = os.environ.get("GIT", "git")
15
 
16
- in_hf_spaces = spaces is not None
17
-
18
 
19
  @lru_cache()
20
  def commit_hash():
21
  try:
22
- if in_hf_spaces:
23
  return "<hf>"
24
  return subprocess.check_output(
25
  [git, "-C", constants.ROOT_DIR, "rev-parse", "HEAD"],
@@ -33,7 +25,7 @@ def commit_hash():
33
  @lru_cache()
34
  def git_tag():
35
  try:
36
- if in_hf_spaces:
37
  return "<hf>"
38
  return subprocess.check_output(
39
  [git, "-C", constants.ROOT_DIR, "describe", "--tags"],
@@ -57,7 +49,7 @@ def git_tag():
57
  @lru_cache()
58
  def branch_name():
59
  try:
60
- if in_hf_spaces:
61
  return "<hf>"
62
  return subprocess.check_output(
63
  [git, "-C", constants.ROOT_DIR, "rev-parse", "--abbrev-ref", "HEAD"],
 
3
  import subprocess
4
 
5
 
6
+ from modules.utils import constants, hf
 
 
 
 
 
 
7
 
8
  git = os.environ.get("GIT", "git")
9
 
 
 
10
 
11
  @lru_cache()
12
  def commit_hash():
13
  try:
14
+ if hf.is_spaces_env:
15
  return "<hf>"
16
  return subprocess.check_output(
17
  [git, "-C", constants.ROOT_DIR, "rev-parse", "HEAD"],
 
25
  @lru_cache()
26
  def git_tag():
27
  try:
28
+ if hf.is_spaces_env:
29
  return "<hf>"
30
  return subprocess.check_output(
31
  [git, "-C", constants.ROOT_DIR, "describe", "--tags"],
 
49
  @lru_cache()
50
  def branch_name():
51
  try:
52
+ if hf.is_spaces_env:
53
  return "<hf>"
54
  return subprocess.check_output(
55
  [git, "-C", constants.ROOT_DIR, "rev-parse", "--abbrev-ref", "HEAD"],
modules/utils/hf.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 给huggingface space写的兼容代码
2
+
3
+ try:
4
+ import spaces
5
+
6
+ is_spaces_env = True
7
+ except:
8
+
9
+ class NoneSpaces:
10
+ def __init__(self):
11
+ pass
12
+
13
+ def GPU(self, fn):
14
+ return fn
15
+
16
+ spaces = NoneSpaces()
17
+ is_spaces_env = False
modules/utils/torch_opt.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def configure_torch_optimizations():
5
+ torch._dynamo.config.cache_size_limit = 64
6
+ torch._dynamo.config.suppress_errors = True
7
+ torch.set_float32_matmul_precision("high")
modules/webui/app.py CHANGED
@@ -1,11 +1,10 @@
1
  import logging
2
  import os
3
 
4
- import torch
5
  import gradio as gr
6
 
7
  from modules import config
8
- from modules.webui import gradio_extensions, localization, webui_config, gradio_hijack
9
 
10
  from modules.webui.changelog_tab import create_changelog_tab
11
  from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars
@@ -24,10 +23,6 @@ def webui_init():
24
  # fix: If the system proxy is enabled in the Windows system, you need to skip these
25
  os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"
26
 
27
- torch._dynamo.config.cache_size_limit = 64
28
- torch._dynamo.config.suppress_errors = True
29
- torch.set_float32_matmul_precision("high")
30
-
31
  if config.runtime_env_vars.language == "en":
32
  webui_config.localization = ENLocalizationVars()
33
  else:
@@ -43,6 +38,7 @@ def create_app_footer():
43
  git_branch = os.environ.get("V_GIT_BRANCH") or config.versions.git_branch
44
  python_version = config.versions.python_version
45
  torch_version = config.versions.torch_version
 
46
 
47
  config.versions.gradio_version = gradio_version
48
 
@@ -53,9 +49,10 @@ def create_app_footer():
53
  footer_items.append(f"branch: `{git_branch}`")
54
  footer_items.append(f"python: `{python_version}`")
55
  footer_items.append(f"torch: `{torch_version}`")
 
56
 
57
  if config.runtime_env_vars.api and not config.runtime_env_vars.no_docs:
58
- footer_items.append(f"[API](/docs)")
59
 
60
  gr.Markdown(
61
  " | ".join(footer_items),
 
1
  import logging
2
  import os
3
 
 
4
  import gradio as gr
5
 
6
  from modules import config
7
+ from modules.webui import gradio_extensions, webui_config
8
 
9
  from modules.webui.changelog_tab import create_changelog_tab
10
  from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars
 
23
  # fix: If the system proxy is enabled in the Windows system, you need to skip these
24
  os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"
25
 
 
 
 
 
26
  if config.runtime_env_vars.language == "en":
27
  webui_config.localization = ENLocalizationVars()
28
  else:
 
38
  git_branch = os.environ.get("V_GIT_BRANCH") or config.versions.git_branch
39
  python_version = config.versions.python_version
40
  torch_version = config.versions.torch_version
41
+ ffmpeg_version = config.versions.ffmpeg_version
42
 
43
  config.versions.gradio_version = gradio_version
44
 
 
49
  footer_items.append(f"branch: `{git_branch}`")
50
  footer_items.append(f"python: `{python_version}`")
51
  footer_items.append(f"torch: `{torch_version}`")
52
+ footer_items.append(f"ffmpeg: `{ffmpeg_version}`")
53
 
54
  if config.runtime_env_vars.api and not config.runtime_env_vars.no_docs:
55
+ footer_items.append(f"[api](/docs)")
56
 
57
  gr.Markdown(
58
  " | ".join(footer_items),
modules/webui/gradio_extensions.py CHANGED
@@ -14,7 +14,7 @@ WEBUI_DIR_PATH = Path(os.path.dirname(os.path.realpath(__file__)))
14
 
15
 
16
  def read_file(fp):
17
- with open(WEBUI_DIR_PATH / fp, "r") as f:
18
  return f.read()
19
 
20
 
 
14
 
15
 
16
  def read_file(fp):
17
+ with open(WEBUI_DIR_PATH / fp, "r", encoding="utf-8") as f:
18
  return f.read()
19
 
20
 
modules/webui/js/index.js CHANGED
@@ -154,6 +154,7 @@ addObserverIfDesiredNodeAvailable(".toast-wrap", function (added) {
154
  added.forEach(function (element) {
155
  if (element.innerText.includes("Connection errored out.")) {
156
  window.setTimeout(function () {
 
157
  document.getElementById("reset_button")?.classList.remove("hidden");
158
  document.getElementById("generate_button")?.classList.add("hidden");
159
  document.getElementById("skip_button")?.classList.add("hidden");
 
154
  added.forEach(function (element) {
155
  if (element.innerText.includes("Connection errored out.")) {
156
  window.setTimeout(function () {
157
+ // FIXME: 这几个button好像是没有...a1111里面的gradio版本和我们的不一样
158
  document.getElementById("reset_button")?.classList.remove("hidden");
159
  document.getElementById("generate_button")?.classList.add("hidden");
160
  document.getElementById("skip_button")?.classList.add("hidden");
modules/webui/speaker/speaker_creator.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  from modules.speaker import Speaker
4
  from modules.utils.SeedContext import SeedContext
5
- from modules.hf import spaces
6
  from modules.models import load_chat_tts
7
  from modules.utils.rng import np_rng
8
  from modules.webui import webui_config
 
2
  import torch
3
  from modules.speaker import Speaker
4
  from modules.utils.SeedContext import SeedContext
5
+ from modules.utils.hf import spaces
6
  from modules.models import load_chat_tts
7
  from modules.utils.rng import np_rng
8
  from modules.webui import webui_config
modules/webui/speaker/speaker_editor.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from modules.speaker import Speaker
4
- from modules.hf import spaces
5
  from modules.webui import webui_config
6
  from modules.webui.webui_utils import tts_generate
7
 
 
1
  import gradio as gr
2
  import torch
3
  from modules.speaker import Speaker
4
+ from modules.utils.hf import spaces
5
  from modules.webui import webui_config
6
  from modules.webui.webui_utils import tts_generate
7
 
modules/webui/speaker/speaker_merger.py CHANGED
@@ -2,7 +2,7 @@ import io
2
  import gradio as gr
3
  import torch
4
 
5
- from modules.hf import spaces
6
  from modules.webui import webui_config, webui_utils
7
  from modules.webui.webui_utils import get_speakers, tts_generate
8
  from modules.speaker import speaker_mgr, Speaker
 
2
  import gradio as gr
3
  import torch
4
 
5
+ from modules.utils.hf import spaces
6
  from modules.webui import webui_config, webui_utils
7
  from modules.webui.webui_utils import get_speakers, tts_generate
8
  from modules.speaker import speaker_mgr, Speaker
modules/webui/ssml/podcast_tab.py CHANGED
@@ -4,68 +4,68 @@ import torch
4
 
5
  from modules.normalization import text_normalize
6
  from modules.webui import webui_utils
7
- from modules.hf import spaces
8
 
9
  podcast_default_case = [
10
  [
11
  1,
12
  "female2",
13
  "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。 [lbreak]",
14
- "podcast_p",
15
  ],
16
  [
17
  2,
18
  "Alice",
19
  "嗨,我特别期待这个话题!中华料理真的是博大精深。 [lbreak]",
20
- "podcast_p",
21
  ],
22
  [
23
  3,
24
  "Bob",
25
  "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。 [lbreak]",
26
- "podcast_p",
27
  ],
28
  [
29
  4,
30
  "female2",
31
  "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。 [lbreak]",
32
- "podcast_p",
33
  ],
34
  [
35
  5,
36
  "Alice",
37
  "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。 [lbreak]",
38
- "podcast_p",
39
  ],
40
  [
41
  6,
42
  "Bob",
43
  "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。 [lbreak]",
44
- "podcast_p",
45
  ],
46
  [
47
  7,
48
  "female2",
49
  "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。 [lbreak]",
50
- "podcast_p",
51
  ],
52
  [
53
  8,
54
  "Alice",
55
  "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。 [lbreak]",
56
- "podcast_p",
57
  ],
58
  [
59
  9,
60
  "Bob",
61
  "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。 [lbreak]",
62
- "podcast_p",
63
  ],
64
  [
65
  10,
66
  "female2",
67
  "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。 [lbreak]",
68
- "podcast_p",
69
  ],
70
  ]
71
 
 
4
 
5
  from modules.normalization import text_normalize
6
  from modules.webui import webui_utils
7
+ from modules.utils.hf import spaces
8
 
9
  podcast_default_case = [
10
  [
11
  1,
12
  "female2",
13
  "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。 [lbreak]",
14
+ "podcast",
15
  ],
16
  [
17
  2,
18
  "Alice",
19
  "嗨,我特别期待这个话题!中华料理真的是博大精深。 [lbreak]",
20
+ "podcast",
21
  ],
22
  [
23
  3,
24
  "Bob",
25
  "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。 [lbreak]",
26
+ "podcast",
27
  ],
28
  [
29
  4,
30
  "female2",
31
  "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。 [lbreak]",
32
+ "podcast",
33
  ],
34
  [
35
  5,
36
  "Alice",
37
  "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。 [lbreak]",
38
+ "podcast",
39
  ],
40
  [
41
  6,
42
  "Bob",
43
  "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。 [lbreak]",
44
+ "podcast",
45
  ],
46
  [
47
  7,
48
  "female2",
49
  "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。 [lbreak]",
50
+ "podcast",
51
  ],
52
  [
53
  8,
54
  "Alice",
55
  "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。 [lbreak]",
56
+ "podcast",
57
  ],
58
  [
59
  9,
60
  "Bob",
61
  "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。 [lbreak]",
62
+ "podcast",
63
  ],
64
  [
65
  10,
66
  "female2",
67
  "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。 [lbreak]",
68
+ "podcast",
69
  ],
70
  ]
71
 
modules/webui/ssml/spliter_tab.py CHANGED
@@ -7,7 +7,7 @@ from modules.webui.webui_utils import (
7
  get_styles,
8
  split_long_text,
9
  )
10
- from modules.hf import spaces
11
 
12
 
13
  # NOTE: 因为 text_normalize 需要使用 tokenizer
 
7
  get_styles,
8
  split_long_text,
9
  )
10
+ from modules.utils.hf import spaces
11
 
12
 
13
  # NOTE: 因为 text_normalize 需要使用 tokenizer
modules/webui/webui_utils.py CHANGED
@@ -2,10 +2,10 @@ import io
2
  from typing import Union
3
  import numpy as np
4
 
5
- from modules.Enhancer.ResembleEnhance import load_enhancer
6
  from modules.devices import devices
7
  from modules.synthesize_audio import synthesize_audio
8
- from modules.hf import spaces
9
  from modules.webui import webui_config
10
 
11
  import torch
@@ -85,22 +85,7 @@ def segments_length_limit(
85
  @torch.inference_mode()
86
  @spaces.GPU
87
  def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
88
- if not enable_denoise and not enable_enhance:
89
- return audio_data, sr
90
-
91
- device = devices.device
92
- # NOTE: 这里很奇怪按道理得放到 device 上,但是 enhancer 做 chunk 的时候会报错...所以得 cpu()
93
- tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
94
- enhancer = load_enhancer(device)
95
-
96
- if enable_enhance or enable_denoise:
97
- lambd = 0.9 if enable_denoise else 0.1
98
- tensor, sr = enhancer.enhance(
99
- tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd, device=device
100
- )
101
-
102
- audio_data = tensor.cpu().numpy()
103
- return audio_data, int(sr)
104
 
105
 
106
  @torch.inference_mode()
 
2
  from typing import Union
3
  import numpy as np
4
 
5
+ from modules.Enhancer.ResembleEnhance import apply_audio_enhance as _apply_audio_enhance
6
  from modules.devices import devices
7
  from modules.synthesize_audio import synthesize_audio
8
+ from modules.utils.hf import spaces
9
  from modules.webui import webui_config
10
 
11
  import torch
 
85
  @torch.inference_mode()
86
  @spaces.GPU
87
  def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
88
+ return _apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  @torch.inference_mode()
webui.py CHANGED
@@ -1,21 +1,23 @@
1
  import os
2
  import logging
3
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  logging.basicConfig(
5
  level=os.getenv("LOG_LEVEL", "INFO"),
6
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
7
  )
8
 
9
- from launch import (
10
- get_and_update_env,
11
- setup_api_args,
12
- setup_model_args,
13
- process_api_args,
14
- process_model_args,
15
- app_description,
16
- app_title,
17
- app_version,
18
- )
19
  from modules.webui import webui_config
20
  from modules import config
21
  from modules.webui.app import webui_init, create_interface
@@ -89,6 +91,7 @@ def process_webui_args(args):
89
  webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
90
  webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
91
 
 
92
  webui_init()
93
  demo = create_interface()
94
 
@@ -102,7 +105,7 @@ def process_webui_args(args):
102
  debug=debug,
103
  auth=auth,
104
  show_api=False,
105
- prevent_thread_lock=False,
106
  app_kwargs={
107
  "title": app_title,
108
  "description": app_description,
@@ -119,6 +122,18 @@ def process_webui_args(args):
119
  ),
120
  },
121
  )
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
  if __name__ == "__main__":
 
1
  import os
2
  import logging
3
 
4
+ from modules.api.api_setup import (
5
+ process_api_args,
6
+ process_model_args,
7
+ setup_api_args,
8
+ setup_model_args,
9
+ )
10
+ from modules.ffmpeg_env import setup_ffmpeg_path
11
+ from modules.utils.env import get_and_update_env
12
+ from modules.api.app_config import app_description, app_title, app_version
13
+ from modules.utils.torch_opt import configure_torch_optimizations
14
+
15
+ setup_ffmpeg_path()
16
  logging.basicConfig(
17
  level=os.getenv("LOG_LEVEL", "INFO"),
18
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
21
  from modules.webui import webui_config
22
  from modules import config
23
  from modules.webui.app import webui_init, create_interface
 
91
  webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
92
  webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
93
 
94
+ configure_torch_optimizations()
95
  webui_init()
96
  demo = create_interface()
97
 
 
105
  debug=debug,
106
  auth=auth,
107
  show_api=False,
108
+ prevent_thread_lock=True,
109
  app_kwargs={
110
  "title": app_title,
111
  "description": app_description,
 
122
  ),
123
  },
124
  )
125
+ # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
126
+ # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
127
+ # running web ui and do whatever the attacker wants, including installing an extension and
128
+ # running its code. We disable this here. Suggested by RyotaK.
129
+ app.user_middleware = [
130
+ x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware"
131
+ ]
132
+
133
+ if api:
134
+ process_api_args(args, app)
135
+
136
+ demo.block_thread()
137
 
138
 
139
  if __name__ == "__main__":