Spaces:
Running
Running
Upload server_fastapi.py
Browse files- server_fastapi.py +263 -0
server_fastapi.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
API server for TTS
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
from io import BytesIO
|
8 |
+
from typing import Dict, Optional, Union
|
9 |
+
from urllib.parse import unquote
|
10 |
+
|
11 |
+
import GPUtil
|
12 |
+
import psutil
|
13 |
+
import torch
|
14 |
+
import uvicorn
|
15 |
+
from fastapi import FastAPI, HTTPException, Query, Request, status
|
16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
17 |
+
from fastapi.responses import FileResponse, Response
|
18 |
+
from scipy.io import wavfile
|
19 |
+
|
20 |
+
from common.constants import (
|
21 |
+
DEFAULT_ASSIST_TEXT_WEIGHT,
|
22 |
+
DEFAULT_LENGTH,
|
23 |
+
DEFAULT_LINE_SPLIT,
|
24 |
+
DEFAULT_NOISE,
|
25 |
+
DEFAULT_NOISEW,
|
26 |
+
DEFAULT_SDP_RATIO,
|
27 |
+
DEFAULT_SPLIT_INTERVAL,
|
28 |
+
DEFAULT_STYLE,
|
29 |
+
DEFAULT_STYLE_WEIGHT,
|
30 |
+
Languages,
|
31 |
+
)
|
32 |
+
from common.log import logger
|
33 |
+
from common.tts_model import Model, ModelHolder
|
34 |
+
from config import config
|
35 |
+
|
36 |
+
ln = config.server_config.language
|
37 |
+
|
38 |
+
|
39 |
+
def raise_validation_error(msg: str, param: str):
|
40 |
+
logger.warning(f"Validation error: {msg}")
|
41 |
+
raise HTTPException(
|
42 |
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
43 |
+
detail=[dict(type="invalid_params", msg=msg, loc=["query", param])],
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
class AudioResponse(Response):
|
48 |
+
media_type = "audio/wav"
|
49 |
+
|
50 |
+
|
51 |
+
def load_models(model_holder: ModelHolder):
|
52 |
+
model_holder.models = []
|
53 |
+
for model_name, model_paths in model_holder.model_files_dict.items():
|
54 |
+
model = Model(
|
55 |
+
model_path=model_paths[0],
|
56 |
+
config_path=os.path.join(model_holder.root_dir, model_name, "config.json"),
|
57 |
+
style_vec_path=os.path.join(
|
58 |
+
model_holder.root_dir, model_name, "style_vectors.npy"
|
59 |
+
),
|
60 |
+
device=model_holder.device,
|
61 |
+
)
|
62 |
+
model.load_net_g()
|
63 |
+
model_holder.models.append(model)
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
parser = argparse.ArgumentParser()
|
68 |
+
parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
|
69 |
+
parser.add_argument(
|
70 |
+
"--dir", "-d", type=str, help="Model directory", default=config.assets_root
|
71 |
+
)
|
72 |
+
args = parser.parse_args()
|
73 |
+
|
74 |
+
if args.cpu:
|
75 |
+
device = "cpu"
|
76 |
+
else:
|
77 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
78 |
+
|
79 |
+
model_dir = args.dir
|
80 |
+
model_holder = ModelHolder(model_dir, device)
|
81 |
+
if len(model_holder.model_names) == 0:
|
82 |
+
logger.error(f"Models not found in {model_dir}.")
|
83 |
+
sys.exit(1)
|
84 |
+
|
85 |
+
logger.info("Loading models...")
|
86 |
+
load_models(model_holder)
|
87 |
+
limit = config.server_config.limit
|
88 |
+
app = FastAPI()
|
89 |
+
allow_origins = config.server_config.origins
|
90 |
+
if allow_origins:
|
91 |
+
logger.warning(
|
92 |
+
f"CORS allow_origins={config.server_config.origins}. If you don't want, modify config.yml"
|
93 |
+
)
|
94 |
+
app.add_middleware(
|
95 |
+
CORSMiddleware,
|
96 |
+
allow_origins=config.server_config.origins,
|
97 |
+
allow_credentials=True,
|
98 |
+
allow_methods=["*"],
|
99 |
+
allow_headers=["*"],
|
100 |
+
)
|
101 |
+
app.logger = logger
|
102 |
+
|
103 |
+
@app.get("/voice", response_class=AudioResponse)
|
104 |
+
async def voice(
|
105 |
+
request: Request,
|
106 |
+
text: str = Query(..., min_length=1, max_length=limit, description=f"セリフ"),
|
107 |
+
encoding: str = Query(None, description="textをURLデコードする(ex, `utf-8`)"),
|
108 |
+
model_id: int = Query(0, description="モデルID。`GET /models/info`のkeyの値を指定ください"),
|
109 |
+
speaker_name: str = Query(
|
110 |
+
None, description="話者名(speaker_idより優先)。esd.listの2列目の文字列を指定"
|
111 |
+
),
|
112 |
+
speaker_id: int = Query(
|
113 |
+
0, description="話者ID。model_assets>[model]>config.json内のspk2idを確認"
|
114 |
+
),
|
115 |
+
sdp_ratio: float = Query(
|
116 |
+
DEFAULT_SDP_RATIO,
|
117 |
+
description="SDP(Stochastic Duration Predictor)/DP混合比。比率が高くなるほどトーンのばらつきが大きくなる",
|
118 |
+
),
|
119 |
+
noise: float = Query(DEFAULT_NOISE, description="サンプルノイズの割合。大きくするほどランダム性が高まる"),
|
120 |
+
noisew: float = Query(
|
121 |
+
DEFAULT_NOISEW, description="SDPノイズ。大きくするほど発音の間隔にばらつきが出やすくなる"
|
122 |
+
),
|
123 |
+
length: float = Query(
|
124 |
+
DEFAULT_LENGTH, description="話速。基準は1で大きくするほど音声は長くなり読み上げが遅まる"
|
125 |
+
),
|
126 |
+
language: Languages = Query(ln, description=f"textの言語"),
|
127 |
+
auto_split: bool = Query(DEFAULT_LINE_SPLIT, description="改行で分けて生成"),
|
128 |
+
split_interval: float = Query(
|
129 |
+
DEFAULT_SPLIT_INTERVAL, description="分けた場合に挟む無音の長さ(秒)"
|
130 |
+
),
|
131 |
+
assist_text: Optional[str] = Query(
|
132 |
+
None, description="このテキストの読み上げと似た声音・感情になりやすくなる。ただし抑揚やテンポ等が犠牲になる傾向がある"
|
133 |
+
),
|
134 |
+
assist_text_weight: float = Query(
|
135 |
+
DEFAULT_ASSIST_TEXT_WEIGHT, description="assist_textの強さ"
|
136 |
+
),
|
137 |
+
style: Optional[Union[int, str]] = Query(DEFAULT_STYLE, description="スタイル"),
|
138 |
+
style_weight: float = Query(DEFAULT_STYLE_WEIGHT, description="スタイルの強さ"),
|
139 |
+
reference_audio_path: Optional[str] = Query(None, description="スタイルを音声ファイルで行う"),
|
140 |
+
):
|
141 |
+
"""Infer text to speech(テキストから感情付き音声を生成する)"""
|
142 |
+
logger.info(
|
143 |
+
f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}"
|
144 |
+
)
|
145 |
+
if model_id >= len(model_holder.models): # /models/refresh があるためQuery(le)で表現不可
|
146 |
+
raise_validation_error(f"model_id={model_id} not found", "model_id")
|
147 |
+
|
148 |
+
model = model_holder.models[model_id]
|
149 |
+
if speaker_name is None:
|
150 |
+
if speaker_id not in model.id2spk.keys():
|
151 |
+
raise_validation_error(
|
152 |
+
f"speaker_id={speaker_id} not found", "speaker_id"
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
if speaker_name not in model.spk2id.keys():
|
156 |
+
raise_validation_error(
|
157 |
+
f"speaker_name={speaker_name} not found", "speaker_name"
|
158 |
+
)
|
159 |
+
speaker_id = model.spk2id[speaker_name]
|
160 |
+
if style not in model.style2id.keys():
|
161 |
+
raise_validation_error(f"style={style} not found", "style")
|
162 |
+
if encoding is not None:
|
163 |
+
text = unquote(text, encoding=encoding)
|
164 |
+
sr, audio = model.infer(
|
165 |
+
text=text,
|
166 |
+
language=language,
|
167 |
+
sid=speaker_id,
|
168 |
+
reference_audio_path=reference_audio_path,
|
169 |
+
sdp_ratio=sdp_ratio,
|
170 |
+
noise=noise,
|
171 |
+
noisew=noisew,
|
172 |
+
length=length,
|
173 |
+
line_split=auto_split,
|
174 |
+
split_interval=split_interval,
|
175 |
+
assist_text=assist_text,
|
176 |
+
assist_text_weight=assist_text_weight,
|
177 |
+
use_assist_text=bool(assist_text),
|
178 |
+
style=style,
|
179 |
+
style_weight=style_weight,
|
180 |
+
)
|
181 |
+
logger.success("Audio data generated and sent successfully")
|
182 |
+
with BytesIO() as wavContent:
|
183 |
+
wavfile.write(wavContent, sr, audio)
|
184 |
+
return Response(content=wavContent.getvalue(), media_type="audio/wav")
|
185 |
+
|
186 |
+
@app.get("/models/info")
|
187 |
+
def get_loaded_models_info():
|
188 |
+
"""ロードされたモデル情報の取得"""
|
189 |
+
|
190 |
+
result: Dict[str, Dict] = dict()
|
191 |
+
for model_id, model in enumerate(model_holder.models):
|
192 |
+
result[str(model_id)] = {
|
193 |
+
"config_path": model.config_path,
|
194 |
+
"model_path": model.model_path,
|
195 |
+
"device": model.device,
|
196 |
+
"spk2id": model.spk2id,
|
197 |
+
"id2spk": model.id2spk,
|
198 |
+
"style2id": model.style2id,
|
199 |
+
}
|
200 |
+
return result
|
201 |
+
|
202 |
+
@app.post("/models/refresh")
|
203 |
+
def refresh():
|
204 |
+
"""モデルをパスに追加/削除した際などに読み込ませる"""
|
205 |
+
model_holder.refresh()
|
206 |
+
load_models(model_holder)
|
207 |
+
return get_loaded_models_info()
|
208 |
+
|
209 |
+
@app.get("/status")
|
210 |
+
def get_status():
|
211 |
+
"""実行環境のステータスを取得"""
|
212 |
+
cpu_percent = psutil.cpu_percent(interval=1)
|
213 |
+
memory_info = psutil.virtual_memory()
|
214 |
+
memory_total = memory_info.total
|
215 |
+
memory_available = memory_info.available
|
216 |
+
memory_used = memory_info.used
|
217 |
+
memory_percent = memory_info.percent
|
218 |
+
gpuInfo = []
|
219 |
+
devices = ["cpu"]
|
220 |
+
for i in range(torch.cuda.device_count()):
|
221 |
+
devices.append(f"cuda:{i}")
|
222 |
+
gpus = GPUtil.getGPUs()
|
223 |
+
for gpu in gpus:
|
224 |
+
gpuInfo.append(
|
225 |
+
{
|
226 |
+
"gpu_id": gpu.id,
|
227 |
+
"gpu_load": gpu.load,
|
228 |
+
"gpu_memory": {
|
229 |
+
"total": gpu.memoryTotal,
|
230 |
+
"used": gpu.memoryUsed,
|
231 |
+
"free": gpu.memoryFree,
|
232 |
+
},
|
233 |
+
}
|
234 |
+
)
|
235 |
+
return {
|
236 |
+
"devices": devices,
|
237 |
+
"cpu_percent": cpu_percent,
|
238 |
+
"memory_total": memory_total,
|
239 |
+
"memory_available": memory_available,
|
240 |
+
"memory_used": memory_used,
|
241 |
+
"memory_percent": memory_percent,
|
242 |
+
"gpu": gpuInfo,
|
243 |
+
}
|
244 |
+
|
245 |
+
@app.get("/tools/get_audio", response_class=AudioResponse)
|
246 |
+
def get_audio(
|
247 |
+
request: Request, path: str = Query(..., description="local wav path")
|
248 |
+
):
|
249 |
+
"""wavデータを取得する"""
|
250 |
+
logger.info(
|
251 |
+
f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
|
252 |
+
)
|
253 |
+
if not os.path.isfile(path):
|
254 |
+
raise_validation_error(f"path={path} not found", "path")
|
255 |
+
if not path.lower().endswith(".wav"):
|
256 |
+
raise_validation_error(f"wav file not found in {path}", "path")
|
257 |
+
return FileResponse(path=path, media_type="audio/wav")
|
258 |
+
|
259 |
+
logger.info(f"server listen: http://127.0.0.1:{config.server_config.port}")
|
260 |
+
logger.info(f"API docs: http://127.0.0.1:{config.server_config.port}/docs")
|
261 |
+
uvicorn.run(
|
262 |
+
app, port=config.server_config.port, host="0.0.0.0", log_level="warning"
|
263 |
+
)
|