ElesisSiegherts
commited on
Commit
•
245dd7d
1
Parent(s):
ed6c2db
Upload 6 files
Browse files- requirements.txt +33 -0
- server_fastapi.py +642 -0
- spec_gen.py +87 -0
- transforms.py +209 -0
- update_status.py +89 -0
- utils.py +457 -0
requirements.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librosa==0.9.2
|
2 |
+
matplotlib
|
3 |
+
numpy
|
4 |
+
numba
|
5 |
+
phonemizer
|
6 |
+
scipy
|
7 |
+
tensorboard
|
8 |
+
Unidecode
|
9 |
+
amfm_decompy
|
10 |
+
jieba
|
11 |
+
transformers
|
12 |
+
pypinyin
|
13 |
+
cn2an
|
14 |
+
gradio==3.38.0
|
15 |
+
av
|
16 |
+
mecab-python3
|
17 |
+
loguru
|
18 |
+
unidic-lite
|
19 |
+
cmudict
|
20 |
+
fugashi
|
21 |
+
num2words
|
22 |
+
PyYAML
|
23 |
+
requests
|
24 |
+
pyopenjtalk; sys_platform == 'linux'
|
25 |
+
openjtalk; sys_platform != 'linux'
|
26 |
+
jaconv
|
27 |
+
psutil
|
28 |
+
GPUtil
|
29 |
+
vector_quantize_pytorch
|
30 |
+
g2p_en
|
31 |
+
sentencepiece
|
32 |
+
pykakasi
|
33 |
+
langid
|
server_fastapi.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
api服务 多版本多模型 fastapi实现
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
import gc
|
6 |
+
import random
|
7 |
+
|
8 |
+
import gradio
|
9 |
+
import numpy as np
|
10 |
+
import utils
|
11 |
+
from fastapi import FastAPI, Query, Request, File, UploadFile, Form
|
12 |
+
from fastapi.responses import Response, FileResponse
|
13 |
+
from fastapi.staticfiles import StaticFiles
|
14 |
+
from io import BytesIO
|
15 |
+
from scipy.io import wavfile
|
16 |
+
import uvicorn
|
17 |
+
import torch
|
18 |
+
import webbrowser
|
19 |
+
import psutil
|
20 |
+
import GPUtil
|
21 |
+
from typing import Dict, Optional, List, Set, Union
|
22 |
+
import os
|
23 |
+
from tools.log import logger
|
24 |
+
from urllib.parse import unquote
|
25 |
+
|
26 |
+
from infer import infer, get_net_g, latest_version
|
27 |
+
import tools.translate as trans
|
28 |
+
from re_matching import cut_sent
|
29 |
+
|
30 |
+
|
31 |
+
from config import config
|
32 |
+
|
33 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
34 |
+
|
35 |
+
|
36 |
+
class Model:
|
37 |
+
"""模型封装类"""
|
38 |
+
|
39 |
+
def __init__(self, config_path: str, model_path: str, device: str, language: str):
|
40 |
+
self.config_path: str = os.path.normpath(config_path)
|
41 |
+
self.model_path: str = os.path.normpath(model_path)
|
42 |
+
self.device: str = device
|
43 |
+
self.language: str = language
|
44 |
+
self.hps = utils.get_hparams_from_file(config_path)
|
45 |
+
self.spk2id: Dict[str, int] = self.hps.data.spk2id # spk - id 映射字典
|
46 |
+
self.id2spk: Dict[int, str] = dict() # id - spk 映射字典
|
47 |
+
for speaker, speaker_id in self.hps.data.spk2id.items():
|
48 |
+
self.id2spk[speaker_id] = speaker
|
49 |
+
self.version: str = (
|
50 |
+
self.hps.version if hasattr(self.hps, "version") else latest_version
|
51 |
+
)
|
52 |
+
self.net_g = get_net_g(
|
53 |
+
model_path=model_path,
|
54 |
+
version=self.version,
|
55 |
+
device=device,
|
56 |
+
hps=self.hps,
|
57 |
+
)
|
58 |
+
|
59 |
+
def to_dict(self) -> Dict[str, any]:
|
60 |
+
return {
|
61 |
+
"config_path": self.config_path,
|
62 |
+
"model_path": self.model_path,
|
63 |
+
"device": self.device,
|
64 |
+
"language": self.language,
|
65 |
+
"spk2id": self.spk2id,
|
66 |
+
"id2spk": self.id2spk,
|
67 |
+
"version": self.version,
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
class Models:
|
72 |
+
def __init__(self):
|
73 |
+
self.models: Dict[int, Model] = dict()
|
74 |
+
self.num = 0
|
75 |
+
# spkInfo[角色名][模型id] = 角色id
|
76 |
+
self.spk_info: Dict[str, Dict[int, int]] = dict()
|
77 |
+
self.path2ids: Dict[str, Set[int]] = dict() # 路径指向的model的id
|
78 |
+
|
79 |
+
def init_model(
|
80 |
+
self, config_path: str, model_path: str, device: str, language: str
|
81 |
+
) -> int:
|
82 |
+
"""
|
83 |
+
初始化并添加一个模型
|
84 |
+
|
85 |
+
:param config_path: 模型config.json路径
|
86 |
+
:param model_path: 模型路径
|
87 |
+
:param device: 模型推理使用设备
|
88 |
+
:param language: 模型推理默认语言
|
89 |
+
"""
|
90 |
+
# 若文件不存在则不进行加载
|
91 |
+
if not os.path.isfile(model_path):
|
92 |
+
if model_path != "":
|
93 |
+
logger.warning(f"模型文件{model_path} 不存在,不进行初始化")
|
94 |
+
return self.num
|
95 |
+
if not os.path.isfile(config_path):
|
96 |
+
if config_path != "":
|
97 |
+
logger.warning(f"配置文件{config_path} 不存在,不进行初始化")
|
98 |
+
return self.num
|
99 |
+
|
100 |
+
# 若路径中的模型已存在,则不添加模型,若不存在,则进行初始化。
|
101 |
+
model_path = os.path.realpath(model_path)
|
102 |
+
if model_path not in self.path2ids.keys():
|
103 |
+
self.path2ids[model_path] = {self.num}
|
104 |
+
self.models[self.num] = Model(
|
105 |
+
config_path=config_path,
|
106 |
+
model_path=model_path,
|
107 |
+
device=device,
|
108 |
+
language=language,
|
109 |
+
)
|
110 |
+
logger.success(f"添加模型{model_path},使用配置文件{os.path.realpath(config_path)}")
|
111 |
+
else:
|
112 |
+
# 获取一个指向id
|
113 |
+
m_id = next(iter(self.path2ids[model_path]))
|
114 |
+
self.models[self.num] = self.models[m_id]
|
115 |
+
self.path2ids[model_path].add(self.num)
|
116 |
+
logger.success("模型已存在,添加模型引用。")
|
117 |
+
# 添加角色信息
|
118 |
+
for speaker, speaker_id in self.models[self.num].spk2id.items():
|
119 |
+
if speaker not in self.spk_info.keys():
|
120 |
+
self.spk_info[speaker] = {self.num: speaker_id}
|
121 |
+
else:
|
122 |
+
self.spk_info[speaker][self.num] = speaker_id
|
123 |
+
# 修改计数
|
124 |
+
self.num += 1
|
125 |
+
return self.num - 1
|
126 |
+
|
127 |
+
def del_model(self, index: int) -> Optional[int]:
|
128 |
+
"""删除对应序号的模型,若不存在则返回None"""
|
129 |
+
if index not in self.models.keys():
|
130 |
+
return None
|
131 |
+
# 删除角色信息
|
132 |
+
for speaker, speaker_id in self.models[index].spk2id.items():
|
133 |
+
self.spk_info[speaker].pop(index)
|
134 |
+
if len(self.spk_info[speaker]) == 0:
|
135 |
+
# 若对应角色的所有模型都被删除,则清除该角色信息
|
136 |
+
self.spk_info.pop(speaker)
|
137 |
+
# 删除路径信息
|
138 |
+
model_path = os.path.realpath(self.models[index].model_path)
|
139 |
+
self.path2ids[model_path].remove(index)
|
140 |
+
if len(self.path2ids[model_path]) == 0:
|
141 |
+
self.path2ids.pop(model_path)
|
142 |
+
logger.success(f"删除模型{model_path}, id = {index}")
|
143 |
+
else:
|
144 |
+
logger.success(f"删除模型引用{model_path}, id = {index}")
|
145 |
+
# 删除模型
|
146 |
+
self.models.pop(index)
|
147 |
+
gc.collect()
|
148 |
+
if torch.cuda.is_available():
|
149 |
+
torch.cuda.empty_cache()
|
150 |
+
return index
|
151 |
+
|
152 |
+
def get_models(self):
|
153 |
+
"""获取所有模型"""
|
154 |
+
return self.models
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
app = FastAPI()
|
159 |
+
app.logger = logger
|
160 |
+
# 挂载静态文件
|
161 |
+
logger.info("开始挂载网页页面")
|
162 |
+
StaticDir: str = "./Web"
|
163 |
+
if not os.path.isdir(StaticDir):
|
164 |
+
logger.warning(
|
165 |
+
"缺少网页资源,无法开启网页页面,如有需要请在 https://github.com/jiangyuxiaoxiao/Bert-VITS2-UI 或者Bert-VITS对应版本的release页面下载"
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
dirs = [fir.name for fir in os.scandir(StaticDir) if fir.is_dir()]
|
169 |
+
files = [fir.name for fir in os.scandir(StaticDir) if fir.is_dir()]
|
170 |
+
for dirName in dirs:
|
171 |
+
app.mount(
|
172 |
+
f"/{dirName}",
|
173 |
+
StaticFiles(directory=f"./{StaticDir}/{dirName}"),
|
174 |
+
name=dirName,
|
175 |
+
)
|
176 |
+
loaded_models = Models()
|
177 |
+
# 加载模型
|
178 |
+
logger.info("开始加载模型")
|
179 |
+
models_info = config.server_config.models
|
180 |
+
for model_info in models_info:
|
181 |
+
loaded_models.init_model(
|
182 |
+
config_path=model_info["config"],
|
183 |
+
model_path=model_info["model"],
|
184 |
+
device=model_info["device"],
|
185 |
+
language=model_info["language"],
|
186 |
+
)
|
187 |
+
|
188 |
+
@app.get("/")
|
189 |
+
async def index():
|
190 |
+
return FileResponse("./Web/index.html")
|
191 |
+
|
192 |
+
async def _voice(
|
193 |
+
text: str,
|
194 |
+
model_id: int,
|
195 |
+
speaker_name: str,
|
196 |
+
speaker_id: int,
|
197 |
+
sdp_ratio: float,
|
198 |
+
noise: float,
|
199 |
+
noisew: float,
|
200 |
+
length: float,
|
201 |
+
language: str,
|
202 |
+
auto_translate: bool,
|
203 |
+
auto_split: bool,
|
204 |
+
emotion: Optional[int] = None,
|
205 |
+
reference_audio=None,
|
206 |
+
) -> Union[Response, Dict[str, any]]:
|
207 |
+
"""TTS实现函数"""
|
208 |
+
# 检查模型是否存在
|
209 |
+
if model_id not in loaded_models.models.keys():
|
210 |
+
return {"status": 10, "detail": f"模型model_id={model_id}未加载"}
|
211 |
+
# 检查是否提供speaker
|
212 |
+
if speaker_name is None and speaker_id is None:
|
213 |
+
return {"status": 11, "detail": "请提供speaker_name或speaker_id"}
|
214 |
+
elif speaker_name is None:
|
215 |
+
# 检查speaker_id是否存在
|
216 |
+
if speaker_id not in loaded_models.models[model_id].id2spk.keys():
|
217 |
+
return {"status": 12, "detail": f"角色speaker_id={speaker_id}不存在"}
|
218 |
+
speaker_name = loaded_models.models[model_id].id2spk[speaker_id]
|
219 |
+
# 检查speaker_name是否存在
|
220 |
+
if speaker_name not in loaded_models.models[model_id].spk2id.keys():
|
221 |
+
return {"status": 13, "detail": f"角色speaker_name={speaker_name}不存在"}
|
222 |
+
if language is None:
|
223 |
+
language = loaded_models.models[model_id].language
|
224 |
+
if auto_translate:
|
225 |
+
text = trans.translate(Sentence=text, to_Language=language.lower())
|
226 |
+
if reference_audio is not None:
|
227 |
+
ref_audio = BytesIO(await reference_audio.read())
|
228 |
+
else:
|
229 |
+
ref_audio = reference_audio
|
230 |
+
if not auto_split:
|
231 |
+
with torch.no_grad():
|
232 |
+
audio = infer(
|
233 |
+
text=text,
|
234 |
+
sdp_ratio=sdp_ratio,
|
235 |
+
noise_scale=noise,
|
236 |
+
noise_scale_w=noisew,
|
237 |
+
length_scale=length,
|
238 |
+
sid=speaker_name,
|
239 |
+
language=language,
|
240 |
+
hps=loaded_models.models[model_id].hps,
|
241 |
+
net_g=loaded_models.models[model_id].net_g,
|
242 |
+
device=loaded_models.models[model_id].device,
|
243 |
+
emotion=emotion,
|
244 |
+
reference_audio=ref_audio,
|
245 |
+
)
|
246 |
+
audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
|
247 |
+
else:
|
248 |
+
texts = cut_sent(text)
|
249 |
+
audios = []
|
250 |
+
with torch.no_grad():
|
251 |
+
for t in texts:
|
252 |
+
audios.append(
|
253 |
+
infer(
|
254 |
+
text=t,
|
255 |
+
sdp_ratio=sdp_ratio,
|
256 |
+
noise_scale=noise,
|
257 |
+
noise_scale_w=noisew,
|
258 |
+
length_scale=length,
|
259 |
+
sid=speaker_name,
|
260 |
+
language=language,
|
261 |
+
hps=loaded_models.models[model_id].hps,
|
262 |
+
net_g=loaded_models.models[model_id].net_g,
|
263 |
+
device=loaded_models.models[model_id].device,
|
264 |
+
emotion=emotion,
|
265 |
+
reference_audio=ref_audio,
|
266 |
+
)
|
267 |
+
)
|
268 |
+
audios.append(np.zeros(int(44100 * 0.2)))
|
269 |
+
audio = np.concatenate(audios)
|
270 |
+
audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
|
271 |
+
with BytesIO() as wavContent:
|
272 |
+
wavfile.write(
|
273 |
+
wavContent, loaded_models.models[model_id].hps.data.sampling_rate, audio
|
274 |
+
)
|
275 |
+
response = Response(content=wavContent.getvalue(), media_type="audio/wav")
|
276 |
+
return response
|
277 |
+
|
278 |
+
@app.post("/voice")
|
279 |
+
async def voice(
|
280 |
+
request: Request, # fastapi自动注入
|
281 |
+
text: str = Form(...),
|
282 |
+
model_id: int = Query(..., description="模型ID"), # 模型序号
|
283 |
+
speaker_name: str = Query(
|
284 |
+
None, description="说话人名"
|
285 |
+
), # speaker_name与 speaker_id二者选其一
|
286 |
+
speaker_id: int = Query(None, description="说话人id,与speaker_name二选一"),
|
287 |
+
sdp_ratio: float = Query(0.2, description="SDP/DP混合比"),
|
288 |
+
noise: float = Query(0.2, description="感情"),
|
289 |
+
noisew: float = Query(0.9, description="音素长度"),
|
290 |
+
length: float = Query(1, description="语速"),
|
291 |
+
language: str = Query(None, description="语言"), # 若不指定使用语言则使用默认值
|
292 |
+
auto_translate: bool = Query(False, description="自动翻译"),
|
293 |
+
auto_split: bool = Query(False, description="自动切分"),
|
294 |
+
emotion: Optional[int] = Query(None, description="emo"),
|
295 |
+
reference_audio: UploadFile = File(None),
|
296 |
+
):
|
297 |
+
"""语音接口,若需要上传参考音频请仅使用post请求"""
|
298 |
+
logger.info(
|
299 |
+
f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )} text={text}"
|
300 |
+
)
|
301 |
+
return await _voice(
|
302 |
+
text=text,
|
303 |
+
model_id=model_id,
|
304 |
+
speaker_name=speaker_name,
|
305 |
+
speaker_id=speaker_id,
|
306 |
+
sdp_ratio=sdp_ratio,
|
307 |
+
noise=noise,
|
308 |
+
noisew=noisew,
|
309 |
+
length=length,
|
310 |
+
language=language,
|
311 |
+
auto_translate=auto_translate,
|
312 |
+
auto_split=auto_split,
|
313 |
+
emotion=emotion,
|
314 |
+
reference_audio=reference_audio,
|
315 |
+
)
|
316 |
+
|
317 |
+
@app.get("/voice")
|
318 |
+
async def voice(
|
319 |
+
request: Request, # fastapi自动注入
|
320 |
+
text: str = Query(..., description="输入文字"),
|
321 |
+
model_id: int = Query(..., description="模型ID"), # 模型序号
|
322 |
+
speaker_name: str = Query(
|
323 |
+
None, description="说话人名"
|
324 |
+
), # speaker_name与 speaker_id二者选其一
|
325 |
+
speaker_id: int = Query(None, description="说话人id,与speaker_name二选一"),
|
326 |
+
sdp_ratio: float = Query(0.2, description="SDP/DP混合比"),
|
327 |
+
noise: float = Query(0.2, description="感情"),
|
328 |
+
noisew: float = Query(0.9, description="音素长度"),
|
329 |
+
length: float = Query(1, description="语速"),
|
330 |
+
language: str = Query(None, description="语言"), # 若不指定使用语言则使用默认值
|
331 |
+
auto_translate: bool = Query(False, description="自动翻译"),
|
332 |
+
auto_split: bool = Query(False, description="自动切分"),
|
333 |
+
emotion: Optional[int] = Query(None, description="emo"),
|
334 |
+
):
|
335 |
+
"""语音接口"""
|
336 |
+
logger.info(
|
337 |
+
f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}"
|
338 |
+
)
|
339 |
+
return await _voice(
|
340 |
+
text=text,
|
341 |
+
model_id=model_id,
|
342 |
+
speaker_name=speaker_name,
|
343 |
+
speaker_id=speaker_id,
|
344 |
+
sdp_ratio=sdp_ratio,
|
345 |
+
noise=noise,
|
346 |
+
noisew=noisew,
|
347 |
+
length=length,
|
348 |
+
language=language,
|
349 |
+
auto_translate=auto_translate,
|
350 |
+
auto_split=auto_split,
|
351 |
+
emotion=emotion,
|
352 |
+
)
|
353 |
+
|
354 |
+
@app.get("/models/info")
|
355 |
+
def get_loaded_models_info(request: Request):
|
356 |
+
"""获取已加载模型信息"""
|
357 |
+
|
358 |
+
result: Dict[str, Dict] = dict()
|
359 |
+
for key, model in loaded_models.models.items():
|
360 |
+
result[str(key)] = model.to_dict()
|
361 |
+
return result
|
362 |
+
|
363 |
+
@app.get("/models/delete")
|
364 |
+
def delete_model(
|
365 |
+
request: Request, model_id: int = Query(..., description="删除模型id")
|
366 |
+
):
|
367 |
+
"""删除指定模型"""
|
368 |
+
logger.info(
|
369 |
+
f"{request.client.host}:{request.client.port}/models/delete { unquote(str(request.query_params) )}"
|
370 |
+
)
|
371 |
+
result = loaded_models.del_model(model_id)
|
372 |
+
if result is None:
|
373 |
+
return {"status": 14, "detail": f"模型{model_id}不存在,删除失败"}
|
374 |
+
return {"status": 0, "detail": "删除成功"}
|
375 |
+
|
376 |
+
@app.get("/models/add")
|
377 |
+
def add_model(
|
378 |
+
request: Request,
|
379 |
+
model_path: str = Query(..., description="添加模型路径"),
|
380 |
+
config_path: str = Query(
|
381 |
+
None, description="添加模型配置文件路径,不填则使用./config.json或../config.json"
|
382 |
+
),
|
383 |
+
device: str = Query("cuda", description="推理使用设备"),
|
384 |
+
language: str = Query("ZH", description="模型默认语言"),
|
385 |
+
):
|
386 |
+
"""添加指定模型:允许重复添加相同路径模型,且不重复占用内存"""
|
387 |
+
logger.info(
|
388 |
+
f"{request.client.host}:{request.client.port}/models/add { unquote(str(request.query_params) )}"
|
389 |
+
)
|
390 |
+
if config_path is None:
|
391 |
+
model_dir = os.path.dirname(model_path)
|
392 |
+
if os.path.isfile(os.path.join(model_dir, "config.json")):
|
393 |
+
config_path = os.path.join(model_dir, "config.json")
|
394 |
+
elif os.path.isfile(os.path.join(model_dir, "../config.json")):
|
395 |
+
config_path = os.path.join(model_dir, "../config.json")
|
396 |
+
else:
|
397 |
+
return {
|
398 |
+
"status": 15,
|
399 |
+
"detail": "查询未传入配置文件路径,同时默认路径./与../中不存在配置文件config.json。",
|
400 |
+
}
|
401 |
+
try:
|
402 |
+
model_id = loaded_models.init_model(
|
403 |
+
config_path=config_path,
|
404 |
+
model_path=model_path,
|
405 |
+
device=device,
|
406 |
+
language=language,
|
407 |
+
)
|
408 |
+
except Exception:
|
409 |
+
logging.exception("模型加载出错")
|
410 |
+
return {
|
411 |
+
"status": 16,
|
412 |
+
"detail": "模型加载出错,详细查看日志",
|
413 |
+
}
|
414 |
+
return {
|
415 |
+
"status": 0,
|
416 |
+
"detail": "模型添加成功",
|
417 |
+
"Data": {
|
418 |
+
"model_id": model_id,
|
419 |
+
"model_info": loaded_models.models[model_id].to_dict(),
|
420 |
+
},
|
421 |
+
}
|
422 |
+
|
423 |
+
def _get_all_models(root_dir: str = "Data", only_unloaded: bool = False):
|
424 |
+
"""从root_dir搜索获取所有可用模型"""
|
425 |
+
result: Dict[str, List[str]] = dict()
|
426 |
+
files = os.listdir(root_dir) + ["."]
|
427 |
+
for file in files:
|
428 |
+
if os.path.isdir(os.path.join(root_dir, file)):
|
429 |
+
sub_dir = os.path.join(root_dir, file)
|
430 |
+
# 搜索 "sub_dir" 、 "sub_dir/models" 两个路径
|
431 |
+
result[file] = list()
|
432 |
+
sub_files = os.listdir(sub_dir)
|
433 |
+
model_files = []
|
434 |
+
for sub_file in sub_files:
|
435 |
+
relpath = os.path.realpath(os.path.join(sub_dir, sub_file))
|
436 |
+
if only_unloaded and relpath in loaded_models.path2ids.keys():
|
437 |
+
continue
|
438 |
+
if sub_file.endswith(".pth") and sub_file.startswith("G_"):
|
439 |
+
if os.path.isfile(relpath):
|
440 |
+
model_files.append(sub_file)
|
441 |
+
# 对模型文件按步数排序
|
442 |
+
model_files = sorted(
|
443 |
+
model_files,
|
444 |
+
key=lambda pth: int(pth.lstrip("G_").rstrip(".pth"))
|
445 |
+
if pth.lstrip("G_").rstrip(".pth").isdigit()
|
446 |
+
else 10**10,
|
447 |
+
)
|
448 |
+
result[file] = model_files
|
449 |
+
models_dir = os.path.join(sub_dir, "models")
|
450 |
+
model_files = []
|
451 |
+
if os.path.isdir(models_dir):
|
452 |
+
sub_files = os.listdir(models_dir)
|
453 |
+
for sub_file in sub_files:
|
454 |
+
relpath = os.path.realpath(os.path.join(models_dir, sub_file))
|
455 |
+
if only_unloaded and relpath in loaded_models.path2ids.keys():
|
456 |
+
continue
|
457 |
+
if sub_file.endswith(".pth") and sub_file.startswith("G_"):
|
458 |
+
if os.path.isfile(os.path.join(models_dir, sub_file)):
|
459 |
+
model_files.append(f"models/{sub_file}")
|
460 |
+
# 对模型文件按步数排序
|
461 |
+
model_files = sorted(
|
462 |
+
model_files,
|
463 |
+
key=lambda pth: int(pth.lstrip("models/G_").rstrip(".pth"))
|
464 |
+
if pth.lstrip("models/G_").rstrip(".pth").isdigit()
|
465 |
+
else 10**10,
|
466 |
+
)
|
467 |
+
result[file] += model_files
|
468 |
+
if len(result[file]) == 0:
|
469 |
+
result.pop(file)
|
470 |
+
|
471 |
+
return result
|
472 |
+
|
473 |
+
@app.get("/models/get_unloaded")
|
474 |
+
def get_unloaded_models_info(
|
475 |
+
request: Request, root_dir: str = Query("Data", description="搜索根目录")
|
476 |
+
):
|
477 |
+
"""获取未加载模型"""
|
478 |
+
logger.info(
|
479 |
+
f"{request.client.host}:{request.client.port}/models/get_unloaded { unquote(str(request.query_params) )}"
|
480 |
+
)
|
481 |
+
return _get_all_models(root_dir, only_unloaded=True)
|
482 |
+
|
483 |
+
@app.get("/models/get_local")
|
484 |
+
def get_local_models_info(
|
485 |
+
request: Request, root_dir: str = Query("Data", description="搜索根目录")
|
486 |
+
):
|
487 |
+
"""获取全部本地模型"""
|
488 |
+
logger.info(
|
489 |
+
f"{request.client.host}:{request.client.port}/models/get_local { unquote(str(request.query_params) )}"
|
490 |
+
)
|
491 |
+
return _get_all_models(root_dir, only_unloaded=False)
|
492 |
+
|
493 |
+
@app.get("/status")
|
494 |
+
def get_status():
|
495 |
+
"""获取电脑运行状态"""
|
496 |
+
cpu_percent = psutil.cpu_percent(interval=1)
|
497 |
+
memory_info = psutil.virtual_memory()
|
498 |
+
memory_total = memory_info.total
|
499 |
+
memory_available = memory_info.available
|
500 |
+
memory_used = memory_info.used
|
501 |
+
memory_percent = memory_info.percent
|
502 |
+
gpuInfo = []
|
503 |
+
devices = ["cpu"]
|
504 |
+
for i in range(torch.cuda.device_count()):
|
505 |
+
devices.append(f"cuda:{i}")
|
506 |
+
gpus = GPUtil.getGPUs()
|
507 |
+
for gpu in gpus:
|
508 |
+
gpuInfo.append(
|
509 |
+
{
|
510 |
+
"gpu_id": gpu.id,
|
511 |
+
"gpu_load": gpu.load,
|
512 |
+
"gpu_memory": {
|
513 |
+
"total": gpu.memoryTotal,
|
514 |
+
"used": gpu.memoryUsed,
|
515 |
+
"free": gpu.memoryFree,
|
516 |
+
},
|
517 |
+
}
|
518 |
+
)
|
519 |
+
return {
|
520 |
+
"devices": devices,
|
521 |
+
"cpu_percent": cpu_percent,
|
522 |
+
"memory_total": memory_total,
|
523 |
+
"memory_available": memory_available,
|
524 |
+
"memory_used": memory_used,
|
525 |
+
"memory_percent": memory_percent,
|
526 |
+
"gpu": gpuInfo,
|
527 |
+
}
|
528 |
+
|
529 |
+
@app.get("/tools/translate")
|
530 |
+
def translate(
|
531 |
+
request: Request,
|
532 |
+
texts: str = Query(..., description="待翻译文本"),
|
533 |
+
to_language: str = Query(..., description="翻译目标语言"),
|
534 |
+
):
|
535 |
+
"""翻译"""
|
536 |
+
logger.info(
|
537 |
+
f"{request.client.host}:{request.client.port}/tools/translate { unquote(str(request.query_params) )}"
|
538 |
+
)
|
539 |
+
return {"texts": trans.translate(Sentence=texts, to_Language=to_language)}
|
540 |
+
|
541 |
+
all_examples: Dict[str, Dict[str, List]] = dict() # 存放示例
|
542 |
+
|
543 |
+
@app.get("/tools/random_example")
|
544 |
+
def random_example(
|
545 |
+
request: Request,
|
546 |
+
language: str = Query(None, description="指定语言,未指定则随机返回"),
|
547 |
+
root_dir: str = Query("Data", description="搜索根目录"),
|
548 |
+
):
|
549 |
+
"""
|
550 |
+
获取一个随机音频+文本,用于对比,音频会从本地目录随机选择。
|
551 |
+
"""
|
552 |
+
logger.info(
|
553 |
+
f"{request.client.host}:{request.client.port}/tools/random_example { unquote(str(request.query_params) )}"
|
554 |
+
)
|
555 |
+
global all_examples
|
556 |
+
# 数据初始化
|
557 |
+
if root_dir not in all_examples.keys():
|
558 |
+
all_examples[root_dir] = {"ZH": [], "JP": [], "EN": []}
|
559 |
+
|
560 |
+
examples = all_examples[root_dir]
|
561 |
+
|
562 |
+
# 从项目Data目录中搜索train/val.list
|
563 |
+
for root, directories, _files in os.walk(root_dir):
|
564 |
+
for file in _files:
|
565 |
+
if file in ["train.list", "val.list"]:
|
566 |
+
with open(
|
567 |
+
os.path.join(root, file), mode="r", encoding="utf-8"
|
568 |
+
) as f:
|
569 |
+
lines = f.readlines()
|
570 |
+
for line in lines:
|
571 |
+
data = line.split("|")
|
572 |
+
if len(data) != 7:
|
573 |
+
continue
|
574 |
+
# 音频存在 且语言为ZH/EN/JP
|
575 |
+
if os.path.isfile(data[0]) and data[2] in [
|
576 |
+
"ZH",
|
577 |
+
"JP",
|
578 |
+
"EN",
|
579 |
+
]:
|
580 |
+
examples[data[2]].append(
|
581 |
+
{
|
582 |
+
"text": data[3],
|
583 |
+
"audio": data[0],
|
584 |
+
"speaker": data[1],
|
585 |
+
}
|
586 |
+
)
|
587 |
+
|
588 |
+
examples = all_examples[root_dir]
|
589 |
+
if language is None:
|
590 |
+
if len(examples["ZH"]) + len(examples["JP"]) + len(examples["EN"]) == 0:
|
591 |
+
return {"status": 17, "detail": "没有加载任何示例数据"}
|
592 |
+
else:
|
593 |
+
# 随机选一个
|
594 |
+
rand_num = random.randint(
|
595 |
+
0,
|
596 |
+
len(examples["ZH"]) + len(examples["JP"]) + len(examples["EN"]) - 1,
|
597 |
+
)
|
598 |
+
# ZH
|
599 |
+
if rand_num < len(examples["ZH"]):
|
600 |
+
return {"status": 0, "Data": examples["ZH"][rand_num]}
|
601 |
+
# JP
|
602 |
+
if rand_num < len(examples["ZH"]) + len(examples["JP"]):
|
603 |
+
return {
|
604 |
+
"status": 0,
|
605 |
+
"Data": examples["JP"][rand_num - len(examples["ZH"])],
|
606 |
+
}
|
607 |
+
# EN
|
608 |
+
return {
|
609 |
+
"status": 0,
|
610 |
+
"Data": examples["EN"][
|
611 |
+
rand_num - len(examples["ZH"]) - len(examples["JP"])
|
612 |
+
],
|
613 |
+
}
|
614 |
+
|
615 |
+
else:
|
616 |
+
if len(examples[language]) == 0:
|
617 |
+
return {"status": 17, "detail": f"没有加载任何{language}数据"}
|
618 |
+
return {
|
619 |
+
"status": 0,
|
620 |
+
"Data": examples[language][
|
621 |
+
random.randint(0, len(examples[language]) - 1)
|
622 |
+
],
|
623 |
+
}
|
624 |
+
|
625 |
+
@app.get("/tools/get_audio")
|
626 |
+
def get_audio(request: Request, path: str = Query(..., description="本地音频路径")):
|
627 |
+
logger.info(
|
628 |
+
f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
|
629 |
+
)
|
630 |
+
if not os.path.isfile(path):
|
631 |
+
return {"status": 18, "detail": "指定音频不存在"}
|
632 |
+
if not path.endswith(".wav"):
|
633 |
+
return {"status": 19, "detail": "非wav格式文件"}
|
634 |
+
return FileResponse(path=path)
|
635 |
+
|
636 |
+
logger.warning("本地服务,请勿将服务端口暴露于外网")
|
637 |
+
logger.info(f"api文档地址 http://127.0.0.1:{config.server_config.port}/docs")
|
638 |
+
if os.path.isdir(StaticDir):
|
639 |
+
webbrowser.open(f"http://127.0.0.1:{config.server_config.port}")
|
640 |
+
uvicorn.run(
|
641 |
+
app, port=config.server_config.port, host="0.0.0.0", log_level="warning"
|
642 |
+
)
|
spec_gen.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
from multiprocessing import Pool
|
4 |
+
from mel_processing import spectrogram_torch, mel_spectrogram_torch
|
5 |
+
from utils import load_wav_to_torch
|
6 |
+
|
7 |
+
|
8 |
+
class AudioProcessor:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
max_wav_value,
|
12 |
+
use_mel_spec_posterior,
|
13 |
+
filter_length,
|
14 |
+
n_mel_channels,
|
15 |
+
sampling_rate,
|
16 |
+
hop_length,
|
17 |
+
win_length,
|
18 |
+
mel_fmin,
|
19 |
+
mel_fmax,
|
20 |
+
):
|
21 |
+
self.max_wav_value = max_wav_value
|
22 |
+
self.use_mel_spec_posterior = use_mel_spec_posterior
|
23 |
+
self.filter_length = filter_length
|
24 |
+
self.n_mel_channels = n_mel_channels
|
25 |
+
self.sampling_rate = sampling_rate
|
26 |
+
self.hop_length = hop_length
|
27 |
+
self.win_length = win_length
|
28 |
+
self.mel_fmin = mel_fmin
|
29 |
+
self.mel_fmax = mel_fmax
|
30 |
+
|
31 |
+
def process_audio(self, filename):
|
32 |
+
audio, sampling_rate = load_wav_to_torch(filename)
|
33 |
+
audio_norm = audio / self.max_wav_value
|
34 |
+
audio_norm = audio_norm.unsqueeze(0)
|
35 |
+
spec_filename = filename.replace(".wav", ".spec.pt")
|
36 |
+
if self.use_mel_spec_posterior:
|
37 |
+
spec_filename = spec_filename.replace(".spec.pt", ".mel.pt")
|
38 |
+
try:
|
39 |
+
spec = torch.load(spec_filename)
|
40 |
+
except:
|
41 |
+
if self.use_mel_spec_posterior:
|
42 |
+
spec = mel_spectrogram_torch(
|
43 |
+
audio_norm,
|
44 |
+
self.filter_length,
|
45 |
+
self.n_mel_channels,
|
46 |
+
self.sampling_rate,
|
47 |
+
self.hop_length,
|
48 |
+
self.win_length,
|
49 |
+
self.mel_fmin,
|
50 |
+
self.mel_fmax,
|
51 |
+
center=False,
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
spec = spectrogram_torch(
|
55 |
+
audio_norm,
|
56 |
+
self.filter_length,
|
57 |
+
self.sampling_rate,
|
58 |
+
self.hop_length,
|
59 |
+
self.win_length,
|
60 |
+
center=False,
|
61 |
+
)
|
62 |
+
spec = torch.squeeze(spec, 0)
|
63 |
+
torch.save(spec, spec_filename)
|
64 |
+
return spec, audio_norm
|
65 |
+
|
66 |
+
|
67 |
+
# 使用示例
|
68 |
+
processor = AudioProcessor(
|
69 |
+
max_wav_value=32768.0,
|
70 |
+
use_mel_spec_posterior=False,
|
71 |
+
filter_length=2048,
|
72 |
+
n_mel_channels=128,
|
73 |
+
sampling_rate=44100,
|
74 |
+
hop_length=512,
|
75 |
+
win_length=2048,
|
76 |
+
mel_fmin=0.0,
|
77 |
+
mel_fmax="null",
|
78 |
+
)
|
79 |
+
|
80 |
+
with open("filelists/train.list", "r") as f:
|
81 |
+
filepaths = [line.split("|")[0] for line in f] # 取每一行的第一部分作为audiopath
|
82 |
+
|
83 |
+
# 使用多进程处理
|
84 |
+
with Pool(processes=32) as pool: # 使用4个进程
|
85 |
+
with tqdm(total=len(filepaths)) as pbar:
|
86 |
+
for i, _ in enumerate(pool.imap_unordered(processor.process_audio, filepaths)):
|
87 |
+
pbar.update()
|
transforms.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
8 |
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
9 |
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
10 |
+
|
11 |
+
|
12 |
+
def piecewise_rational_quadratic_transform(
|
13 |
+
inputs,
|
14 |
+
unnormalized_widths,
|
15 |
+
unnormalized_heights,
|
16 |
+
unnormalized_derivatives,
|
17 |
+
inverse=False,
|
18 |
+
tails=None,
|
19 |
+
tail_bound=1.0,
|
20 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
21 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
22 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
23 |
+
):
|
24 |
+
if tails is None:
|
25 |
+
spline_fn = rational_quadratic_spline
|
26 |
+
spline_kwargs = {}
|
27 |
+
else:
|
28 |
+
spline_fn = unconstrained_rational_quadratic_spline
|
29 |
+
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
30 |
+
|
31 |
+
outputs, logabsdet = spline_fn(
|
32 |
+
inputs=inputs,
|
33 |
+
unnormalized_widths=unnormalized_widths,
|
34 |
+
unnormalized_heights=unnormalized_heights,
|
35 |
+
unnormalized_derivatives=unnormalized_derivatives,
|
36 |
+
inverse=inverse,
|
37 |
+
min_bin_width=min_bin_width,
|
38 |
+
min_bin_height=min_bin_height,
|
39 |
+
min_derivative=min_derivative,
|
40 |
+
**spline_kwargs
|
41 |
+
)
|
42 |
+
return outputs, logabsdet
|
43 |
+
|
44 |
+
|
45 |
+
def searchsorted(bin_locations, inputs, eps=1e-6):
|
46 |
+
bin_locations[..., -1] += eps
|
47 |
+
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
48 |
+
|
49 |
+
|
50 |
+
def unconstrained_rational_quadratic_spline(
|
51 |
+
inputs,
|
52 |
+
unnormalized_widths,
|
53 |
+
unnormalized_heights,
|
54 |
+
unnormalized_derivatives,
|
55 |
+
inverse=False,
|
56 |
+
tails="linear",
|
57 |
+
tail_bound=1.0,
|
58 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
59 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
60 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
61 |
+
):
|
62 |
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
63 |
+
outside_interval_mask = ~inside_interval_mask
|
64 |
+
|
65 |
+
outputs = torch.zeros_like(inputs)
|
66 |
+
logabsdet = torch.zeros_like(inputs)
|
67 |
+
|
68 |
+
if tails == "linear":
|
69 |
+
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
70 |
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
71 |
+
unnormalized_derivatives[..., 0] = constant
|
72 |
+
unnormalized_derivatives[..., -1] = constant
|
73 |
+
|
74 |
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
75 |
+
logabsdet[outside_interval_mask] = 0
|
76 |
+
else:
|
77 |
+
raise RuntimeError("{} tails are not implemented.".format(tails))
|
78 |
+
|
79 |
+
(
|
80 |
+
outputs[inside_interval_mask],
|
81 |
+
logabsdet[inside_interval_mask],
|
82 |
+
) = rational_quadratic_spline(
|
83 |
+
inputs=inputs[inside_interval_mask],
|
84 |
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
85 |
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
86 |
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
87 |
+
inverse=inverse,
|
88 |
+
left=-tail_bound,
|
89 |
+
right=tail_bound,
|
90 |
+
bottom=-tail_bound,
|
91 |
+
top=tail_bound,
|
92 |
+
min_bin_width=min_bin_width,
|
93 |
+
min_bin_height=min_bin_height,
|
94 |
+
min_derivative=min_derivative,
|
95 |
+
)
|
96 |
+
|
97 |
+
return outputs, logabsdet
|
98 |
+
|
99 |
+
|
100 |
+
def rational_quadratic_spline(
|
101 |
+
inputs,
|
102 |
+
unnormalized_widths,
|
103 |
+
unnormalized_heights,
|
104 |
+
unnormalized_derivatives,
|
105 |
+
inverse=False,
|
106 |
+
left=0.0,
|
107 |
+
right=1.0,
|
108 |
+
bottom=0.0,
|
109 |
+
top=1.0,
|
110 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
111 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
112 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
113 |
+
):
|
114 |
+
if torch.min(inputs) < left or torch.max(inputs) > right:
|
115 |
+
raise ValueError("Input to a transform is not within its domain")
|
116 |
+
|
117 |
+
num_bins = unnormalized_widths.shape[-1]
|
118 |
+
|
119 |
+
if min_bin_width * num_bins > 1.0:
|
120 |
+
raise ValueError("Minimal bin width too large for the number of bins")
|
121 |
+
if min_bin_height * num_bins > 1.0:
|
122 |
+
raise ValueError("Minimal bin height too large for the number of bins")
|
123 |
+
|
124 |
+
widths = F.softmax(unnormalized_widths, dim=-1)
|
125 |
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
126 |
+
cumwidths = torch.cumsum(widths, dim=-1)
|
127 |
+
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
128 |
+
cumwidths = (right - left) * cumwidths + left
|
129 |
+
cumwidths[..., 0] = left
|
130 |
+
cumwidths[..., -1] = right
|
131 |
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
132 |
+
|
133 |
+
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
134 |
+
|
135 |
+
heights = F.softmax(unnormalized_heights, dim=-1)
|
136 |
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
137 |
+
cumheights = torch.cumsum(heights, dim=-1)
|
138 |
+
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
139 |
+
cumheights = (top - bottom) * cumheights + bottom
|
140 |
+
cumheights[..., 0] = bottom
|
141 |
+
cumheights[..., -1] = top
|
142 |
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
143 |
+
|
144 |
+
if inverse:
|
145 |
+
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
146 |
+
else:
|
147 |
+
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
148 |
+
|
149 |
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
150 |
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
151 |
+
|
152 |
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
153 |
+
delta = heights / widths
|
154 |
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
155 |
+
|
156 |
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
157 |
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
158 |
+
|
159 |
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
160 |
+
|
161 |
+
if inverse:
|
162 |
+
a = (inputs - input_cumheights) * (
|
163 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
164 |
+
) + input_heights * (input_delta - input_derivatives)
|
165 |
+
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
166 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
167 |
+
)
|
168 |
+
c = -input_delta * (inputs - input_cumheights)
|
169 |
+
|
170 |
+
discriminant = b.pow(2) - 4 * a * c
|
171 |
+
assert (discriminant >= 0).all()
|
172 |
+
|
173 |
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
174 |
+
outputs = root * input_bin_widths + input_cumwidths
|
175 |
+
|
176 |
+
theta_one_minus_theta = root * (1 - root)
|
177 |
+
denominator = input_delta + (
|
178 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
179 |
+
* theta_one_minus_theta
|
180 |
+
)
|
181 |
+
derivative_numerator = input_delta.pow(2) * (
|
182 |
+
input_derivatives_plus_one * root.pow(2)
|
183 |
+
+ 2 * input_delta * theta_one_minus_theta
|
184 |
+
+ input_derivatives * (1 - root).pow(2)
|
185 |
+
)
|
186 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
187 |
+
|
188 |
+
return outputs, -logabsdet
|
189 |
+
else:
|
190 |
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
191 |
+
theta_one_minus_theta = theta * (1 - theta)
|
192 |
+
|
193 |
+
numerator = input_heights * (
|
194 |
+
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
195 |
+
)
|
196 |
+
denominator = input_delta + (
|
197 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
198 |
+
* theta_one_minus_theta
|
199 |
+
)
|
200 |
+
outputs = input_cumheights + numerator / denominator
|
201 |
+
|
202 |
+
derivative_numerator = input_delta.pow(2) * (
|
203 |
+
input_derivatives_plus_one * theta.pow(2)
|
204 |
+
+ 2 * input_delta * theta_one_minus_theta
|
205 |
+
+ input_derivatives * (1 - theta).pow(2)
|
206 |
+
)
|
207 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
208 |
+
|
209 |
+
return outputs, logabsdet
|
update_status.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
lang_dict = {"EN(英文)": "_en", "ZH(中文)": "_zh", "JP(日语)": "_jp"}
|
5 |
+
|
6 |
+
|
7 |
+
def raw_dir_convert_to_path(target_dir: str, lang):
|
8 |
+
res = target_dir.rstrip("/").rstrip("\\")
|
9 |
+
if (not target_dir.startswith("raw")) and (not target_dir.startswith("./raw")):
|
10 |
+
res = os.path.join("./raw", res)
|
11 |
+
if (
|
12 |
+
(not res.endswith("_zh"))
|
13 |
+
and (not res.endswith("_jp"))
|
14 |
+
and (not res.endswith("_en"))
|
15 |
+
):
|
16 |
+
res += lang_dict[lang]
|
17 |
+
return res
|
18 |
+
|
19 |
+
|
20 |
+
def update_g_files():
|
21 |
+
g_files = []
|
22 |
+
cnt = 0
|
23 |
+
for root, dirs, files in os.walk(os.path.abspath("./logs")):
|
24 |
+
for file in files:
|
25 |
+
if file.startswith("G_") and file.endswith(".pth"):
|
26 |
+
g_files.append(os.path.join(root, file))
|
27 |
+
cnt += 1
|
28 |
+
print(g_files)
|
29 |
+
return f"更新模型列表完成, 共找到{cnt}个模型", gr.Dropdown.update(choices=g_files)
|
30 |
+
|
31 |
+
|
32 |
+
def update_c_files():
|
33 |
+
c_files = []
|
34 |
+
cnt = 0
|
35 |
+
for root, dirs, files in os.walk(os.path.abspath("./logs")):
|
36 |
+
for file in files:
|
37 |
+
if file.startswith("config.json"):
|
38 |
+
c_files.append(os.path.join(root, file))
|
39 |
+
cnt += 1
|
40 |
+
print(c_files)
|
41 |
+
return f"更新模型列表完成, 共找到{cnt}个配置文件", gr.Dropdown.update(choices=c_files)
|
42 |
+
|
43 |
+
|
44 |
+
def update_model_folders():
|
45 |
+
subdirs = []
|
46 |
+
cnt = 0
|
47 |
+
for root, dirs, files in os.walk(os.path.abspath("./logs")):
|
48 |
+
for dir_name in dirs:
|
49 |
+
if os.path.basename(dir_name) != "eval":
|
50 |
+
subdirs.append(os.path.join(root, dir_name))
|
51 |
+
cnt += 1
|
52 |
+
print(subdirs)
|
53 |
+
return f"更新模型文件夹列表完成, 共找到{cnt}个文件夹", gr.Dropdown.update(choices=subdirs)
|
54 |
+
|
55 |
+
|
56 |
+
def update_wav_lab_pairs():
|
57 |
+
wav_count = tot_count = 0
|
58 |
+
for root, _, files in os.walk("./raw"):
|
59 |
+
for file in files:
|
60 |
+
# print(file)
|
61 |
+
file_path = os.path.join(root, file)
|
62 |
+
if file.lower().endswith(".wav"):
|
63 |
+
lab_file = os.path.splitext(file_path)[0] + ".lab"
|
64 |
+
if os.path.exists(lab_file):
|
65 |
+
wav_count += 1
|
66 |
+
tot_count += 1
|
67 |
+
return f"{wav_count} / {tot_count}"
|
68 |
+
|
69 |
+
|
70 |
+
def update_raw_folders():
|
71 |
+
subdirs = []
|
72 |
+
cnt = 0
|
73 |
+
script_path = os.path.dirname(os.path.abspath(__file__)) # 获取当前脚本的绝对路径
|
74 |
+
raw_path = os.path.join(script_path, "raw")
|
75 |
+
print(raw_path)
|
76 |
+
os.makedirs(raw_path, exist_ok=True)
|
77 |
+
for root, dirs, files in os.walk(raw_path):
|
78 |
+
for dir_name in dirs:
|
79 |
+
relative_path = os.path.relpath(
|
80 |
+
os.path.join(root, dir_name), script_path
|
81 |
+
) # 获取相对路径
|
82 |
+
subdirs.append(relative_path)
|
83 |
+
cnt += 1
|
84 |
+
print(subdirs)
|
85 |
+
return (
|
86 |
+
f"更新raw音频文件夹列表完成, 共找到{cnt}个文件夹",
|
87 |
+
gr.Dropdown.update(choices=subdirs),
|
88 |
+
gr.Textbox.update(value=update_wav_lab_pairs()),
|
89 |
+
)
|
utils.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
import shutil
|
7 |
+
import subprocess
|
8 |
+
import numpy as np
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from scipy.io.wavfile import read
|
11 |
+
import torch
|
12 |
+
import re
|
13 |
+
|
14 |
+
MATPLOTLIB_FLAG = False
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def download_emo_models(mirror, repo_id, model_name):
|
20 |
+
if mirror == "openi":
|
21 |
+
import openi
|
22 |
+
|
23 |
+
openi.model.download_model(
|
24 |
+
"Stardust_minus/Bert-VITS2",
|
25 |
+
repo_id.split("/")[-1],
|
26 |
+
"./emotional",
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
hf_hub_download(
|
30 |
+
repo_id,
|
31 |
+
"pytorch_model.bin",
|
32 |
+
local_dir=model_name,
|
33 |
+
local_dir_use_symlinks=False,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def download_checkpoint(
|
38 |
+
dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
|
39 |
+
):
|
40 |
+
repo_id = repo_config["repo_id"]
|
41 |
+
f_list = glob.glob(os.path.join(dir_path, regex))
|
42 |
+
if f_list:
|
43 |
+
print("Use existed model, skip downloading.")
|
44 |
+
return
|
45 |
+
if mirror.lower() == "openi":
|
46 |
+
import openi
|
47 |
+
|
48 |
+
kwargs = {"token": token} if token else {}
|
49 |
+
openi.login(**kwargs)
|
50 |
+
|
51 |
+
model_image = repo_config["model_image"]
|
52 |
+
openi.model.download_model(repo_id, model_image, dir_path)
|
53 |
+
|
54 |
+
fs = glob.glob(os.path.join(dir_path, model_image, "*.pth"))
|
55 |
+
for file in fs:
|
56 |
+
shutil.move(file, dir_path)
|
57 |
+
shutil.rmtree(os.path.join(dir_path, model_image))
|
58 |
+
else:
|
59 |
+
for file in ["DUR_0.pth", "D_0.pth", "G_0.pth"]:
|
60 |
+
hf_hub_download(
|
61 |
+
repo_id, file, local_dir=dir_path, local_dir_use_symlinks=False
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
66 |
+
assert os.path.isfile(checkpoint_path)
|
67 |
+
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
68 |
+
iteration = checkpoint_dict["iteration"]
|
69 |
+
learning_rate = checkpoint_dict["learning_rate"]
|
70 |
+
if (
|
71 |
+
optimizer is not None
|
72 |
+
and not skip_optimizer
|
73 |
+
and checkpoint_dict["optimizer"] is not None
|
74 |
+
):
|
75 |
+
optimizer.load_state_dict(checkpoint_dict["optimizer"])
|
76 |
+
elif optimizer is None and not skip_optimizer:
|
77 |
+
# else: Disable this line if Infer and resume checkpoint,then enable the line upper
|
78 |
+
new_opt_dict = optimizer.state_dict()
|
79 |
+
new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
|
80 |
+
new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
|
81 |
+
new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
|
82 |
+
optimizer.load_state_dict(new_opt_dict)
|
83 |
+
|
84 |
+
saved_state_dict = checkpoint_dict["model"]
|
85 |
+
if hasattr(model, "module"):
|
86 |
+
state_dict = model.module.state_dict()
|
87 |
+
else:
|
88 |
+
state_dict = model.state_dict()
|
89 |
+
|
90 |
+
new_state_dict = {}
|
91 |
+
for k, v in state_dict.items():
|
92 |
+
try:
|
93 |
+
# assert "emb_g" not in k
|
94 |
+
new_state_dict[k] = saved_state_dict[k]
|
95 |
+
assert saved_state_dict[k].shape == v.shape, (
|
96 |
+
saved_state_dict[k].shape,
|
97 |
+
v.shape,
|
98 |
+
)
|
99 |
+
except:
|
100 |
+
# For upgrading from the old version
|
101 |
+
if "ja_bert_proj" in k:
|
102 |
+
v = torch.zeros_like(v)
|
103 |
+
logger.warn(
|
104 |
+
f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
logger.error(f"{k} is not in the checkpoint")
|
108 |
+
|
109 |
+
new_state_dict[k] = v
|
110 |
+
|
111 |
+
if hasattr(model, "module"):
|
112 |
+
model.module.load_state_dict(new_state_dict, strict=False)
|
113 |
+
else:
|
114 |
+
model.load_state_dict(new_state_dict, strict=False)
|
115 |
+
|
116 |
+
logger.info(
|
117 |
+
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
|
118 |
+
)
|
119 |
+
|
120 |
+
return model, optimizer, learning_rate, iteration
|
121 |
+
|
122 |
+
|
123 |
+
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
124 |
+
logger.info(
|
125 |
+
"Saving model and optimizer state at iteration {} to {}".format(
|
126 |
+
iteration, checkpoint_path
|
127 |
+
)
|
128 |
+
)
|
129 |
+
if hasattr(model, "module"):
|
130 |
+
state_dict = model.module.state_dict()
|
131 |
+
else:
|
132 |
+
state_dict = model.state_dict()
|
133 |
+
torch.save(
|
134 |
+
{
|
135 |
+
"model": state_dict,
|
136 |
+
"iteration": iteration,
|
137 |
+
"optimizer": optimizer.state_dict(),
|
138 |
+
"learning_rate": learning_rate,
|
139 |
+
},
|
140 |
+
checkpoint_path,
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
def summarize(
|
145 |
+
writer,
|
146 |
+
global_step,
|
147 |
+
scalars={},
|
148 |
+
histograms={},
|
149 |
+
images={},
|
150 |
+
audios={},
|
151 |
+
audio_sampling_rate=22050,
|
152 |
+
):
|
153 |
+
for k, v in scalars.items():
|
154 |
+
writer.add_scalar(k, v, global_step)
|
155 |
+
for k, v in histograms.items():
|
156 |
+
writer.add_histogram(k, v, global_step)
|
157 |
+
for k, v in images.items():
|
158 |
+
writer.add_image(k, v, global_step, dataformats="HWC")
|
159 |
+
for k, v in audios.items():
|
160 |
+
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
161 |
+
|
162 |
+
|
163 |
+
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
164 |
+
f_list = glob.glob(os.path.join(dir_path, regex))
|
165 |
+
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
|
166 |
+
x = f_list[-1]
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
def plot_spectrogram_to_numpy(spectrogram):
|
171 |
+
global MATPLOTLIB_FLAG
|
172 |
+
if not MATPLOTLIB_FLAG:
|
173 |
+
import matplotlib
|
174 |
+
|
175 |
+
matplotlib.use("Agg")
|
176 |
+
MATPLOTLIB_FLAG = True
|
177 |
+
mpl_logger = logging.getLogger("matplotlib")
|
178 |
+
mpl_logger.setLevel(logging.WARNING)
|
179 |
+
import matplotlib.pylab as plt
|
180 |
+
import numpy as np
|
181 |
+
|
182 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
183 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
184 |
+
plt.colorbar(im, ax=ax)
|
185 |
+
plt.xlabel("Frames")
|
186 |
+
plt.ylabel("Channels")
|
187 |
+
plt.tight_layout()
|
188 |
+
|
189 |
+
fig.canvas.draw()
|
190 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
191 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
192 |
+
plt.close()
|
193 |
+
return data
|
194 |
+
|
195 |
+
|
196 |
+
def plot_alignment_to_numpy(alignment, info=None):
|
197 |
+
global MATPLOTLIB_FLAG
|
198 |
+
if not MATPLOTLIB_FLAG:
|
199 |
+
import matplotlib
|
200 |
+
|
201 |
+
matplotlib.use("Agg")
|
202 |
+
MATPLOTLIB_FLAG = True
|
203 |
+
mpl_logger = logging.getLogger("matplotlib")
|
204 |
+
mpl_logger.setLevel(logging.WARNING)
|
205 |
+
import matplotlib.pylab as plt
|
206 |
+
import numpy as np
|
207 |
+
|
208 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
209 |
+
im = ax.imshow(
|
210 |
+
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
|
211 |
+
)
|
212 |
+
fig.colorbar(im, ax=ax)
|
213 |
+
xlabel = "Decoder timestep"
|
214 |
+
if info is not None:
|
215 |
+
xlabel += "\n\n" + info
|
216 |
+
plt.xlabel(xlabel)
|
217 |
+
plt.ylabel("Encoder timestep")
|
218 |
+
plt.tight_layout()
|
219 |
+
|
220 |
+
fig.canvas.draw()
|
221 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
222 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
223 |
+
plt.close()
|
224 |
+
return data
|
225 |
+
|
226 |
+
|
227 |
+
def load_wav_to_torch(full_path):
|
228 |
+
sampling_rate, data = read(full_path)
|
229 |
+
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
230 |
+
|
231 |
+
|
232 |
+
def load_filepaths_and_text(filename, split="|"):
|
233 |
+
with open(filename, encoding="utf-8") as f:
|
234 |
+
filepaths_and_text = [line.strip().split(split) for line in f]
|
235 |
+
return filepaths_and_text
|
236 |
+
|
237 |
+
|
238 |
+
def get_hparams(init=True):
|
239 |
+
parser = argparse.ArgumentParser()
|
240 |
+
parser.add_argument(
|
241 |
+
"-c",
|
242 |
+
"--config",
|
243 |
+
type=str,
|
244 |
+
default="./configs/base.json",
|
245 |
+
help="JSON file for configuration",
|
246 |
+
)
|
247 |
+
parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
|
248 |
+
|
249 |
+
args = parser.parse_args()
|
250 |
+
model_dir = os.path.join("./logs", args.model)
|
251 |
+
|
252 |
+
if not os.path.exists(model_dir):
|
253 |
+
os.makedirs(model_dir)
|
254 |
+
|
255 |
+
config_path = args.config
|
256 |
+
config_save_path = os.path.join(model_dir, "config.json")
|
257 |
+
if init:
|
258 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
259 |
+
data = f.read()
|
260 |
+
with open(config_save_path, "w", encoding="utf-8") as f:
|
261 |
+
f.write(data)
|
262 |
+
else:
|
263 |
+
with open(config_save_path, "r", vencoding="utf-8") as f:
|
264 |
+
data = f.read()
|
265 |
+
config = json.loads(data)
|
266 |
+
hparams = HParams(**config)
|
267 |
+
hparams.model_dir = model_dir
|
268 |
+
return hparams
|
269 |
+
|
270 |
+
|
271 |
+
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
|
272 |
+
"""Freeing up space by deleting saved ckpts
|
273 |
+
|
274 |
+
Arguments:
|
275 |
+
path_to_models -- Path to the model directory
|
276 |
+
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
|
277 |
+
sort_by_time -- True -> chronologically delete ckpts
|
278 |
+
False -> lexicographically delete ckpts
|
279 |
+
"""
|
280 |
+
import re
|
281 |
+
|
282 |
+
ckpts_files = [
|
283 |
+
f
|
284 |
+
for f in os.listdir(path_to_models)
|
285 |
+
if os.path.isfile(os.path.join(path_to_models, f))
|
286 |
+
]
|
287 |
+
|
288 |
+
def name_key(_f):
|
289 |
+
return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
|
290 |
+
|
291 |
+
def time_key(_f):
|
292 |
+
return os.path.getmtime(os.path.join(path_to_models, _f))
|
293 |
+
|
294 |
+
sort_key = time_key if sort_by_time else name_key
|
295 |
+
|
296 |
+
def x_sorted(_x):
|
297 |
+
return sorted(
|
298 |
+
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
|
299 |
+
key=sort_key,
|
300 |
+
)
|
301 |
+
|
302 |
+
to_del = [
|
303 |
+
os.path.join(path_to_models, fn)
|
304 |
+
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
|
305 |
+
]
|
306 |
+
|
307 |
+
def del_info(fn):
|
308 |
+
return logger.info(f".. Free up space by deleting ckpt {fn}")
|
309 |
+
|
310 |
+
def del_routine(x):
|
311 |
+
return [os.remove(x), del_info(x)]
|
312 |
+
|
313 |
+
[del_routine(fn) for fn in to_del]
|
314 |
+
|
315 |
+
|
316 |
+
def get_hparams_from_dir(model_dir):
|
317 |
+
config_save_path = os.path.join(model_dir, "config.json")
|
318 |
+
with open(config_save_path, "r", encoding="utf-8") as f:
|
319 |
+
data = f.read()
|
320 |
+
config = json.loads(data)
|
321 |
+
|
322 |
+
hparams = HParams(**config)
|
323 |
+
hparams.model_dir = model_dir
|
324 |
+
return hparams
|
325 |
+
|
326 |
+
|
327 |
+
def get_hparams_from_file(config_path):
|
328 |
+
# print("config_path: ", config_path)
|
329 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
330 |
+
data = f.read()
|
331 |
+
config = json.loads(data)
|
332 |
+
|
333 |
+
hparams = HParams(**config)
|
334 |
+
return hparams
|
335 |
+
|
336 |
+
|
337 |
+
def check_git_hash(model_dir):
|
338 |
+
source_dir = os.path.dirname(os.path.realpath(__file__))
|
339 |
+
if not os.path.exists(os.path.join(source_dir, ".git")):
|
340 |
+
logger.warn(
|
341 |
+
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
342 |
+
source_dir
|
343 |
+
)
|
344 |
+
)
|
345 |
+
return
|
346 |
+
|
347 |
+
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
348 |
+
|
349 |
+
path = os.path.join(model_dir, "githash")
|
350 |
+
if os.path.exists(path):
|
351 |
+
saved_hash = open(path).read()
|
352 |
+
if saved_hash != cur_hash:
|
353 |
+
logger.warn(
|
354 |
+
"git hash values are different. {}(saved) != {}(current)".format(
|
355 |
+
saved_hash[:8], cur_hash[:8]
|
356 |
+
)
|
357 |
+
)
|
358 |
+
else:
|
359 |
+
open(path, "w").write(cur_hash)
|
360 |
+
|
361 |
+
|
362 |
+
def get_logger(model_dir, filename="train.log"):
|
363 |
+
global logger
|
364 |
+
logger = logging.getLogger(os.path.basename(model_dir))
|
365 |
+
logger.setLevel(logging.DEBUG)
|
366 |
+
|
367 |
+
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
368 |
+
if not os.path.exists(model_dir):
|
369 |
+
os.makedirs(model_dir)
|
370 |
+
h = logging.FileHandler(os.path.join(model_dir, filename))
|
371 |
+
h.setLevel(logging.DEBUG)
|
372 |
+
h.setFormatter(formatter)
|
373 |
+
logger.addHandler(h)
|
374 |
+
return logger
|
375 |
+
|
376 |
+
|
377 |
+
class HParams:
|
378 |
+
def __init__(self, **kwargs):
|
379 |
+
for k, v in kwargs.items():
|
380 |
+
if type(v) == dict:
|
381 |
+
v = HParams(**v)
|
382 |
+
self[k] = v
|
383 |
+
|
384 |
+
def keys(self):
|
385 |
+
return self.__dict__.keys()
|
386 |
+
|
387 |
+
def items(self):
|
388 |
+
return self.__dict__.items()
|
389 |
+
|
390 |
+
def values(self):
|
391 |
+
return self.__dict__.values()
|
392 |
+
|
393 |
+
def __len__(self):
|
394 |
+
return len(self.__dict__)
|
395 |
+
|
396 |
+
def __getitem__(self, key):
|
397 |
+
return getattr(self, key)
|
398 |
+
|
399 |
+
def __setitem__(self, key, value):
|
400 |
+
return setattr(self, key, value)
|
401 |
+
|
402 |
+
def __contains__(self, key):
|
403 |
+
return key in self.__dict__
|
404 |
+
|
405 |
+
def __repr__(self):
|
406 |
+
return self.__dict__.__repr__()
|
407 |
+
|
408 |
+
|
409 |
+
def load_model(model_path, config_path):
|
410 |
+
hps = get_hparams_from_file(config_path)
|
411 |
+
net = SynthesizerTrn(
|
412 |
+
# len(symbols),
|
413 |
+
108,
|
414 |
+
hps.data.filter_length // 2 + 1,
|
415 |
+
hps.train.segment_size // hps.data.hop_length,
|
416 |
+
n_speakers=hps.data.n_speakers,
|
417 |
+
**hps.model,
|
418 |
+
).to("cpu")
|
419 |
+
_ = net.eval()
|
420 |
+
_ = load_checkpoint(model_path, net, None, skip_optimizer=True)
|
421 |
+
return net
|
422 |
+
|
423 |
+
|
424 |
+
def mix_model(
|
425 |
+
network1, network2, output_path, voice_ratio=(0.5, 0.5), tone_ratio=(0.5, 0.5)
|
426 |
+
):
|
427 |
+
if hasattr(network1, "module"):
|
428 |
+
state_dict1 = network1.module.state_dict()
|
429 |
+
state_dict2 = network2.module.state_dict()
|
430 |
+
else:
|
431 |
+
state_dict1 = network1.state_dict()
|
432 |
+
state_dict2 = network2.state_dict()
|
433 |
+
for k in state_dict1.keys():
|
434 |
+
if k not in state_dict2.keys():
|
435 |
+
continue
|
436 |
+
if "enc_p" in k:
|
437 |
+
state_dict1[k] = (
|
438 |
+
state_dict1[k].clone() * tone_ratio[0]
|
439 |
+
+ state_dict2[k].clone() * tone_ratio[1]
|
440 |
+
)
|
441 |
+
else:
|
442 |
+
state_dict1[k] = (
|
443 |
+
state_dict1[k].clone() * voice_ratio[0]
|
444 |
+
+ state_dict2[k].clone() * voice_ratio[1]
|
445 |
+
)
|
446 |
+
for k in state_dict2.keys():
|
447 |
+
if k not in state_dict1.keys():
|
448 |
+
state_dict1[k] = state_dict2[k].clone()
|
449 |
+
torch.save(
|
450 |
+
{"model": state_dict1, "iteration": 0, "optimizer": None, "learning_rate": 0},
|
451 |
+
output_path,
|
452 |
+
)
|
453 |
+
|
454 |
+
|
455 |
+
def get_steps(model_path):
|
456 |
+
matches = re.findall(r"\d+", model_path)
|
457 |
+
return matches[-1] if matches else None
|