ElesisSiegherts commited on
Commit
245dd7d
1 Parent(s): ed6c2db

Upload 6 files

Browse files
Files changed (6) hide show
  1. requirements.txt +33 -0
  2. server_fastapi.py +642 -0
  3. spec_gen.py +87 -0
  4. transforms.py +209 -0
  5. update_status.py +89 -0
  6. utils.py +457 -0
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librosa==0.9.2
2
+ matplotlib
3
+ numpy
4
+ numba
5
+ phonemizer
6
+ scipy
7
+ tensorboard
8
+ Unidecode
9
+ amfm_decompy
10
+ jieba
11
+ transformers
12
+ pypinyin
13
+ cn2an
14
+ gradio==3.38.0
15
+ av
16
+ mecab-python3
17
+ loguru
18
+ unidic-lite
19
+ cmudict
20
+ fugashi
21
+ num2words
22
+ PyYAML
23
+ requests
24
+ pyopenjtalk; sys_platform == 'linux'
25
+ openjtalk; sys_platform != 'linux'
26
+ jaconv
27
+ psutil
28
+ GPUtil
29
+ vector_quantize_pytorch
30
+ g2p_en
31
+ sentencepiece
32
+ pykakasi
33
+ langid
server_fastapi.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ api服务 多版本多模型 fastapi实现
3
+ """
4
+ import logging
5
+ import gc
6
+ import random
7
+
8
+ import gradio
9
+ import numpy as np
10
+ import utils
11
+ from fastapi import FastAPI, Query, Request, File, UploadFile, Form
12
+ from fastapi.responses import Response, FileResponse
13
+ from fastapi.staticfiles import StaticFiles
14
+ from io import BytesIO
15
+ from scipy.io import wavfile
16
+ import uvicorn
17
+ import torch
18
+ import webbrowser
19
+ import psutil
20
+ import GPUtil
21
+ from typing import Dict, Optional, List, Set, Union
22
+ import os
23
+ from tools.log import logger
24
+ from urllib.parse import unquote
25
+
26
+ from infer import infer, get_net_g, latest_version
27
+ import tools.translate as trans
28
+ from re_matching import cut_sent
29
+
30
+
31
+ from config import config
32
+
33
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
34
+
35
+
36
+ class Model:
37
+ """模型封装类"""
38
+
39
+ def __init__(self, config_path: str, model_path: str, device: str, language: str):
40
+ self.config_path: str = os.path.normpath(config_path)
41
+ self.model_path: str = os.path.normpath(model_path)
42
+ self.device: str = device
43
+ self.language: str = language
44
+ self.hps = utils.get_hparams_from_file(config_path)
45
+ self.spk2id: Dict[str, int] = self.hps.data.spk2id # spk - id 映射字典
46
+ self.id2spk: Dict[int, str] = dict() # id - spk 映射字典
47
+ for speaker, speaker_id in self.hps.data.spk2id.items():
48
+ self.id2spk[speaker_id] = speaker
49
+ self.version: str = (
50
+ self.hps.version if hasattr(self.hps, "version") else latest_version
51
+ )
52
+ self.net_g = get_net_g(
53
+ model_path=model_path,
54
+ version=self.version,
55
+ device=device,
56
+ hps=self.hps,
57
+ )
58
+
59
+ def to_dict(self) -> Dict[str, any]:
60
+ return {
61
+ "config_path": self.config_path,
62
+ "model_path": self.model_path,
63
+ "device": self.device,
64
+ "language": self.language,
65
+ "spk2id": self.spk2id,
66
+ "id2spk": self.id2spk,
67
+ "version": self.version,
68
+ }
69
+
70
+
71
+ class Models:
72
+ def __init__(self):
73
+ self.models: Dict[int, Model] = dict()
74
+ self.num = 0
75
+ # spkInfo[角色名][模型id] = 角色id
76
+ self.spk_info: Dict[str, Dict[int, int]] = dict()
77
+ self.path2ids: Dict[str, Set[int]] = dict() # 路径指向的model的id
78
+
79
+ def init_model(
80
+ self, config_path: str, model_path: str, device: str, language: str
81
+ ) -> int:
82
+ """
83
+ 初始化并添加一个模型
84
+
85
+ :param config_path: 模型config.json路径
86
+ :param model_path: 模型路径
87
+ :param device: 模型推理使用设备
88
+ :param language: 模型推理默认语言
89
+ """
90
+ # 若文件不存在则不进行加载
91
+ if not os.path.isfile(model_path):
92
+ if model_path != "":
93
+ logger.warning(f"模型文件{model_path} 不存在,不进行初始化")
94
+ return self.num
95
+ if not os.path.isfile(config_path):
96
+ if config_path != "":
97
+ logger.warning(f"配置文件{config_path} 不存在,不进行初始化")
98
+ return self.num
99
+
100
+ # 若路径中的模型已存在,则不添加模型,若不存在,则进行初始化。
101
+ model_path = os.path.realpath(model_path)
102
+ if model_path not in self.path2ids.keys():
103
+ self.path2ids[model_path] = {self.num}
104
+ self.models[self.num] = Model(
105
+ config_path=config_path,
106
+ model_path=model_path,
107
+ device=device,
108
+ language=language,
109
+ )
110
+ logger.success(f"添加模型{model_path},使用配置文件{os.path.realpath(config_path)}")
111
+ else:
112
+ # 获取一个指向id
113
+ m_id = next(iter(self.path2ids[model_path]))
114
+ self.models[self.num] = self.models[m_id]
115
+ self.path2ids[model_path].add(self.num)
116
+ logger.success("模型已存在,添加模型引用。")
117
+ # 添加角色信息
118
+ for speaker, speaker_id in self.models[self.num].spk2id.items():
119
+ if speaker not in self.spk_info.keys():
120
+ self.spk_info[speaker] = {self.num: speaker_id}
121
+ else:
122
+ self.spk_info[speaker][self.num] = speaker_id
123
+ # 修改计数
124
+ self.num += 1
125
+ return self.num - 1
126
+
127
+ def del_model(self, index: int) -> Optional[int]:
128
+ """删除对应序号的模型,若不存在则返回None"""
129
+ if index not in self.models.keys():
130
+ return None
131
+ # 删除角色信息
132
+ for speaker, speaker_id in self.models[index].spk2id.items():
133
+ self.spk_info[speaker].pop(index)
134
+ if len(self.spk_info[speaker]) == 0:
135
+ # 若对应角色的所有模型都被删除,则清除该角色信息
136
+ self.spk_info.pop(speaker)
137
+ # 删除路径信息
138
+ model_path = os.path.realpath(self.models[index].model_path)
139
+ self.path2ids[model_path].remove(index)
140
+ if len(self.path2ids[model_path]) == 0:
141
+ self.path2ids.pop(model_path)
142
+ logger.success(f"删除模型{model_path}, id = {index}")
143
+ else:
144
+ logger.success(f"删除模型引用{model_path}, id = {index}")
145
+ # 删除模型
146
+ self.models.pop(index)
147
+ gc.collect()
148
+ if torch.cuda.is_available():
149
+ torch.cuda.empty_cache()
150
+ return index
151
+
152
+ def get_models(self):
153
+ """获取所有模型"""
154
+ return self.models
155
+
156
+
157
+ if __name__ == "__main__":
158
+ app = FastAPI()
159
+ app.logger = logger
160
+ # 挂载静态文件
161
+ logger.info("开始挂载网页页面")
162
+ StaticDir: str = "./Web"
163
+ if not os.path.isdir(StaticDir):
164
+ logger.warning(
165
+ "缺少网页资源,无法开启网页页面,如有需要请在 https://github.com/jiangyuxiaoxiao/Bert-VITS2-UI 或者Bert-VITS对应版本的release页面下载"
166
+ )
167
+ else:
168
+ dirs = [fir.name for fir in os.scandir(StaticDir) if fir.is_dir()]
169
+ files = [fir.name for fir in os.scandir(StaticDir) if fir.is_dir()]
170
+ for dirName in dirs:
171
+ app.mount(
172
+ f"/{dirName}",
173
+ StaticFiles(directory=f"./{StaticDir}/{dirName}"),
174
+ name=dirName,
175
+ )
176
+ loaded_models = Models()
177
+ # 加载模型
178
+ logger.info("开始加载模型")
179
+ models_info = config.server_config.models
180
+ for model_info in models_info:
181
+ loaded_models.init_model(
182
+ config_path=model_info["config"],
183
+ model_path=model_info["model"],
184
+ device=model_info["device"],
185
+ language=model_info["language"],
186
+ )
187
+
188
+ @app.get("/")
189
+ async def index():
190
+ return FileResponse("./Web/index.html")
191
+
192
+ async def _voice(
193
+ text: str,
194
+ model_id: int,
195
+ speaker_name: str,
196
+ speaker_id: int,
197
+ sdp_ratio: float,
198
+ noise: float,
199
+ noisew: float,
200
+ length: float,
201
+ language: str,
202
+ auto_translate: bool,
203
+ auto_split: bool,
204
+ emotion: Optional[int] = None,
205
+ reference_audio=None,
206
+ ) -> Union[Response, Dict[str, any]]:
207
+ """TTS实现函数"""
208
+ # 检查模型是否存在
209
+ if model_id not in loaded_models.models.keys():
210
+ return {"status": 10, "detail": f"模型model_id={model_id}未加载"}
211
+ # 检查是否提供speaker
212
+ if speaker_name is None and speaker_id is None:
213
+ return {"status": 11, "detail": "请提供speaker_name或speaker_id"}
214
+ elif speaker_name is None:
215
+ # 检查speaker_id是否存在
216
+ if speaker_id not in loaded_models.models[model_id].id2spk.keys():
217
+ return {"status": 12, "detail": f"角色speaker_id={speaker_id}不存在"}
218
+ speaker_name = loaded_models.models[model_id].id2spk[speaker_id]
219
+ # 检查speaker_name是否存在
220
+ if speaker_name not in loaded_models.models[model_id].spk2id.keys():
221
+ return {"status": 13, "detail": f"角色speaker_name={speaker_name}不存在"}
222
+ if language is None:
223
+ language = loaded_models.models[model_id].language
224
+ if auto_translate:
225
+ text = trans.translate(Sentence=text, to_Language=language.lower())
226
+ if reference_audio is not None:
227
+ ref_audio = BytesIO(await reference_audio.read())
228
+ else:
229
+ ref_audio = reference_audio
230
+ if not auto_split:
231
+ with torch.no_grad():
232
+ audio = infer(
233
+ text=text,
234
+ sdp_ratio=sdp_ratio,
235
+ noise_scale=noise,
236
+ noise_scale_w=noisew,
237
+ length_scale=length,
238
+ sid=speaker_name,
239
+ language=language,
240
+ hps=loaded_models.models[model_id].hps,
241
+ net_g=loaded_models.models[model_id].net_g,
242
+ device=loaded_models.models[model_id].device,
243
+ emotion=emotion,
244
+ reference_audio=ref_audio,
245
+ )
246
+ audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
247
+ else:
248
+ texts = cut_sent(text)
249
+ audios = []
250
+ with torch.no_grad():
251
+ for t in texts:
252
+ audios.append(
253
+ infer(
254
+ text=t,
255
+ sdp_ratio=sdp_ratio,
256
+ noise_scale=noise,
257
+ noise_scale_w=noisew,
258
+ length_scale=length,
259
+ sid=speaker_name,
260
+ language=language,
261
+ hps=loaded_models.models[model_id].hps,
262
+ net_g=loaded_models.models[model_id].net_g,
263
+ device=loaded_models.models[model_id].device,
264
+ emotion=emotion,
265
+ reference_audio=ref_audio,
266
+ )
267
+ )
268
+ audios.append(np.zeros(int(44100 * 0.2)))
269
+ audio = np.concatenate(audios)
270
+ audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
271
+ with BytesIO() as wavContent:
272
+ wavfile.write(
273
+ wavContent, loaded_models.models[model_id].hps.data.sampling_rate, audio
274
+ )
275
+ response = Response(content=wavContent.getvalue(), media_type="audio/wav")
276
+ return response
277
+
278
+ @app.post("/voice")
279
+ async def voice(
280
+ request: Request, # fastapi自动注入
281
+ text: str = Form(...),
282
+ model_id: int = Query(..., description="模型ID"), # 模型序号
283
+ speaker_name: str = Query(
284
+ None, description="说话人名"
285
+ ), # speaker_name与 speaker_id二者选其一
286
+ speaker_id: int = Query(None, description="说话人id,与speaker_name二选一"),
287
+ sdp_ratio: float = Query(0.2, description="SDP/DP混合比"),
288
+ noise: float = Query(0.2, description="感情"),
289
+ noisew: float = Query(0.9, description="音素长度"),
290
+ length: float = Query(1, description="语速"),
291
+ language: str = Query(None, description="语言"), # 若不指定使用语言则使用默认值
292
+ auto_translate: bool = Query(False, description="自动翻译"),
293
+ auto_split: bool = Query(False, description="自动切分"),
294
+ emotion: Optional[int] = Query(None, description="emo"),
295
+ reference_audio: UploadFile = File(None),
296
+ ):
297
+ """语音接口,若需要上传参考音频请仅使用post请求"""
298
+ logger.info(
299
+ f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )} text={text}"
300
+ )
301
+ return await _voice(
302
+ text=text,
303
+ model_id=model_id,
304
+ speaker_name=speaker_name,
305
+ speaker_id=speaker_id,
306
+ sdp_ratio=sdp_ratio,
307
+ noise=noise,
308
+ noisew=noisew,
309
+ length=length,
310
+ language=language,
311
+ auto_translate=auto_translate,
312
+ auto_split=auto_split,
313
+ emotion=emotion,
314
+ reference_audio=reference_audio,
315
+ )
316
+
317
+ @app.get("/voice")
318
+ async def voice(
319
+ request: Request, # fastapi自动注入
320
+ text: str = Query(..., description="输入文字"),
321
+ model_id: int = Query(..., description="模型ID"), # 模型序号
322
+ speaker_name: str = Query(
323
+ None, description="说话人名"
324
+ ), # speaker_name与 speaker_id二者选其一
325
+ speaker_id: int = Query(None, description="说话人id,与speaker_name二选一"),
326
+ sdp_ratio: float = Query(0.2, description="SDP/DP混合比"),
327
+ noise: float = Query(0.2, description="感情"),
328
+ noisew: float = Query(0.9, description="音素长度"),
329
+ length: float = Query(1, description="语速"),
330
+ language: str = Query(None, description="语言"), # 若不指定使用语言则使用默认值
331
+ auto_translate: bool = Query(False, description="自动翻译"),
332
+ auto_split: bool = Query(False, description="自动切分"),
333
+ emotion: Optional[int] = Query(None, description="emo"),
334
+ ):
335
+ """语音接口"""
336
+ logger.info(
337
+ f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}"
338
+ )
339
+ return await _voice(
340
+ text=text,
341
+ model_id=model_id,
342
+ speaker_name=speaker_name,
343
+ speaker_id=speaker_id,
344
+ sdp_ratio=sdp_ratio,
345
+ noise=noise,
346
+ noisew=noisew,
347
+ length=length,
348
+ language=language,
349
+ auto_translate=auto_translate,
350
+ auto_split=auto_split,
351
+ emotion=emotion,
352
+ )
353
+
354
+ @app.get("/models/info")
355
+ def get_loaded_models_info(request: Request):
356
+ """获取已加载模型信息"""
357
+
358
+ result: Dict[str, Dict] = dict()
359
+ for key, model in loaded_models.models.items():
360
+ result[str(key)] = model.to_dict()
361
+ return result
362
+
363
+ @app.get("/models/delete")
364
+ def delete_model(
365
+ request: Request, model_id: int = Query(..., description="删除模型id")
366
+ ):
367
+ """删除指定模型"""
368
+ logger.info(
369
+ f"{request.client.host}:{request.client.port}/models/delete { unquote(str(request.query_params) )}"
370
+ )
371
+ result = loaded_models.del_model(model_id)
372
+ if result is None:
373
+ return {"status": 14, "detail": f"模型{model_id}不存在,删除失败"}
374
+ return {"status": 0, "detail": "删除成功"}
375
+
376
+ @app.get("/models/add")
377
+ def add_model(
378
+ request: Request,
379
+ model_path: str = Query(..., description="添加模型路径"),
380
+ config_path: str = Query(
381
+ None, description="添加模型配置文件路径,不填则使用./config.json或../config.json"
382
+ ),
383
+ device: str = Query("cuda", description="推理使用设备"),
384
+ language: str = Query("ZH", description="模型默认语言"),
385
+ ):
386
+ """添加指定模型:允许重复添加相同路径模型,且不重复占用内存"""
387
+ logger.info(
388
+ f"{request.client.host}:{request.client.port}/models/add { unquote(str(request.query_params) )}"
389
+ )
390
+ if config_path is None:
391
+ model_dir = os.path.dirname(model_path)
392
+ if os.path.isfile(os.path.join(model_dir, "config.json")):
393
+ config_path = os.path.join(model_dir, "config.json")
394
+ elif os.path.isfile(os.path.join(model_dir, "../config.json")):
395
+ config_path = os.path.join(model_dir, "../config.json")
396
+ else:
397
+ return {
398
+ "status": 15,
399
+ "detail": "查询未传入配置文件路径,同时默认路径./与../中不存在配置文件config.json。",
400
+ }
401
+ try:
402
+ model_id = loaded_models.init_model(
403
+ config_path=config_path,
404
+ model_path=model_path,
405
+ device=device,
406
+ language=language,
407
+ )
408
+ except Exception:
409
+ logging.exception("模型加载出错")
410
+ return {
411
+ "status": 16,
412
+ "detail": "模型加载出错,详细查看日志",
413
+ }
414
+ return {
415
+ "status": 0,
416
+ "detail": "模型添加成功",
417
+ "Data": {
418
+ "model_id": model_id,
419
+ "model_info": loaded_models.models[model_id].to_dict(),
420
+ },
421
+ }
422
+
423
+ def _get_all_models(root_dir: str = "Data", only_unloaded: bool = False):
424
+ """从root_dir搜索获取所有可用模型"""
425
+ result: Dict[str, List[str]] = dict()
426
+ files = os.listdir(root_dir) + ["."]
427
+ for file in files:
428
+ if os.path.isdir(os.path.join(root_dir, file)):
429
+ sub_dir = os.path.join(root_dir, file)
430
+ # 搜索 "sub_dir" 、 "sub_dir/models" 两个路径
431
+ result[file] = list()
432
+ sub_files = os.listdir(sub_dir)
433
+ model_files = []
434
+ for sub_file in sub_files:
435
+ relpath = os.path.realpath(os.path.join(sub_dir, sub_file))
436
+ if only_unloaded and relpath in loaded_models.path2ids.keys():
437
+ continue
438
+ if sub_file.endswith(".pth") and sub_file.startswith("G_"):
439
+ if os.path.isfile(relpath):
440
+ model_files.append(sub_file)
441
+ # 对模型文件按步数排序
442
+ model_files = sorted(
443
+ model_files,
444
+ key=lambda pth: int(pth.lstrip("G_").rstrip(".pth"))
445
+ if pth.lstrip("G_").rstrip(".pth").isdigit()
446
+ else 10**10,
447
+ )
448
+ result[file] = model_files
449
+ models_dir = os.path.join(sub_dir, "models")
450
+ model_files = []
451
+ if os.path.isdir(models_dir):
452
+ sub_files = os.listdir(models_dir)
453
+ for sub_file in sub_files:
454
+ relpath = os.path.realpath(os.path.join(models_dir, sub_file))
455
+ if only_unloaded and relpath in loaded_models.path2ids.keys():
456
+ continue
457
+ if sub_file.endswith(".pth") and sub_file.startswith("G_"):
458
+ if os.path.isfile(os.path.join(models_dir, sub_file)):
459
+ model_files.append(f"models/{sub_file}")
460
+ # 对模型文件按步数排序
461
+ model_files = sorted(
462
+ model_files,
463
+ key=lambda pth: int(pth.lstrip("models/G_").rstrip(".pth"))
464
+ if pth.lstrip("models/G_").rstrip(".pth").isdigit()
465
+ else 10**10,
466
+ )
467
+ result[file] += model_files
468
+ if len(result[file]) == 0:
469
+ result.pop(file)
470
+
471
+ return result
472
+
473
+ @app.get("/models/get_unloaded")
474
+ def get_unloaded_models_info(
475
+ request: Request, root_dir: str = Query("Data", description="搜索根目录")
476
+ ):
477
+ """获取未加载模型"""
478
+ logger.info(
479
+ f"{request.client.host}:{request.client.port}/models/get_unloaded { unquote(str(request.query_params) )}"
480
+ )
481
+ return _get_all_models(root_dir, only_unloaded=True)
482
+
483
+ @app.get("/models/get_local")
484
+ def get_local_models_info(
485
+ request: Request, root_dir: str = Query("Data", description="搜索根目录")
486
+ ):
487
+ """获取全部本地模型"""
488
+ logger.info(
489
+ f"{request.client.host}:{request.client.port}/models/get_local { unquote(str(request.query_params) )}"
490
+ )
491
+ return _get_all_models(root_dir, only_unloaded=False)
492
+
493
+ @app.get("/status")
494
+ def get_status():
495
+ """获取电脑运行状态"""
496
+ cpu_percent = psutil.cpu_percent(interval=1)
497
+ memory_info = psutil.virtual_memory()
498
+ memory_total = memory_info.total
499
+ memory_available = memory_info.available
500
+ memory_used = memory_info.used
501
+ memory_percent = memory_info.percent
502
+ gpuInfo = []
503
+ devices = ["cpu"]
504
+ for i in range(torch.cuda.device_count()):
505
+ devices.append(f"cuda:{i}")
506
+ gpus = GPUtil.getGPUs()
507
+ for gpu in gpus:
508
+ gpuInfo.append(
509
+ {
510
+ "gpu_id": gpu.id,
511
+ "gpu_load": gpu.load,
512
+ "gpu_memory": {
513
+ "total": gpu.memoryTotal,
514
+ "used": gpu.memoryUsed,
515
+ "free": gpu.memoryFree,
516
+ },
517
+ }
518
+ )
519
+ return {
520
+ "devices": devices,
521
+ "cpu_percent": cpu_percent,
522
+ "memory_total": memory_total,
523
+ "memory_available": memory_available,
524
+ "memory_used": memory_used,
525
+ "memory_percent": memory_percent,
526
+ "gpu": gpuInfo,
527
+ }
528
+
529
+ @app.get("/tools/translate")
530
+ def translate(
531
+ request: Request,
532
+ texts: str = Query(..., description="待翻译文本"),
533
+ to_language: str = Query(..., description="翻译目标语言"),
534
+ ):
535
+ """翻译"""
536
+ logger.info(
537
+ f"{request.client.host}:{request.client.port}/tools/translate { unquote(str(request.query_params) )}"
538
+ )
539
+ return {"texts": trans.translate(Sentence=texts, to_Language=to_language)}
540
+
541
+ all_examples: Dict[str, Dict[str, List]] = dict() # 存放示例
542
+
543
+ @app.get("/tools/random_example")
544
+ def random_example(
545
+ request: Request,
546
+ language: str = Query(None, description="指定语言,未指定则随机返回"),
547
+ root_dir: str = Query("Data", description="搜索根目录"),
548
+ ):
549
+ """
550
+ 获取一个随机音频+文本,用于对比,音频会从本地目录随机选择。
551
+ """
552
+ logger.info(
553
+ f"{request.client.host}:{request.client.port}/tools/random_example { unquote(str(request.query_params) )}"
554
+ )
555
+ global all_examples
556
+ # 数据初始化
557
+ if root_dir not in all_examples.keys():
558
+ all_examples[root_dir] = {"ZH": [], "JP": [], "EN": []}
559
+
560
+ examples = all_examples[root_dir]
561
+
562
+ # 从项目Data目录中搜索train/val.list
563
+ for root, directories, _files in os.walk(root_dir):
564
+ for file in _files:
565
+ if file in ["train.list", "val.list"]:
566
+ with open(
567
+ os.path.join(root, file), mode="r", encoding="utf-8"
568
+ ) as f:
569
+ lines = f.readlines()
570
+ for line in lines:
571
+ data = line.split("|")
572
+ if len(data) != 7:
573
+ continue
574
+ # 音频存在 且语言为ZH/EN/JP
575
+ if os.path.isfile(data[0]) and data[2] in [
576
+ "ZH",
577
+ "JP",
578
+ "EN",
579
+ ]:
580
+ examples[data[2]].append(
581
+ {
582
+ "text": data[3],
583
+ "audio": data[0],
584
+ "speaker": data[1],
585
+ }
586
+ )
587
+
588
+ examples = all_examples[root_dir]
589
+ if language is None:
590
+ if len(examples["ZH"]) + len(examples["JP"]) + len(examples["EN"]) == 0:
591
+ return {"status": 17, "detail": "没有加载任何示例数据"}
592
+ else:
593
+ # 随机选一个
594
+ rand_num = random.randint(
595
+ 0,
596
+ len(examples["ZH"]) + len(examples["JP"]) + len(examples["EN"]) - 1,
597
+ )
598
+ # ZH
599
+ if rand_num < len(examples["ZH"]):
600
+ return {"status": 0, "Data": examples["ZH"][rand_num]}
601
+ # JP
602
+ if rand_num < len(examples["ZH"]) + len(examples["JP"]):
603
+ return {
604
+ "status": 0,
605
+ "Data": examples["JP"][rand_num - len(examples["ZH"])],
606
+ }
607
+ # EN
608
+ return {
609
+ "status": 0,
610
+ "Data": examples["EN"][
611
+ rand_num - len(examples["ZH"]) - len(examples["JP"])
612
+ ],
613
+ }
614
+
615
+ else:
616
+ if len(examples[language]) == 0:
617
+ return {"status": 17, "detail": f"没有加载任何{language}数据"}
618
+ return {
619
+ "status": 0,
620
+ "Data": examples[language][
621
+ random.randint(0, len(examples[language]) - 1)
622
+ ],
623
+ }
624
+
625
+ @app.get("/tools/get_audio")
626
+ def get_audio(request: Request, path: str = Query(..., description="本地音频路径")):
627
+ logger.info(
628
+ f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
629
+ )
630
+ if not os.path.isfile(path):
631
+ return {"status": 18, "detail": "指定音频不存在"}
632
+ if not path.endswith(".wav"):
633
+ return {"status": 19, "detail": "非wav格式文件"}
634
+ return FileResponse(path=path)
635
+
636
+ logger.warning("本地服务,请勿将服务端口暴露于外网")
637
+ logger.info(f"api文档地址 http://127.0.0.1:{config.server_config.port}/docs")
638
+ if os.path.isdir(StaticDir):
639
+ webbrowser.open(f"http://127.0.0.1:{config.server_config.port}")
640
+ uvicorn.run(
641
+ app, port=config.server_config.port, host="0.0.0.0", log_level="warning"
642
+ )
spec_gen.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from multiprocessing import Pool
4
+ from mel_processing import spectrogram_torch, mel_spectrogram_torch
5
+ from utils import load_wav_to_torch
6
+
7
+
8
+ class AudioProcessor:
9
+ def __init__(
10
+ self,
11
+ max_wav_value,
12
+ use_mel_spec_posterior,
13
+ filter_length,
14
+ n_mel_channels,
15
+ sampling_rate,
16
+ hop_length,
17
+ win_length,
18
+ mel_fmin,
19
+ mel_fmax,
20
+ ):
21
+ self.max_wav_value = max_wav_value
22
+ self.use_mel_spec_posterior = use_mel_spec_posterior
23
+ self.filter_length = filter_length
24
+ self.n_mel_channels = n_mel_channels
25
+ self.sampling_rate = sampling_rate
26
+ self.hop_length = hop_length
27
+ self.win_length = win_length
28
+ self.mel_fmin = mel_fmin
29
+ self.mel_fmax = mel_fmax
30
+
31
+ def process_audio(self, filename):
32
+ audio, sampling_rate = load_wav_to_torch(filename)
33
+ audio_norm = audio / self.max_wav_value
34
+ audio_norm = audio_norm.unsqueeze(0)
35
+ spec_filename = filename.replace(".wav", ".spec.pt")
36
+ if self.use_mel_spec_posterior:
37
+ spec_filename = spec_filename.replace(".spec.pt", ".mel.pt")
38
+ try:
39
+ spec = torch.load(spec_filename)
40
+ except:
41
+ if self.use_mel_spec_posterior:
42
+ spec = mel_spectrogram_torch(
43
+ audio_norm,
44
+ self.filter_length,
45
+ self.n_mel_channels,
46
+ self.sampling_rate,
47
+ self.hop_length,
48
+ self.win_length,
49
+ self.mel_fmin,
50
+ self.mel_fmax,
51
+ center=False,
52
+ )
53
+ else:
54
+ spec = spectrogram_torch(
55
+ audio_norm,
56
+ self.filter_length,
57
+ self.sampling_rate,
58
+ self.hop_length,
59
+ self.win_length,
60
+ center=False,
61
+ )
62
+ spec = torch.squeeze(spec, 0)
63
+ torch.save(spec, spec_filename)
64
+ return spec, audio_norm
65
+
66
+
67
+ # 使用示例
68
+ processor = AudioProcessor(
69
+ max_wav_value=32768.0,
70
+ use_mel_spec_posterior=False,
71
+ filter_length=2048,
72
+ n_mel_channels=128,
73
+ sampling_rate=44100,
74
+ hop_length=512,
75
+ win_length=2048,
76
+ mel_fmin=0.0,
77
+ mel_fmax="null",
78
+ )
79
+
80
+ with open("filelists/train.list", "r") as f:
81
+ filepaths = [line.split("|")[0] for line in f] # 取每一行的第一部分作为audiopath
82
+
83
+ # 使用多进程处理
84
+ with Pool(processes=32) as pool: # 使用4个进程
85
+ with tqdm(total=len(filepaths)) as pbar:
86
+ for i, _ in enumerate(pool.imap_unordered(processor.process_audio, filepaths)):
87
+ pbar.update()
transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
update_status.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ lang_dict = {"EN(英文)": "_en", "ZH(中文)": "_zh", "JP(日语)": "_jp"}
5
+
6
+
7
+ def raw_dir_convert_to_path(target_dir: str, lang):
8
+ res = target_dir.rstrip("/").rstrip("\\")
9
+ if (not target_dir.startswith("raw")) and (not target_dir.startswith("./raw")):
10
+ res = os.path.join("./raw", res)
11
+ if (
12
+ (not res.endswith("_zh"))
13
+ and (not res.endswith("_jp"))
14
+ and (not res.endswith("_en"))
15
+ ):
16
+ res += lang_dict[lang]
17
+ return res
18
+
19
+
20
+ def update_g_files():
21
+ g_files = []
22
+ cnt = 0
23
+ for root, dirs, files in os.walk(os.path.abspath("./logs")):
24
+ for file in files:
25
+ if file.startswith("G_") and file.endswith(".pth"):
26
+ g_files.append(os.path.join(root, file))
27
+ cnt += 1
28
+ print(g_files)
29
+ return f"更新模型列表完成, 共找到{cnt}个模型", gr.Dropdown.update(choices=g_files)
30
+
31
+
32
+ def update_c_files():
33
+ c_files = []
34
+ cnt = 0
35
+ for root, dirs, files in os.walk(os.path.abspath("./logs")):
36
+ for file in files:
37
+ if file.startswith("config.json"):
38
+ c_files.append(os.path.join(root, file))
39
+ cnt += 1
40
+ print(c_files)
41
+ return f"更新模型列表完成, 共找到{cnt}个配置文件", gr.Dropdown.update(choices=c_files)
42
+
43
+
44
+ def update_model_folders():
45
+ subdirs = []
46
+ cnt = 0
47
+ for root, dirs, files in os.walk(os.path.abspath("./logs")):
48
+ for dir_name in dirs:
49
+ if os.path.basename(dir_name) != "eval":
50
+ subdirs.append(os.path.join(root, dir_name))
51
+ cnt += 1
52
+ print(subdirs)
53
+ return f"更新模型文件夹列表完成, 共找到{cnt}个文件夹", gr.Dropdown.update(choices=subdirs)
54
+
55
+
56
+ def update_wav_lab_pairs():
57
+ wav_count = tot_count = 0
58
+ for root, _, files in os.walk("./raw"):
59
+ for file in files:
60
+ # print(file)
61
+ file_path = os.path.join(root, file)
62
+ if file.lower().endswith(".wav"):
63
+ lab_file = os.path.splitext(file_path)[0] + ".lab"
64
+ if os.path.exists(lab_file):
65
+ wav_count += 1
66
+ tot_count += 1
67
+ return f"{wav_count} / {tot_count}"
68
+
69
+
70
+ def update_raw_folders():
71
+ subdirs = []
72
+ cnt = 0
73
+ script_path = os.path.dirname(os.path.abspath(__file__)) # 获取当前脚本的绝对路径
74
+ raw_path = os.path.join(script_path, "raw")
75
+ print(raw_path)
76
+ os.makedirs(raw_path, exist_ok=True)
77
+ for root, dirs, files in os.walk(raw_path):
78
+ for dir_name in dirs:
79
+ relative_path = os.path.relpath(
80
+ os.path.join(root, dir_name), script_path
81
+ ) # 获取相对路径
82
+ subdirs.append(relative_path)
83
+ cnt += 1
84
+ print(subdirs)
85
+ return (
86
+ f"更新raw音频文件夹列表完成, 共找到{cnt}个文件夹",
87
+ gr.Dropdown.update(choices=subdirs),
88
+ gr.Textbox.update(value=update_wav_lab_pairs()),
89
+ )
utils.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import argparse
4
+ import logging
5
+ import json
6
+ import shutil
7
+ import subprocess
8
+ import numpy as np
9
+ from huggingface_hub import hf_hub_download
10
+ from scipy.io.wavfile import read
11
+ import torch
12
+ import re
13
+
14
+ MATPLOTLIB_FLAG = False
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def download_emo_models(mirror, repo_id, model_name):
20
+ if mirror == "openi":
21
+ import openi
22
+
23
+ openi.model.download_model(
24
+ "Stardust_minus/Bert-VITS2",
25
+ repo_id.split("/")[-1],
26
+ "./emotional",
27
+ )
28
+ else:
29
+ hf_hub_download(
30
+ repo_id,
31
+ "pytorch_model.bin",
32
+ local_dir=model_name,
33
+ local_dir_use_symlinks=False,
34
+ )
35
+
36
+
37
+ def download_checkpoint(
38
+ dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
39
+ ):
40
+ repo_id = repo_config["repo_id"]
41
+ f_list = glob.glob(os.path.join(dir_path, regex))
42
+ if f_list:
43
+ print("Use existed model, skip downloading.")
44
+ return
45
+ if mirror.lower() == "openi":
46
+ import openi
47
+
48
+ kwargs = {"token": token} if token else {}
49
+ openi.login(**kwargs)
50
+
51
+ model_image = repo_config["model_image"]
52
+ openi.model.download_model(repo_id, model_image, dir_path)
53
+
54
+ fs = glob.glob(os.path.join(dir_path, model_image, "*.pth"))
55
+ for file in fs:
56
+ shutil.move(file, dir_path)
57
+ shutil.rmtree(os.path.join(dir_path, model_image))
58
+ else:
59
+ for file in ["DUR_0.pth", "D_0.pth", "G_0.pth"]:
60
+ hf_hub_download(
61
+ repo_id, file, local_dir=dir_path, local_dir_use_symlinks=False
62
+ )
63
+
64
+
65
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
66
+ assert os.path.isfile(checkpoint_path)
67
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
68
+ iteration = checkpoint_dict["iteration"]
69
+ learning_rate = checkpoint_dict["learning_rate"]
70
+ if (
71
+ optimizer is not None
72
+ and not skip_optimizer
73
+ and checkpoint_dict["optimizer"] is not None
74
+ ):
75
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
76
+ elif optimizer is None and not skip_optimizer:
77
+ # else: Disable this line if Infer and resume checkpoint,then enable the line upper
78
+ new_opt_dict = optimizer.state_dict()
79
+ new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
80
+ new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
81
+ new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
82
+ optimizer.load_state_dict(new_opt_dict)
83
+
84
+ saved_state_dict = checkpoint_dict["model"]
85
+ if hasattr(model, "module"):
86
+ state_dict = model.module.state_dict()
87
+ else:
88
+ state_dict = model.state_dict()
89
+
90
+ new_state_dict = {}
91
+ for k, v in state_dict.items():
92
+ try:
93
+ # assert "emb_g" not in k
94
+ new_state_dict[k] = saved_state_dict[k]
95
+ assert saved_state_dict[k].shape == v.shape, (
96
+ saved_state_dict[k].shape,
97
+ v.shape,
98
+ )
99
+ except:
100
+ # For upgrading from the old version
101
+ if "ja_bert_proj" in k:
102
+ v = torch.zeros_like(v)
103
+ logger.warn(
104
+ f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
105
+ )
106
+ else:
107
+ logger.error(f"{k} is not in the checkpoint")
108
+
109
+ new_state_dict[k] = v
110
+
111
+ if hasattr(model, "module"):
112
+ model.module.load_state_dict(new_state_dict, strict=False)
113
+ else:
114
+ model.load_state_dict(new_state_dict, strict=False)
115
+
116
+ logger.info(
117
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
118
+ )
119
+
120
+ return model, optimizer, learning_rate, iteration
121
+
122
+
123
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
124
+ logger.info(
125
+ "Saving model and optimizer state at iteration {} to {}".format(
126
+ iteration, checkpoint_path
127
+ )
128
+ )
129
+ if hasattr(model, "module"):
130
+ state_dict = model.module.state_dict()
131
+ else:
132
+ state_dict = model.state_dict()
133
+ torch.save(
134
+ {
135
+ "model": state_dict,
136
+ "iteration": iteration,
137
+ "optimizer": optimizer.state_dict(),
138
+ "learning_rate": learning_rate,
139
+ },
140
+ checkpoint_path,
141
+ )
142
+
143
+
144
+ def summarize(
145
+ writer,
146
+ global_step,
147
+ scalars={},
148
+ histograms={},
149
+ images={},
150
+ audios={},
151
+ audio_sampling_rate=22050,
152
+ ):
153
+ for k, v in scalars.items():
154
+ writer.add_scalar(k, v, global_step)
155
+ for k, v in histograms.items():
156
+ writer.add_histogram(k, v, global_step)
157
+ for k, v in images.items():
158
+ writer.add_image(k, v, global_step, dataformats="HWC")
159
+ for k, v in audios.items():
160
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
161
+
162
+
163
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
164
+ f_list = glob.glob(os.path.join(dir_path, regex))
165
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
166
+ x = f_list[-1]
167
+ return x
168
+
169
+
170
+ def plot_spectrogram_to_numpy(spectrogram):
171
+ global MATPLOTLIB_FLAG
172
+ if not MATPLOTLIB_FLAG:
173
+ import matplotlib
174
+
175
+ matplotlib.use("Agg")
176
+ MATPLOTLIB_FLAG = True
177
+ mpl_logger = logging.getLogger("matplotlib")
178
+ mpl_logger.setLevel(logging.WARNING)
179
+ import matplotlib.pylab as plt
180
+ import numpy as np
181
+
182
+ fig, ax = plt.subplots(figsize=(10, 2))
183
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
184
+ plt.colorbar(im, ax=ax)
185
+ plt.xlabel("Frames")
186
+ plt.ylabel("Channels")
187
+ plt.tight_layout()
188
+
189
+ fig.canvas.draw()
190
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
191
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
192
+ plt.close()
193
+ return data
194
+
195
+
196
+ def plot_alignment_to_numpy(alignment, info=None):
197
+ global MATPLOTLIB_FLAG
198
+ if not MATPLOTLIB_FLAG:
199
+ import matplotlib
200
+
201
+ matplotlib.use("Agg")
202
+ MATPLOTLIB_FLAG = True
203
+ mpl_logger = logging.getLogger("matplotlib")
204
+ mpl_logger.setLevel(logging.WARNING)
205
+ import matplotlib.pylab as plt
206
+ import numpy as np
207
+
208
+ fig, ax = plt.subplots(figsize=(6, 4))
209
+ im = ax.imshow(
210
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
211
+ )
212
+ fig.colorbar(im, ax=ax)
213
+ xlabel = "Decoder timestep"
214
+ if info is not None:
215
+ xlabel += "\n\n" + info
216
+ plt.xlabel(xlabel)
217
+ plt.ylabel("Encoder timestep")
218
+ plt.tight_layout()
219
+
220
+ fig.canvas.draw()
221
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
222
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
223
+ plt.close()
224
+ return data
225
+
226
+
227
+ def load_wav_to_torch(full_path):
228
+ sampling_rate, data = read(full_path)
229
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
230
+
231
+
232
+ def load_filepaths_and_text(filename, split="|"):
233
+ with open(filename, encoding="utf-8") as f:
234
+ filepaths_and_text = [line.strip().split(split) for line in f]
235
+ return filepaths_and_text
236
+
237
+
238
+ def get_hparams(init=True):
239
+ parser = argparse.ArgumentParser()
240
+ parser.add_argument(
241
+ "-c",
242
+ "--config",
243
+ type=str,
244
+ default="./configs/base.json",
245
+ help="JSON file for configuration",
246
+ )
247
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
248
+
249
+ args = parser.parse_args()
250
+ model_dir = os.path.join("./logs", args.model)
251
+
252
+ if not os.path.exists(model_dir):
253
+ os.makedirs(model_dir)
254
+
255
+ config_path = args.config
256
+ config_save_path = os.path.join(model_dir, "config.json")
257
+ if init:
258
+ with open(config_path, "r", encoding="utf-8") as f:
259
+ data = f.read()
260
+ with open(config_save_path, "w", encoding="utf-8") as f:
261
+ f.write(data)
262
+ else:
263
+ with open(config_save_path, "r", vencoding="utf-8") as f:
264
+ data = f.read()
265
+ config = json.loads(data)
266
+ hparams = HParams(**config)
267
+ hparams.model_dir = model_dir
268
+ return hparams
269
+
270
+
271
+ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
272
+ """Freeing up space by deleting saved ckpts
273
+
274
+ Arguments:
275
+ path_to_models -- Path to the model directory
276
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
277
+ sort_by_time -- True -> chronologically delete ckpts
278
+ False -> lexicographically delete ckpts
279
+ """
280
+ import re
281
+
282
+ ckpts_files = [
283
+ f
284
+ for f in os.listdir(path_to_models)
285
+ if os.path.isfile(os.path.join(path_to_models, f))
286
+ ]
287
+
288
+ def name_key(_f):
289
+ return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
290
+
291
+ def time_key(_f):
292
+ return os.path.getmtime(os.path.join(path_to_models, _f))
293
+
294
+ sort_key = time_key if sort_by_time else name_key
295
+
296
+ def x_sorted(_x):
297
+ return sorted(
298
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
299
+ key=sort_key,
300
+ )
301
+
302
+ to_del = [
303
+ os.path.join(path_to_models, fn)
304
+ for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
305
+ ]
306
+
307
+ def del_info(fn):
308
+ return logger.info(f".. Free up space by deleting ckpt {fn}")
309
+
310
+ def del_routine(x):
311
+ return [os.remove(x), del_info(x)]
312
+
313
+ [del_routine(fn) for fn in to_del]
314
+
315
+
316
+ def get_hparams_from_dir(model_dir):
317
+ config_save_path = os.path.join(model_dir, "config.json")
318
+ with open(config_save_path, "r", encoding="utf-8") as f:
319
+ data = f.read()
320
+ config = json.loads(data)
321
+
322
+ hparams = HParams(**config)
323
+ hparams.model_dir = model_dir
324
+ return hparams
325
+
326
+
327
+ def get_hparams_from_file(config_path):
328
+ # print("config_path: ", config_path)
329
+ with open(config_path, "r", encoding="utf-8") as f:
330
+ data = f.read()
331
+ config = json.loads(data)
332
+
333
+ hparams = HParams(**config)
334
+ return hparams
335
+
336
+
337
+ def check_git_hash(model_dir):
338
+ source_dir = os.path.dirname(os.path.realpath(__file__))
339
+ if not os.path.exists(os.path.join(source_dir, ".git")):
340
+ logger.warn(
341
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
342
+ source_dir
343
+ )
344
+ )
345
+ return
346
+
347
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
348
+
349
+ path = os.path.join(model_dir, "githash")
350
+ if os.path.exists(path):
351
+ saved_hash = open(path).read()
352
+ if saved_hash != cur_hash:
353
+ logger.warn(
354
+ "git hash values are different. {}(saved) != {}(current)".format(
355
+ saved_hash[:8], cur_hash[:8]
356
+ )
357
+ )
358
+ else:
359
+ open(path, "w").write(cur_hash)
360
+
361
+
362
+ def get_logger(model_dir, filename="train.log"):
363
+ global logger
364
+ logger = logging.getLogger(os.path.basename(model_dir))
365
+ logger.setLevel(logging.DEBUG)
366
+
367
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
368
+ if not os.path.exists(model_dir):
369
+ os.makedirs(model_dir)
370
+ h = logging.FileHandler(os.path.join(model_dir, filename))
371
+ h.setLevel(logging.DEBUG)
372
+ h.setFormatter(formatter)
373
+ logger.addHandler(h)
374
+ return logger
375
+
376
+
377
+ class HParams:
378
+ def __init__(self, **kwargs):
379
+ for k, v in kwargs.items():
380
+ if type(v) == dict:
381
+ v = HParams(**v)
382
+ self[k] = v
383
+
384
+ def keys(self):
385
+ return self.__dict__.keys()
386
+
387
+ def items(self):
388
+ return self.__dict__.items()
389
+
390
+ def values(self):
391
+ return self.__dict__.values()
392
+
393
+ def __len__(self):
394
+ return len(self.__dict__)
395
+
396
+ def __getitem__(self, key):
397
+ return getattr(self, key)
398
+
399
+ def __setitem__(self, key, value):
400
+ return setattr(self, key, value)
401
+
402
+ def __contains__(self, key):
403
+ return key in self.__dict__
404
+
405
+ def __repr__(self):
406
+ return self.__dict__.__repr__()
407
+
408
+
409
+ def load_model(model_path, config_path):
410
+ hps = get_hparams_from_file(config_path)
411
+ net = SynthesizerTrn(
412
+ # len(symbols),
413
+ 108,
414
+ hps.data.filter_length // 2 + 1,
415
+ hps.train.segment_size // hps.data.hop_length,
416
+ n_speakers=hps.data.n_speakers,
417
+ **hps.model,
418
+ ).to("cpu")
419
+ _ = net.eval()
420
+ _ = load_checkpoint(model_path, net, None, skip_optimizer=True)
421
+ return net
422
+
423
+
424
+ def mix_model(
425
+ network1, network2, output_path, voice_ratio=(0.5, 0.5), tone_ratio=(0.5, 0.5)
426
+ ):
427
+ if hasattr(network1, "module"):
428
+ state_dict1 = network1.module.state_dict()
429
+ state_dict2 = network2.module.state_dict()
430
+ else:
431
+ state_dict1 = network1.state_dict()
432
+ state_dict2 = network2.state_dict()
433
+ for k in state_dict1.keys():
434
+ if k not in state_dict2.keys():
435
+ continue
436
+ if "enc_p" in k:
437
+ state_dict1[k] = (
438
+ state_dict1[k].clone() * tone_ratio[0]
439
+ + state_dict2[k].clone() * tone_ratio[1]
440
+ )
441
+ else:
442
+ state_dict1[k] = (
443
+ state_dict1[k].clone() * voice_ratio[0]
444
+ + state_dict2[k].clone() * voice_ratio[1]
445
+ )
446
+ for k in state_dict2.keys():
447
+ if k not in state_dict1.keys():
448
+ state_dict1[k] = state_dict2[k].clone()
449
+ torch.save(
450
+ {"model": state_dict1, "iteration": 0, "optimizer": None, "learning_rate": 0},
451
+ output_path,
452
+ )
453
+
454
+
455
+ def get_steps(model_path):
456
+ matches = re.findall(r"\d+", model_path)
457
+ return matches[-1] if matches else None