tts / app.py
indiejoseph's picture
Update app.py
76282ae verified
raw
history blame
5.1 kB
from infer import OnnxInferenceSession
from text import cleaned_text_to_sequence, get_bert
from text.cleaner import clean_text
import numpy as np
from huggingface_hub import hf_hub_download
import asyncio
from pathlib import Path
OnnxSession = None
models = [
{
"local_path": "./bert/bert-large-cantonese",
"repo_id": "hon9kon9ize/bert-large-cantonese",
"files": [
"pytorch_model.bin"
]
},
{
"local_path": "./bert/deberta-v3-large",
"repo_id": "microsoft/deberta-v3-large",
"files": [
"spm.model",
"pytorch_model.bin"
]
},
{
"local_path": "./onnx",
"repo_id": "hon9kon9ize/bert-vits-zoengjyutgaai-onnx",
"files": [
"BertVits2.2PT.json",
"BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
"BertVits2.2PT/BertVits2.2PT_emb.onnx",
"BertVits2.2PT/BertVits2.2PT_dp.onnx",
"BertVits2.2PT/BertVits2.2PT_sdp.onnx",
"BertVits2.2PT/BertVits2.2PT_flow.onnx",
"BertVits2.2PT/BertVits2.2PT_dec.onnx"
]
}
]
def get_onnx_session():
global OnnxSession
if OnnxSession is not None:
return OnnxSession
OnnxSession = OnnxInferenceSession(
{
"enc": "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
"emb_g": "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx",
"dp": "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx",
"sdp": "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx",
"flow": "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx",
"dec": "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx",
},
Providers=["CPUExecutionProvider"],
)
return OnnxSession
def download_model_files(repo_id, files, local_path):
for file in files:
if not Path(local_path).joinpath(file).exists():
hf_hub_download(
repo_id, file, local_dir=local_path, local_dir_use_symlinks=False
)
def download_models():
for data in models:
download_model_files(data["repo_id"], data["files"], data["local_path"])
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def get_text(text, language_str, style_text=None, style_weight=0.7):
style_text = None if style_text == "" else style_text
# 在此处实现当前版本的get_text
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
# add blank
phone = intersperse(phone, 0)
tone = intersperse(tone, 0)
language = intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert_ori = get_bert(
norm_text, word2ph, language_str, "cpu", style_text, style_weight
)
del word2ph
assert bert_ori.shape[-1] == len(phone), phone
if language_str == "EN":
en_bert = bert_ori
yue_bert = np.random.randn(1024, len(phone))
elif language_str == "YUE":
en_bert = np.random.randn(1024, len(phone))
yue_bert = bert_ori
else:
raise ValueError("language_str should be EN or YUE")
assert yue_bert.shape[-1] == len(
phone
), f"Bert seq len {yue_bert.shape[-1]} != {len(phone)}"
phone = np.asarray(phone)
tone = np.asarray(tone)
language = np.asarray(language)
en_bert = np.asarray(en_bert.T)
yue_bert = np.asarray(yue_bert.T)
return en_bert, yue_bert, phone, tone, language
# Text-to-speech function
async def text_to_speech(text, sid=0, language="YUE"):
Session = get_onnx_session()
if not text.strip():
return None, gr.Warning("Please enter text to convert.")
en_bert, yue_bert, x, tone, language = get_text(text, language)
sid = np.array([sid])
audio = Session(x, tone, language, en_bert, yue_bert, sid, sdp_ratio=0.4)
return audio[0][0]
# Create Gradio application
import gradio as gr
# Gradio interface function
def tts_interface(text):
audio = asyncio.run(text_to_speech(text, 0, "YUE"))
return 44100, audio
async def create_demo():
description = """廣東話語音生成器,基於Bert-VITS2模型
注意:model 本身支持廣東話同英文,但呢個 space 未實現中英夾雜生成。
"""
demo = gr.Interface(
fn=tts_interface,
inputs=[
gr.Textbox(label="Input Text", lines=5),
],
outputs=[
gr.Audio(label="Generated Audio"),
],
examples=[
["漆黑之中我心眺望,不出一聲但兩眼發光\n寂寞極淒厲,晚風充滿汗,只因她幽怨目光"],
["本身我就係一個言出必達嘅人"],
],
title="Cantonese TTS Text-to-Speech",
description=description,
analytics_enabled=False,
allow_flagging=False,
)
return demo
# Run the application
if __name__ == "__main__":
download_models()
demo = asyncio.run(create_demo())
demo.launch()