zhzluke96
commited on
Commit
•
8c22399
1
Parent(s):
4554b6b
update
Browse files- launch.py +150 -58
- modules/api/Api.py +6 -20
- modules/api/impl/refiner_api.py +6 -1
- modules/api/impl/speaker_api.py +13 -13
- modules/api/impl/tts_api.py +0 -2
- modules/gradio_dcls_fix.py +6 -0
- modules/webui/app.py +12 -4
- modules/webui/js/localization.js +22 -3
- modules/webui/tts_tab.py +1 -1
- webui.py +79 -63
launch.py
CHANGED
@@ -1,109 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from modules import config
|
|
|
3 |
from modules import generate_audio as generate
|
4 |
-
|
5 |
-
from functools import lru_cache
|
6 |
-
from typing import Callable
|
7 |
-
|
8 |
from modules.api.Api import APIManager
|
9 |
|
10 |
from modules.api.impl import (
|
11 |
-
|
12 |
tts_api,
|
13 |
ssml_api,
|
14 |
google_api,
|
15 |
openai_api,
|
16 |
refiner_api,
|
|
|
|
|
|
|
17 |
)
|
18 |
|
|
|
|
|
19 |
torch._dynamo.config.cache_size_limit = 64
|
20 |
torch._dynamo.config.suppress_errors = True
|
21 |
torch.set_float32_matmul_precision("high")
|
22 |
|
23 |
|
24 |
-
def create_api():
|
25 |
-
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
33 |
|
34 |
-
return
|
35 |
|
36 |
|
37 |
-
def
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
|
43 |
-
def wrapper(*args, **kwargs):
|
44 |
-
if condition(*args, **kwargs):
|
45 |
-
return cached_func(*args, **kwargs)
|
46 |
-
else:
|
47 |
-
return func(*args, **kwargs)
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
import argparse
|
56 |
-
import uvicorn
|
57 |
-
|
58 |
-
parser = argparse.ArgumentParser(
|
59 |
-
description="Start the FastAPI server with command line arguments"
|
60 |
)
|
61 |
parser.add_argument(
|
62 |
-
"--
|
|
|
|
|
63 |
)
|
64 |
parser.add_argument(
|
65 |
-
"--
|
|
|
|
|
|
|
66 |
)
|
67 |
parser.add_argument(
|
68 |
-
"--
|
|
|
|
|
|
|
|
|
69 |
)
|
70 |
-
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
71 |
parser.add_argument(
|
72 |
"--lru_size",
|
73 |
type=int,
|
74 |
default=64,
|
75 |
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
76 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
parser.add_argument(
|
78 |
"--cors_origin",
|
79 |
type=str,
|
80 |
-
default="*",
|
81 |
help="Allowed CORS origins. Use '*' to allow all origins.",
|
82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
args = parser.parse_args()
|
85 |
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
|
92 |
-
def should_cache(*args, **kwargs):
|
93 |
-
spk_seed = kwargs.get("spk_seed", -1)
|
94 |
-
infer_seed = kwargs.get("infer_seed", -1)
|
95 |
-
return spk_seed != -1 and infer_seed != -1
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
|
103 |
-
api = create_api()
|
104 |
config.api = api
|
105 |
|
106 |
-
if
|
107 |
-
api.set_cors(allow_origins=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
uvicorn.run(
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
logging.basicConfig(
|
5 |
+
level=os.getenv("LOG_LEVEL", "INFO"),
|
6 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
7 |
+
)
|
8 |
+
|
9 |
+
from modules.devices import devices
|
10 |
+
import argparse
|
11 |
+
import uvicorn
|
12 |
+
|
13 |
import torch
|
14 |
from modules import config
|
15 |
+
from modules.utils import env
|
16 |
from modules import generate_audio as generate
|
|
|
|
|
|
|
|
|
17 |
from modules.api.Api import APIManager
|
18 |
|
19 |
from modules.api.impl import (
|
20 |
+
style_api,
|
21 |
tts_api,
|
22 |
ssml_api,
|
23 |
google_api,
|
24 |
openai_api,
|
25 |
refiner_api,
|
26 |
+
speaker_api,
|
27 |
+
ping_api,
|
28 |
+
models_api,
|
29 |
)
|
30 |
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
torch._dynamo.config.cache_size_limit = 64
|
34 |
torch._dynamo.config.suppress_errors = True
|
35 |
torch.set_float32_matmul_precision("high")
|
36 |
|
37 |
|
38 |
+
def create_api(app, no_docs=False, exclude=[]):
|
39 |
+
app_mgr = APIManager(app=app, no_docs=no_docs, exclude_patterns=exclude)
|
40 |
|
41 |
+
ping_api.setup(app_mgr)
|
42 |
+
models_api.setup(app_mgr)
|
43 |
+
style_api.setup(app_mgr)
|
44 |
+
speaker_api.setup(app_mgr)
|
45 |
+
tts_api.setup(app_mgr)
|
46 |
+
ssml_api.setup(app_mgr)
|
47 |
+
google_api.setup(app_mgr)
|
48 |
+
openai_api.setup(app_mgr)
|
49 |
+
refiner_api.setup(app_mgr)
|
50 |
|
51 |
+
return app_mgr
|
52 |
|
53 |
|
54 |
+
def get_and_update_env(*args):
|
55 |
+
val = env.get_env_or_arg(*args)
|
56 |
+
key = args[1]
|
57 |
+
config.runtime_env_vars[key] = val
|
58 |
+
return val
|
59 |
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
def setup_model_args(parser: argparse.ArgumentParser):
|
62 |
+
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
63 |
+
parser.add_argument(
|
64 |
+
"--half",
|
65 |
+
action="store_true",
|
66 |
+
help="Enable half precision for model inference",
|
|
|
|
|
|
|
|
|
|
|
67 |
)
|
68 |
parser.add_argument(
|
69 |
+
"--off_tqdm",
|
70 |
+
action="store_true",
|
71 |
+
help="Disable tqdm progress bar",
|
72 |
)
|
73 |
parser.add_argument(
|
74 |
+
"--device_id",
|
75 |
+
type=str,
|
76 |
+
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
|
77 |
+
default=None,
|
78 |
)
|
79 |
parser.add_argument(
|
80 |
+
"--use_cpu",
|
81 |
+
nargs="+",
|
82 |
+
help="use CPU as torch device for specified modules",
|
83 |
+
default=[],
|
84 |
+
type=str.lower,
|
85 |
)
|
|
|
86 |
parser.add_argument(
|
87 |
"--lru_size",
|
88 |
type=int,
|
89 |
default=64,
|
90 |
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
91 |
)
|
92 |
+
|
93 |
+
|
94 |
+
def setup_api_args(parser: argparse.ArgumentParser):
|
95 |
+
parser.add_argument("--api_host", type=str, help="Host to run the server on")
|
96 |
+
parser.add_argument("--api_port", type=int, help="Port to run the server on")
|
97 |
+
parser.add_argument(
|
98 |
+
"--reload", action="store_true", help="Enable auto-reload for development"
|
99 |
+
)
|
100 |
parser.add_argument(
|
101 |
"--cors_origin",
|
102 |
type=str,
|
|
|
103 |
help="Allowed CORS origins. Use '*' to allow all origins.",
|
104 |
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--no_playground",
|
107 |
+
action="store_true",
|
108 |
+
help="Disable the playground entry",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--no_docs",
|
112 |
+
action="store_true",
|
113 |
+
help="Disable the documentation entry",
|
114 |
+
)
|
115 |
+
# 配置哪些api要跳过 比如 exclude="/v1/speakers/*,/v1/tts/*"
|
116 |
+
parser.add_argument(
|
117 |
+
"--exclude",
|
118 |
+
type=str,
|
119 |
+
help="Exclude the specified API from the server",
|
120 |
+
)
|
121 |
|
|
|
122 |
|
123 |
+
def process_model_args(args):
|
124 |
+
lru_size = get_and_update_env(args, "lru_size", 64, int)
|
125 |
+
compile = get_and_update_env(args, "compile", False, bool)
|
126 |
+
device_id = get_and_update_env(args, "device_id", None, str)
|
127 |
+
use_cpu = get_and_update_env(args, "use_cpu", [], list)
|
128 |
+
half = get_and_update_env(args, "half", False, bool)
|
129 |
+
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
|
130 |
|
131 |
+
generate.setup_lru_cache()
|
132 |
+
devices.reset_device()
|
133 |
+
devices.first_time_calculation()
|
134 |
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
def process_api_args(args, app):
|
137 |
+
cors_origin = get_and_update_env(args, "cors_origin", "*", str)
|
138 |
+
no_playground = get_and_update_env(args, "no_playground", False, bool)
|
139 |
+
no_docs = get_and_update_env(args, "no_docs", False, bool)
|
140 |
+
exclude = get_and_update_env(args, "exclude", "", str)
|
141 |
|
142 |
+
api = create_api(app=app, no_docs=no_docs, exclude=exclude.split(","))
|
143 |
config.api = api
|
144 |
|
145 |
+
if cors_origin:
|
146 |
+
api.set_cors(allow_origins=[cors_origin])
|
147 |
+
|
148 |
+
if not no_playground:
|
149 |
+
api.setup_playground()
|
150 |
+
|
151 |
+
if compile:
|
152 |
+
logger.info("Model compile is enabled")
|
153 |
+
|
154 |
+
|
155 |
+
app_description = """
|
156 |
+
ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
|
157 |
+
ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
|
158 |
+
|
159 |
+
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
160 |
+
|
161 |
+
> 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
|
162 |
+
> All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
|
163 |
+
|
164 |
+
> 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
|
165 |
+
> [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
|
166 |
+
"""
|
167 |
+
app_title = "ChatTTS Forge API"
|
168 |
+
app_version = "0.1.0"
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
import dotenv
|
172 |
+
from fastapi import FastAPI
|
173 |
+
|
174 |
+
dotenv.load_dotenv(
|
175 |
+
dotenv_path=os.getenv("ENV_FILE", ".env.api"),
|
176 |
+
)
|
177 |
+
|
178 |
+
parser = argparse.ArgumentParser(
|
179 |
+
description="Start the FastAPI server with command line arguments"
|
180 |
+
)
|
181 |
+
setup_api_args(parser)
|
182 |
+
setup_model_args(parser)
|
183 |
+
|
184 |
+
args = parser.parse_args()
|
185 |
+
|
186 |
+
app = FastAPI(
|
187 |
+
title=app_title,
|
188 |
+
description=app_description,
|
189 |
+
version=app_version,
|
190 |
+
redoc_url=None if config.runtime_env_vars.no_docs else "/redoc",
|
191 |
+
docs_url=None if config.runtime_env_vars.no_docs else "/docs",
|
192 |
+
)
|
193 |
+
|
194 |
+
process_model_args(args)
|
195 |
+
process_api_args(args, app)
|
196 |
+
|
197 |
+
host = get_and_update_env(args, "api_host", "0.0.0.0", str)
|
198 |
+
port = get_and_update_env(args, "api_port", 7870, int)
|
199 |
+
reload = get_and_update_env(args, "reload", False, bool)
|
200 |
|
201 |
+
uvicorn.run(app, host=host, port=port, reload=reload)
|
modules/api/Api.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
|
4 |
import logging
|
@@ -24,25 +24,8 @@ def is_excluded(path, exclude_patterns):
|
|
24 |
|
25 |
|
26 |
class APIManager:
|
27 |
-
def __init__(self, no_docs=False, exclude_patterns=[]):
|
28 |
-
self.app =
|
29 |
-
title="ChatTTS Forge API",
|
30 |
-
description="""
|
31 |
-
ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
|
32 |
-
ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
|
33 |
-
|
34 |
-
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
35 |
-
|
36 |
-
> 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
|
37 |
-
> All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
|
38 |
-
|
39 |
-
> 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
|
40 |
-
> [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
|
41 |
-
""",
|
42 |
-
version="0.1.0",
|
43 |
-
redoc_url=None if no_docs else "/redoc",
|
44 |
-
docs_url=None if no_docs else "/docs",
|
45 |
-
)
|
46 |
self.registered_apis = {}
|
47 |
self.logger = logging.getLogger(__name__)
|
48 |
self.exclude = exclude_patterns
|
@@ -57,6 +40,8 @@ ChatTTS-Forge is a powerful text-to-speech generation tool that supports generat
|
|
57 |
allow_methods: list = ["*"],
|
58 |
allow_headers: list = ["*"],
|
59 |
):
|
|
|
|
|
60 |
self.app.add_middleware(
|
61 |
CORSMiddleware,
|
62 |
allow_origins=allow_origins,
|
@@ -64,6 +49,7 @@ ChatTTS-Forge is a powerful text-to-speech generation tool that supports generat
|
|
64 |
allow_methods=allow_methods,
|
65 |
allow_headers=allow_headers,
|
66 |
)
|
|
|
67 |
|
68 |
def setup_playground(self):
|
69 |
app = self.app
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
|
4 |
import logging
|
|
|
24 |
|
25 |
|
26 |
class APIManager:
|
27 |
+
def __init__(self, app: FastAPI, no_docs=False, exclude_patterns=[]):
|
28 |
+
self.app = app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
self.registered_apis = {}
|
30 |
self.logger = logging.getLogger(__name__)
|
31 |
self.exclude = exclude_patterns
|
|
|
40 |
allow_methods: list = ["*"],
|
41 |
allow_headers: list = ["*"],
|
42 |
):
|
43 |
+
# reset middleware stack
|
44 |
+
self.app.middleware_stack = None
|
45 |
self.app.add_middleware(
|
46 |
CORSMiddleware,
|
47 |
allow_origins=allow_origins,
|
|
|
49 |
allow_methods=allow_methods,
|
50 |
allow_headers=allow_headers,
|
51 |
)
|
52 |
+
self.app.build_middleware_stack()
|
53 |
|
54 |
def setup_playground(self):
|
55 |
app = self.app
|
modules/api/impl/refiner_api.py
CHANGED
@@ -7,6 +7,7 @@ from modules import refiner
|
|
7 |
|
8 |
from modules.api import utils as api_utils
|
9 |
from modules.api.Api import APIManager
|
|
|
10 |
|
11 |
|
12 |
class RefineTextRequest(BaseModel):
|
@@ -18,6 +19,7 @@ class RefineTextRequest(BaseModel):
|
|
18 |
temperature: float = 0.7
|
19 |
repetition_penalty: float = 1.0
|
20 |
max_new_token: int = 384
|
|
|
21 |
|
22 |
|
23 |
async def refiner_prompt_post(request: RefineTextRequest):
|
@@ -26,8 +28,11 @@ async def refiner_prompt_post(request: RefineTextRequest):
|
|
26 |
"""
|
27 |
|
28 |
try:
|
|
|
|
|
|
|
29 |
refined_text = refiner.refine_text(
|
30 |
-
text=
|
31 |
prompt=request.prompt,
|
32 |
seed=request.seed,
|
33 |
top_P=request.top_P,
|
|
|
7 |
|
8 |
from modules.api import utils as api_utils
|
9 |
from modules.api.Api import APIManager
|
10 |
+
from modules.normalization import text_normalize
|
11 |
|
12 |
|
13 |
class RefineTextRequest(BaseModel):
|
|
|
19 |
temperature: float = 0.7
|
20 |
repetition_penalty: float = 1.0
|
21 |
max_new_token: int = 384
|
22 |
+
normalize: bool = True
|
23 |
|
24 |
|
25 |
async def refiner_prompt_post(request: RefineTextRequest):
|
|
|
28 |
"""
|
29 |
|
30 |
try:
|
31 |
+
text = request.text
|
32 |
+
if request.normalize:
|
33 |
+
text = text_normalize(request.text)
|
34 |
refined_text = refiner.refine_text(
|
35 |
+
text=text,
|
36 |
prompt=request.prompt,
|
37 |
seed=request.seed,
|
38 |
top_P=request.top_P,
|
modules/api/impl/speaker_api.py
CHANGED
@@ -35,10 +35,14 @@ def setup(app: APIManager):
|
|
35 |
|
36 |
@app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
|
37 |
async def list_speakers():
|
38 |
-
return
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
|
43 |
@app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
|
44 |
async def update_speakers(request: SpeakersUpdate):
|
@@ -59,7 +63,8 @@ def setup(app: APIManager):
|
|
59 |
# number array => Tensor
|
60 |
speaker.emb = torch.tensor(spk["tensor"])
|
61 |
speaker_mgr.save_all()
|
62 |
-
|
|
|
63 |
|
64 |
@app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
|
65 |
async def create_speaker(request: CreateSpeaker):
|
@@ -88,12 +93,7 @@ def setup(app: APIManager):
|
|
88 |
raise HTTPException(
|
89 |
status_code=400, detail="Missing tensor or seed in request"
|
90 |
)
|
91 |
-
return
|
92 |
-
|
93 |
-
@app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
|
94 |
-
async def refresh_speakers():
|
95 |
-
speaker_mgr.refresh_speakers()
|
96 |
-
return {"message": "ok"}
|
97 |
|
98 |
@app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
|
99 |
async def update_speaker(request: UpdateSpeaker):
|
@@ -113,11 +113,11 @@ def setup(app: APIManager):
|
|
113 |
# number array => Tensor
|
114 |
speaker.emb = torch.tensor(request.tensor)
|
115 |
speaker_mgr.update_speaker(speaker)
|
116 |
-
return
|
117 |
|
118 |
@app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
|
119 |
async def speaker_detail(request: SpeakerDetail):
|
120 |
speaker = speaker_mgr.get_speaker_by_id(request.id)
|
121 |
if speaker is None:
|
122 |
raise HTTPException(status_code=404, detail="Speaker not found")
|
123 |
-
return
|
|
|
35 |
|
36 |
@app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
|
37 |
async def list_speakers():
|
38 |
+
return api_utils.success_response(
|
39 |
+
[spk.to_json() for spk in speaker_mgr.list_speakers()]
|
40 |
+
)
|
41 |
+
|
42 |
+
@app.post("/v1/speakers/refresh", response_model=api_utils.BaseResponse)
|
43 |
+
async def refresh_speakers():
|
44 |
+
speaker_mgr.refresh_speakers()
|
45 |
+
return api_utils.success_response(None)
|
46 |
|
47 |
@app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
|
48 |
async def update_speakers(request: SpeakersUpdate):
|
|
|
63 |
# number array => Tensor
|
64 |
speaker.emb = torch.tensor(spk["tensor"])
|
65 |
speaker_mgr.save_all()
|
66 |
+
|
67 |
+
return api_utils.success_response(None)
|
68 |
|
69 |
@app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
|
70 |
async def create_speaker(request: CreateSpeaker):
|
|
|
93 |
raise HTTPException(
|
94 |
status_code=400, detail="Missing tensor or seed in request"
|
95 |
)
|
96 |
+
return api_utils.success_response(speaker.to_json())
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
@app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
|
99 |
async def update_speaker(request: UpdateSpeaker):
|
|
|
113 |
# number array => Tensor
|
114 |
speaker.emb = torch.tensor(request.tensor)
|
115 |
speaker_mgr.update_speaker(speaker)
|
116 |
+
return api_utils.success_response(None)
|
117 |
|
118 |
@app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
|
119 |
async def speaker_detail(request: SpeakerDetail):
|
120 |
speaker = speaker_mgr.get_speaker_by_id(request.id)
|
121 |
if speaker is None:
|
122 |
raise HTTPException(status_code=404, detail="Speaker not found")
|
123 |
+
return api_utils.success_response(speaker.to_json(with_emb=request.with_emb))
|
modules/api/impl/tts_api.py
CHANGED
@@ -9,8 +9,6 @@ from fastapi.responses import FileResponse
|
|
9 |
|
10 |
from modules.normalization import text_normalize
|
11 |
|
12 |
-
from modules import generate_audio as generate
|
13 |
-
|
14 |
from modules.api import utils as api_utils
|
15 |
from modules.api.Api import APIManager
|
16 |
from modules.synthesize_audio import synthesize_audio
|
|
|
9 |
|
10 |
from modules.normalization import text_normalize
|
11 |
|
|
|
|
|
12 |
from modules.api import utils as api_utils
|
13 |
from modules.api.Api import APIManager
|
14 |
from modules.synthesize_audio import synthesize_audio
|
modules/gradio_dcls_fix.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def dcls_patch():
|
2 |
+
from gradio import data_classes
|
3 |
+
|
4 |
+
data_classes.PredictBody.__get_pydantic_json_schema__ = lambda x, y: {
|
5 |
+
"type": "object",
|
6 |
+
}
|
modules/webui/app.py
CHANGED
@@ -46,11 +46,19 @@ def create_app_footer():
|
|
46 |
|
47 |
config.versions.gradio_version = gradio_version
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
gr.Markdown(
|
50 |
-
|
51 |
-
🍦 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
52 |
-
version: [{git_tag}](https://github.com/lenML/ChatTTS-Forge/commit/{git_commit}) | branch: `{git_branch}` | python: `{python_version}` | torch: `{torch_version}`
|
53 |
-
""",
|
54 |
elem_classes=["no-translate"],
|
55 |
)
|
56 |
|
|
|
46 |
|
47 |
config.versions.gradio_version = gradio_version
|
48 |
|
49 |
+
footer_items = ["🍦 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)"]
|
50 |
+
footer_items.append(
|
51 |
+
f"version: [{git_tag}](https://github.com/lenML/ChatTTS-Forge/commit/{git_commit})"
|
52 |
+
)
|
53 |
+
footer_items.append(f"branch: `{git_branch}`")
|
54 |
+
footer_items.append(f"python: `{python_version}`")
|
55 |
+
footer_items.append(f"torch: `{torch_version}`")
|
56 |
+
|
57 |
+
if config.runtime_env_vars.api and not config.runtime_env_vars.no_docs:
|
58 |
+
footer_items.append(f"[API](/docs)")
|
59 |
+
|
60 |
gr.Markdown(
|
61 |
+
" | ".join(footer_items),
|
|
|
|
|
|
|
62 |
elem_classes=["no-translate"],
|
63 |
)
|
64 |
|
modules/webui/js/localization.js
CHANGED
@@ -163,6 +163,23 @@ function localizeWholePage() {
|
|
163 |
}
|
164 |
}
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
document.addEventListener("DOMContentLoaded", function () {
|
167 |
if (!hasLocalization()) {
|
168 |
return;
|
@@ -170,9 +187,11 @@ document.addEventListener("DOMContentLoaded", function () {
|
|
170 |
|
171 |
onUiUpdate(function (m) {
|
172 |
m.forEach(function (mutation) {
|
173 |
-
mutation.addedNodes
|
174 |
-
|
175 |
-
|
|
|
|
|
176 |
});
|
177 |
});
|
178 |
|
|
|
163 |
}
|
164 |
}
|
165 |
|
166 |
+
/**
|
167 |
+
*
|
168 |
+
* @param {HTMLElement} node
|
169 |
+
*/
|
170 |
+
function isNeedTranslate(node) {
|
171 |
+
if (!node) return false;
|
172 |
+
if (!(node instanceof HTMLElement)) return true;
|
173 |
+
while (node.parentElement !== document.body) {
|
174 |
+
if (node.classList.contains("no-translate")) {
|
175 |
+
return false;
|
176 |
+
}
|
177 |
+
node = node.parentElement;
|
178 |
+
if (!node) break;
|
179 |
+
}
|
180 |
+
return true;
|
181 |
+
}
|
182 |
+
|
183 |
document.addEventListener("DOMContentLoaded", function () {
|
184 |
if (!hasLocalization()) {
|
185 |
return;
|
|
|
187 |
|
188 |
onUiUpdate(function (m) {
|
189 |
m.forEach(function (mutation) {
|
190 |
+
Array.from(mutation.addedNodes)
|
191 |
+
.filter(isNeedTranslate)
|
192 |
+
.forEach(function (node) {
|
193 |
+
processNode(node);
|
194 |
+
});
|
195 |
});
|
196 |
});
|
197 |
|
modules/webui/tts_tab.py
CHANGED
@@ -96,7 +96,7 @@ def create_tts_interface():
|
|
96 |
)
|
97 |
|
98 |
gr.Markdown("📝Speaker info")
|
99 |
-
infos = gr.Markdown("empty")
|
100 |
|
101 |
spk_file_upload.change(
|
102 |
fn=load_spk_info,
|
|
|
96 |
)
|
97 |
|
98 |
gr.Markdown("📝Speaker info")
|
99 |
+
infos = gr.Markdown("empty", elem_classes=["no-translate"])
|
100 |
|
101 |
spk_file_upload.change(
|
102 |
fn=load_spk_info,
|
webui.py
CHANGED
@@ -1,27 +1,30 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
|
9 |
-
from
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from modules.webui import webui_config
|
12 |
from modules.webui.app import webui_init, create_interface
|
13 |
-
|
14 |
-
from modules import
|
15 |
|
16 |
-
|
17 |
-
import argparse
|
18 |
-
import dotenv
|
19 |
|
20 |
-
dotenv.load_dotenv(
|
21 |
-
dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
|
22 |
-
)
|
23 |
|
24 |
-
|
25 |
parser.add_argument("--server_name", type=str, help="server name")
|
26 |
parser.add_argument("--server_port", type=int, help="server port")
|
27 |
parser.add_argument(
|
@@ -29,16 +32,6 @@ if __name__ == "__main__":
|
|
29 |
)
|
30 |
parser.add_argument("--debug", action="store_true", help="enable debug mode")
|
31 |
parser.add_argument("--auth", type=str, help="username:password for authentication")
|
32 |
-
parser.add_argument(
|
33 |
-
"--half",
|
34 |
-
action="store_true",
|
35 |
-
help="Enable half precision for model inference",
|
36 |
-
)
|
37 |
-
parser.add_argument(
|
38 |
-
"--off_tqdm",
|
39 |
-
action="store_true",
|
40 |
-
help="Disable tqdm progress bar",
|
41 |
-
)
|
42 |
parser.add_argument(
|
43 |
"--tts_max_len",
|
44 |
type=int,
|
@@ -54,58 +47,39 @@ if __name__ == "__main__":
|
|
54 |
type=int,
|
55 |
help="Max batch size for TTS",
|
56 |
)
|
57 |
-
parser.add_argument(
|
58 |
-
"--lru_size",
|
59 |
-
type=int,
|
60 |
-
default=64,
|
61 |
-
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
62 |
-
)
|
63 |
-
parser.add_argument(
|
64 |
-
"--device_id",
|
65 |
-
type=str,
|
66 |
-
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
|
67 |
-
default=None,
|
68 |
-
)
|
69 |
-
parser.add_argument(
|
70 |
-
"--use_cpu",
|
71 |
-
nargs="+",
|
72 |
-
help="use CPU as torch device for specified modules",
|
73 |
-
default=[],
|
74 |
-
type=str.lower,
|
75 |
-
)
|
76 |
-
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
77 |
# webui_Experimental
|
78 |
parser.add_argument(
|
79 |
"--webui_experimental",
|
80 |
action="store_true",
|
81 |
help="Enable webui_experimental features",
|
82 |
)
|
83 |
-
|
84 |
parser.add_argument(
|
85 |
"--language",
|
86 |
type=str,
|
87 |
help="Set the default language for the webui",
|
88 |
)
|
89 |
-
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
def get_and_update_env(*args):
|
92 |
-
val = env.get_env_or_arg(*args)
|
93 |
-
key = args[1]
|
94 |
-
config.runtime_env_vars[key] = val
|
95 |
-
return val
|
96 |
|
|
|
97 |
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
|
98 |
server_port = get_and_update_env(args, "server_port", 7860, int)
|
99 |
share = get_and_update_env(args, "share", False, bool)
|
100 |
debug = get_and_update_env(args, "debug", False, bool)
|
101 |
auth = get_and_update_env(args, "auth", None, str)
|
102 |
-
half = get_and_update_env(args, "half", False, bool)
|
103 |
-
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
|
104 |
-
lru_size = get_and_update_env(args, "lru_size", 64, int)
|
105 |
-
device_id = get_and_update_env(args, "device_id", None, str)
|
106 |
-
use_cpu = get_and_update_env(args, "use_cpu", [], list)
|
107 |
-
compile = get_and_update_env(args, "compile", False, bool)
|
108 |
language = get_and_update_env(args, "language", "zh-CN", str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
webui_config.experimental = get_and_update_env(
|
111 |
args, "webui_experimental", False, bool
|
@@ -120,15 +94,57 @@ if __name__ == "__main__":
|
|
120 |
if auth:
|
121 |
auth = tuple(auth.split(":"))
|
122 |
|
123 |
-
|
124 |
-
devices.reset_device()
|
125 |
-
devices.first_time_calculation()
|
126 |
-
|
127 |
-
demo.queue().launch(
|
128 |
server_name=server_name,
|
129 |
server_port=server_port,
|
130 |
share=share,
|
131 |
debug=debug,
|
132 |
auth=auth,
|
133 |
show_api=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
4 |
+
logging.basicConfig(
|
5 |
+
level=os.getenv("LOG_LEVEL", "INFO"),
|
6 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
7 |
+
)
|
8 |
|
9 |
+
from launch import (
|
10 |
+
get_and_update_env,
|
11 |
+
setup_api_args,
|
12 |
+
setup_model_args,
|
13 |
+
process_api_args,
|
14 |
+
process_model_args,
|
15 |
+
app_description,
|
16 |
+
app_title,
|
17 |
+
app_version,
|
18 |
+
)
|
19 |
from modules.webui import webui_config
|
20 |
from modules.webui.app import webui_init, create_interface
|
21 |
+
import argparse
|
22 |
+
from modules.gradio_dcls_fix import dcls_patch
|
23 |
|
24 |
+
dcls_patch()
|
|
|
|
|
25 |
|
|
|
|
|
|
|
26 |
|
27 |
+
def setup_webui_args(parser: argparse.ArgumentParser):
|
28 |
parser.add_argument("--server_name", type=str, help="server name")
|
29 |
parser.add_argument("--server_port", type=int, help="server port")
|
30 |
parser.add_argument(
|
|
|
32 |
)
|
33 |
parser.add_argument("--debug", action="store_true", help="enable debug mode")
|
34 |
parser.add_argument("--auth", type=str, help="username:password for authentication")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
parser.add_argument(
|
36 |
"--tts_max_len",
|
37 |
type=int,
|
|
|
47 |
type=int,
|
48 |
help="Max batch size for TTS",
|
49 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# webui_Experimental
|
51 |
parser.add_argument(
|
52 |
"--webui_experimental",
|
53 |
action="store_true",
|
54 |
help="Enable webui_experimental features",
|
55 |
)
|
|
|
56 |
parser.add_argument(
|
57 |
"--language",
|
58 |
type=str,
|
59 |
help="Set the default language for the webui",
|
60 |
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--api",
|
63 |
+
action="store_true",
|
64 |
+
help="use api=True to launch the API together with the webui (run launch.py for only API server)",
|
65 |
+
)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
def process_webui_args(args):
|
69 |
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
|
70 |
server_port = get_and_update_env(args, "server_port", 7860, int)
|
71 |
share = get_and_update_env(args, "share", False, bool)
|
72 |
debug = get_and_update_env(args, "debug", False, bool)
|
73 |
auth = get_and_update_env(args, "auth", None, str)
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
language = get_and_update_env(args, "language", "zh-CN", str)
|
75 |
+
api = get_and_update_env(args, "api", "zh-CN", str)
|
76 |
+
|
77 |
+
webui_config.experimental = get_and_update_env(
|
78 |
+
args, "webui_experimental", False, bool
|
79 |
+
)
|
80 |
+
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
|
81 |
+
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
|
82 |
+
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
|
83 |
|
84 |
webui_config.experimental = get_and_update_env(
|
85 |
args, "webui_experimental", False, bool
|
|
|
94 |
if auth:
|
95 |
auth = tuple(auth.split(":"))
|
96 |
|
97 |
+
app, local_url, share_url = demo.queue().launch(
|
|
|
|
|
|
|
|
|
98 |
server_name=server_name,
|
99 |
server_port=server_port,
|
100 |
share=share,
|
101 |
debug=debug,
|
102 |
auth=auth,
|
103 |
show_api=False,
|
104 |
+
prevent_thread_lock=True,
|
105 |
+
app_kwargs={
|
106 |
+
"title": app_title,
|
107 |
+
"description": app_description,
|
108 |
+
"version": app_version,
|
109 |
+
# "redoc_url": (
|
110 |
+
# None
|
111 |
+
# if api is False
|
112 |
+
# else None if config.runtime_env_vars.no_docs else "/redoc"
|
113 |
+
# ),
|
114 |
+
# "docs_url": (
|
115 |
+
# None
|
116 |
+
# if api is False
|
117 |
+
# else None if config.runtime_env_vars.no_docs else "/docs"
|
118 |
+
# ),
|
119 |
+
"docs_url": "/docs",
|
120 |
+
},
|
121 |
+
)
|
122 |
+
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
123 |
+
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
124 |
+
# running web ui and do whatever the attacker wants, including installing an extension and
|
125 |
+
# running its code. We disable this here. Suggested by RyotaK.
|
126 |
+
app.user_middleware = [
|
127 |
+
x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware"
|
128 |
+
]
|
129 |
+
|
130 |
+
if api:
|
131 |
+
process_api_args(args, app)
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
import dotenv
|
136 |
+
|
137 |
+
dotenv.load_dotenv(
|
138 |
+
dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
|
139 |
)
|
140 |
+
|
141 |
+
parser = argparse.ArgumentParser(description="Gradio App")
|
142 |
+
|
143 |
+
setup_webui_args(parser)
|
144 |
+
setup_model_args(parser)
|
145 |
+
setup_api_args(parser)
|
146 |
+
|
147 |
+
args = parser.parse_args()
|
148 |
+
|
149 |
+
process_model_args(args)
|
150 |
+
process_webui_args(args)
|