Spaces:
Runtime error
Runtime error
import os | |
import sys | |
sys.path.insert(0, os.getcwd()) | |
import ChatTTS | |
import re | |
import time | |
import io | |
from io import BytesIO | |
import pandas | |
import numpy as np | |
from tqdm import tqdm | |
import random | |
import os | |
import json | |
from utils import batch_split,normalize_zh | |
import torch | |
import soundfile as sf | |
import wave | |
from fastapi import FastAPI, Request, HTTPException, Response | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from starlette.middleware.cors import CORSMiddleware #引入 CORS中间件模块 | |
#设置允许访问的域名 | |
origins = ["*"] #"*",即为所有。 | |
from pydantic import BaseModel | |
import uvicorn | |
from typing import Generator | |
chat = ChatTTS.Chat() | |
def clear_cuda_cache(): | |
""" | |
Clear CUDA cache | |
:return: | |
""" | |
torch.cuda.empty_cache() | |
def deterministic(seed=0): | |
""" | |
Set random seed for reproducibility | |
:param seed: | |
:return: | |
""" | |
# ref: https://github.com/Jackiexiao/ChatTTS-api-ui-docker/blob/main/api.py#L27 | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
class TTS_Request(BaseModel): | |
text: str = None | |
seed: int = 2581 | |
speed: int = 3 | |
media_type: str = "wav" | |
streaming: int = 0 | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, #设置允许的origins来源 | |
allow_credentials=True, | |
allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。 | |
allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。 | |
def cut5(inp): | |
# if not re.search(r'[^\w\s]', inp[-1]): | |
# inp += '。' | |
inp = inp.strip("\n") | |
punds = r'[,.;?!、,。?!;:…]' | |
items = re.split(f'({punds})', inp) | |
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] | |
# 在句子不存在符号或句尾无符号的时候保证文本完整 | |
if len(items)%2 == 1: | |
mergeitems.append(items[-1]) | |
# opt = "\n".join(mergeitems) | |
return mergeitems | |
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py | |
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000): | |
# This will create a wave header then append the frame input | |
# It should be first on a streaming wav file | |
# Other frames better should not have it (else you will hear some artifacts each chunk start) | |
wav_buf = BytesIO() | |
with wave.open(wav_buf, "wb") as vfout: | |
vfout.setnchannels(channels) | |
vfout.setsampwidth(sample_width) | |
vfout.setframerate(sample_rate) | |
vfout.writeframes(frame_input) | |
wav_buf.seek(0) | |
return wav_buf.read() | |
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files | |
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): | |
with sf.SoundFile(io_buffer, mode='w',samplerate=rate, channels=1, format='ogg') as audio_file: | |
audio_file.write(data) | |
return io_buffer | |
def pack_raw(io_buffer:BytesIO, data:np.ndarray, rate:int): | |
io_buffer.write(data.tobytes()) | |
return io_buffer | |
def pack_wav(io_buffer:BytesIO, data:np.ndarray, rate:int): | |
io_buffer = BytesIO() | |
sf.write(io_buffer, data, rate, format='wav') | |
return io_buffer | |
def pack_aac(io_buffer:BytesIO, data:np.ndarray, rate:int): | |
process = subprocess.Popen([ | |
'ffmpeg', | |
'-f', 's16le', # 输入16位有符号小端整数PCM | |
'-ar', str(rate), # 设置采样率 | |
'-ac', '1', # 单声道 | |
'-i', 'pipe:0', # 从管道读取输入 | |
'-c:a', 'aac', # 音频编码器为AAC | |
'-b:a', '192k', # 比特率 | |
'-vn', # 不包含视频 | |
'-f', 'adts', # 输出AAC数据流格式 | |
'pipe:1' # 将输出写入管道 | |
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
out, _ = process.communicate(input=data.tobytes()) | |
io_buffer.write(out) | |
return io_buffer | |
def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str): | |
if media_type == "ogg": | |
io_buffer = pack_ogg(io_buffer, data, rate) | |
elif media_type == "aac": | |
io_buffer = pack_aac(io_buffer, data, rate) | |
elif media_type == "wav": | |
io_buffer = pack_wav(io_buffer, data, rate) | |
else: | |
io_buffer = pack_raw(io_buffer, data, rate) | |
io_buffer.seek(0) | |
return io_buffer | |
def generate_tts_audio(text_file,seed=2581,speed=1, oral=0, laugh=0, bk=4, min_length=80, batch_size=5, temperature=0.01, top_P=0.7, | |
top_K=20,streaming=0,cur_tqdm=None): | |
from utils import combine_audio, save_audio, batch_split | |
from utils import split_text, replace_tokens, restore_tokens | |
if seed in [0, -1, None]: | |
seed = random.randint(1, 9999) | |
content = text_file | |
# texts = split_text(content, min_length=min_length) | |
# if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7: | |
# raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range") | |
# refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]" | |
# 将 [uv_break] [laugh] 替换为 _uv_break_ _laugh_ 处理后再还原 | |
content = replace_tokens(content) | |
texts = split_text(content, min_length=min_length) | |
for i, text in enumerate(texts): | |
texts[i] = restore_tokens(text) | |
if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7: | |
raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range") | |
refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]" | |
deterministic(seed) | |
rnd_spk_emb = chat.sample_random_speaker() | |
params_infer_code = { | |
'spk_emb': rnd_spk_emb, | |
'prompt': f'[speed_{speed}]', | |
'top_P': top_P, | |
'top_K': top_K, | |
'temperature': temperature | |
} | |
params_refine_text = { | |
'prompt': refine_text_prompt, | |
'top_P': top_P, | |
'top_K': top_K, | |
'temperature': temperature | |
} | |
if not cur_tqdm: | |
cur_tqdm = tqdm | |
start_time = time.time() | |
if not streaming: | |
all_wavs = [] | |
for batch in cur_tqdm(batch_split(texts, batch_size), desc=f"Inferring audio for seed={seed}"): | |
print(batch) | |
wavs = chat.infer(batch, params_infer_code=params_infer_code, params_refine_text=params_refine_text,use_decoder=True, skip_refine_text=True) | |
audio_data = wavs[0][0] | |
audio_data = audio_data / np.max(np.abs(audio_data)) | |
all_wavs.append(audio_data) | |
# all_wavs.extend(wavs) | |
clear_cuda_cache() | |
audio = (np.concatenate(all_wavs) * 32768).astype( | |
np.int16 | |
) | |
# end_time = time.time() | |
# elapsed_time = end_time - start_time | |
# print(f"Saving audio for seed {seed}, took {elapsed_time:.2f}s") | |
yield audio | |
else: | |
print("流式生成") | |
texts = [normalize_zh(_) for _ in content.split('\n') if _.strip()] | |
for text in texts: | |
wavs_gen = chat.infer(text, params_infer_code=params_infer_code, params_refine_text=params_refine_text,use_decoder=True, skip_refine_text=True,stream=True) | |
for gen in wavs_gen: | |
wavs = [np.array([[]])] | |
wavs[0] = np.hstack([wavs[0], np.array(gen[0])]) | |
audio_data = wavs[0][0] | |
audio_data = audio_data / np.max(np.abs(audio_data)) | |
yield (audio_data * 32767).astype(np.int16) | |
# clear_cuda_cache() | |
async def tts_handle(req:dict): | |
media_type = req["media_type"] | |
print(req["streaming"]) | |
print(req["media_type"]) | |
if not req["streaming"]: | |
audio_data = next(generate_tts_audio(req["text"],req["seed"])) | |
# print(audio_data) | |
sr = 24000 | |
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() | |
return Response(audio_data, media_type=f"audio/{media_type}") | |
# return FileResponse(f"./{audio_data}", media_type="audio/wav") | |
else: | |
tts_generator = generate_tts_audio(req["text"],req["seed"],streaming=1) | |
sr = 24000 | |
def streaming_generator(tts_generator:Generator, media_type:str): | |
if media_type == "wav": | |
yield wave_header_chunk() | |
media_type = "raw" | |
for chunk in tts_generator: | |
print(chunk) | |
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() | |
return StreamingResponse(streaming_generator(tts_generator, media_type), media_type=f"audio/{media_type}") | |
async def tts_get(text: str = None,media_type:str = "wav",seed:int = 2581,streaming:int = 0): | |
req = { | |
"text": text, | |
"media_type": media_type, | |
"seed": seed, | |
"streaming": streaming, | |
} | |
return await tts_handle(req) | |
def speakers_endpoint(): | |
return JSONResponse([{"name":"default","vid":1}], status_code=200) | |
def speakerlist_endpoint(): | |
return JSONResponse(["female_calm","female","male"], status_code=200) | |
async def tts_post_endpoint(request: TTS_Request): | |
req = request.dict() | |
return await tts_handle(req) | |
async def tts_to_audio(request: TTS_Request): | |
req = request.dict() | |
from config import llama_seed | |
req["seed"] = llama_seed | |
return await tts_handle(req) | |
if __name__ == "__main__": | |
chat.load_models(source="local", local_path="models") | |
# chat = load_chat_tts_model(source="local", local_path="models") | |
uvicorn.run(app,host='0.0.0.0',port=9880,workers=1) | |