zhzluke96
update
bed01bd
import logging
from fastapi import Depends, HTTPException, Query
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from modules.api import utils as api_utils
from modules.api.Api import APIManager
from modules.api.impl.handler.TTSHandler import TTSHandler
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.speaker import Speaker
logger = logging.getLogger(__name__)
class TTSParams(BaseModel):
text: str = Query(..., description="Text to synthesize")
spk: str = Query(
"female2", description="Specific speaker by speaker name or speaker seed"
)
style: str = Query("chat", description="Specific style by style name")
temperature: float = Query(
0.3, description="Temperature for sampling (may be overridden by style or spk)"
)
top_p: float = Query(
0.5, description="Top P for sampling (may be overridden by style or spk)"
)
top_k: int = Query(
20, description="Top K for sampling (may be overridden by style or spk)"
)
seed: int = Query(
42, description="Seed for generate (may be overridden by style or spk)"
)
format: str = Query("mp3", description="Response audio format: [mp3,wav]")
prompt1: str = Query("", description="Text prompt for inference")
prompt2: str = Query("", description="Text prompt for inference")
prefix: str = Query("", description="Text prefix for inference")
bs: str = Query("8", description="Batch size for inference")
thr: str = Query("100", description="Threshold for sentence spliter")
eos: str = Query("[uv_break]", description="End of sentence str")
enhance: bool = Query(False, description="Enable enhancer")
denoise: bool = Query(False, description="Enable denoiser")
speed: float = Query(1.0, description="Speed of the audio")
pitch: float = Query(0, description="Pitch of the audio")
volume_gain: float = Query(0, description="Volume gain of the audio")
stream: bool = Query(False, description="Stream the audio")
async def synthesize_tts(params: TTSParams = Depends()):
try:
# Validate text
if not params.text.strip():
raise HTTPException(
status_code=422, detail="Text parameter cannot be empty"
)
# Validate temperature
if not (0 <= params.temperature <= 1):
raise HTTPException(
status_code=422, detail="Temperature must be between 0 and 1"
)
# Validate top_p
if not (0 <= params.top_p <= 1):
raise HTTPException(status_code=422, detail="top_p must be between 0 and 1")
# Validate top_k
if params.top_k <= 0:
raise HTTPException(
status_code=422, detail="top_k must be a positive integer"
)
if params.top_k > 100:
raise HTTPException(
status_code=422, detail="top_k must be less than or equal to 100"
)
# Validate format
if params.format not in ["mp3", "wav"]:
raise HTTPException(
status_code=422,
detail="Invalid format. Supported formats are mp3 and wav",
)
calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
spk = calc_params.get("spk", params.spk)
if not isinstance(spk, Speaker):
raise HTTPException(status_code=422, detail="Invalid speaker")
style = calc_params.get("style", params.style)
seed = params.seed or calc_params.get("seed", params.seed)
temperature = params.temperature or calc_params.get(
"temperature", params.temperature
)
prefix = params.prefix or calc_params.get("prefix", params.prefix)
prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
eos = params.eos or ""
batch_size = int(params.bs)
threshold = int(params.thr)
tts_config = ChatTTSConfig(
style=style,
temperature=temperature,
top_k=params.top_k,
top_p=params.top_p,
prefix=prefix,
prompt1=prompt1,
prompt2=prompt2,
)
infer_config = InferConfig(
batch_size=batch_size,
spliter_threshold=threshold,
eos=eos,
seed=seed,
)
adjust_config = AdjustConfig(
pitch=params.pitch,
speed_rate=params.speed,
volume_gain_db=params.volume_gain,
)
enhancer_config = EnhancerConfig(
enabled=params.enhance or params.denoise or False,
lambd=0.9 if params.denoise else 0.1,
)
handler = TTSHandler(
text_content=params.text,
spk=spk,
tts_config=tts_config,
infer_config=infer_config,
adjust_config=adjust_config,
enhancer_config=enhancer_config,
)
media_type = f"audio/{params.format}"
if params.format == "mp3":
media_type = "audio/mpeg"
if params.stream:
if infer_config.batch_size != 1:
# 流式生成下仅支持 batch size 为 1,当前请求参数将被忽略
logger.warning(
f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1"
)
buffer_gen = handler.enqueue_to_stream(format=AudioFormat(params.format))
return StreamingResponse(buffer_gen, media_type=media_type)
else:
buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
return StreamingResponse(buffer, media_type=media_type)
except Exception as e:
import logging
logging.exception(e)
if isinstance(e, HTTPException):
raise e
else:
raise HTTPException(status_code=500, detail=str(e))
def setup(api_manager: APIManager):
api_manager.get("/v1/tts", response_class=FileResponse)(synthesize_tts)