Bert-VITS-YuukaBot / server.py
XzJosh's picture
Upload 153 files
6c9cbc5
from flask import Flask, request, Response
from io import BytesIO
import torch
from av import open as avopen
from typing import Dict, List
import re_matching
import utils
from infer import infer, get_net_g, latest_version
from scipy.io import wavfile
import gradio as gr
from config import config
# Flask Init
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
def replace_punctuation(text, i=2):
punctuation = ",。?!"
for char in punctuation:
text = text.replace(char, char * i)
return text
def wav2(i, o, format):
inp = avopen(i, "rb")
out = avopen(o, "wb", format=format)
if format == "ogg":
format = "libvorbis"
ostream = out.add_stream(format)
for frame in inp.decode(audio=0):
for p in ostream.encode(frame):
out.mux(p)
for p in ostream.encode(None):
out.mux(p)
out.close()
inp.close()
net_g_List = []
hps_List = []
# 模型角色字典
# 使用方法 chr_name = chrsMap[model_id][chr_id]
chrsMap: List[Dict[int, str]] = list()
# 加载模型
models = config.server_config.models
for model in models:
hps_List.append(utils.get_hparams_from_file(model["config"]))
# 添加角色字典
chrsMap.append(dict())
for name, cid in hps_List[-1].data.spk2id.items():
chrsMap[-1][cid] = name
version = (
hps_List[-1].version if hasattr(hps_List[-1], "version") else latest_version
)
net_g_List.append(
get_net_g(
model_path=model["model"],
version=version,
device=model["device"],
hps=hps_List[-1],
)
)
def generate_audio(
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
):
audio_list = []
silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
with torch.no_grad():
for piece in slices:
audio = infer(
piece,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=hps,
net_g=net_g,
device=device,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
audio_list.append(silence) # 将静音添加到列表中
return audio_list
@app.route("/")
def main():
try:
model = int(request.args.get("model"))
speaker = request.args.get("speaker", "") # 指定人物名
speaker_id = request.args.get("speaker_id", None) # 直接指定id
text = request.args.get("text").replace("/n", "")
sdp_ratio = float(request.args.get("sdp_ratio", 0.2))
noise = float(request.args.get("noise", 0.5))
noisew = float(request.args.get("noisew", 0.6))
length = float(request.args.get("length", 1.2))
language = request.args.get("language")
if length >= 2:
return "Too big length"
if len(text) >= 250:
return "Too long text"
fmt = request.args.get("format", "wav")
if None in (speaker, text):
return "Missing Parameter"
if fmt not in ("mp3", "wav", "ogg"):
return "Invalid Format"
if language not in ("JP", "ZH", "EN", "mix"):
return "Invalid language"
except:
return "Invalid Parameter"
if speaker_id is not None:
if speaker_id.isdigit():
speaker = chrsMap[model][int(speaker_id)]
audio_list = []
if language == "mix":
bool_valid, str_valid = re_matching.validate_text(text)
if not bool_valid:
return str_valid, (
hps.data.sampling_rate,
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
)
result = re_matching.text_matching(text)
for one in result:
_speaker = one.pop()
for lang, content in one:
audio_list.extend(
generate_audio(
content.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
_speaker,
lang,
)
)
else:
audio_list.extend(
generate_audio(
text.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
)
)
audio_concat = np.concatenate(audio_list)
with BytesIO() as wav:
wavfile.write(wav, hps_List[model].data.sampling_rate, audio_concat)
torch.cuda.empty_cache()
if fmt == "wav":
return Response(wav.getvalue(), mimetype="audio/wav")
wav.seek(0, 0)
with BytesIO() as ofp:
wav2(wav, ofp, fmt)
return Response(
ofp.getvalue(), mimetype="audio/mpeg" if fmt == "mp3" else "audio/ogg"
)
if __name__ == "__main__":
app.run(port=config.server_config.port, server_name="0.0.0.0")