File size: 6,335 Bytes
bed01bd
 
01e655b
d2b7e94
01e655b
 
d2b7e94
 
d5d0921
 
 
 
 
01e655b
bed01bd
 
01e655b
 
 
 
 
 
 
 
 
 
d5d0921
01e655b
 
d5d0921
01e655b
 
 
 
 
 
 
 
 
 
 
d5d0921
 
 
 
 
 
 
 
01e655b
bed01bd
 
01e655b
 
 
ebc4336
 
 
 
 
 
 
 
 
 
 
 
d5d0921
 
 
ebc4336
d5d0921
 
ebc4336
d5d0921
ebc4336
d5d0921
ebc4336
d5d0921
ebc4336
 
 
 
 
 
 
 
 
01e655b
 
 
d5d0921
 
 
 
01e655b
 
 
 
 
 
 
1df74c6
01e655b
 
 
 
d5d0921
 
01e655b
d5d0921
 
 
01e655b
 
d5d0921
 
01e655b
 
d5d0921
 
 
 
 
 
 
 
 
 
 
01e655b
 
d5d0921
 
 
 
 
 
 
 
 
 
 
01e655b
bed01bd
 
 
 
 
 
 
 
 
 
 
 
01e655b
 
 
 
ebc4336
 
 
 
 
01e655b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import logging

from fastapi import Depends, HTTPException, Query
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel

from modules.api import utils as api_utils
from modules.api.Api import APIManager
from modules.api.impl.handler.TTSHandler import TTSHandler
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.speaker import Speaker

logger = logging.getLogger(__name__)


class TTSParams(BaseModel):
    text: str = Query(..., description="Text to synthesize")
    spk: str = Query(
        "female2", description="Specific speaker by speaker name or speaker seed"
    )
    style: str = Query("chat", description="Specific style by style name")
    temperature: float = Query(
        0.3, description="Temperature for sampling (may be overridden by style or spk)"
    )
    top_p: float = Query(
        0.5, description="Top P for sampling (may be overridden by style or spk)"
    )
    top_k: int = Query(
        20, description="Top K for sampling (may be overridden by style or spk)"
    )
    seed: int = Query(
        42, description="Seed for generate (may be overridden by style or spk)"
    )
    format: str = Query("mp3", description="Response audio format: [mp3,wav]")
    prompt1: str = Query("", description="Text prompt for inference")
    prompt2: str = Query("", description="Text prompt for inference")
    prefix: str = Query("", description="Text prefix for inference")
    bs: str = Query("8", description="Batch size for inference")
    thr: str = Query("100", description="Threshold for sentence spliter")
    eos: str = Query("[uv_break]", description="End of sentence str")

    enhance: bool = Query(False, description="Enable enhancer")
    denoise: bool = Query(False, description="Enable denoiser")

    speed: float = Query(1.0, description="Speed of the audio")
    pitch: float = Query(0, description="Pitch of the audio")
    volume_gain: float = Query(0, description="Volume gain of the audio")

    stream: bool = Query(False, description="Stream the audio")


async def synthesize_tts(params: TTSParams = Depends()):
    try:
        # Validate text
        if not params.text.strip():
            raise HTTPException(
                status_code=422, detail="Text parameter cannot be empty"
            )

        # Validate temperature
        if not (0 <= params.temperature <= 1):
            raise HTTPException(
                status_code=422, detail="Temperature must be between 0 and 1"
            )

        # Validate top_p
        if not (0 <= params.top_p <= 1):
            raise HTTPException(status_code=422, detail="top_p must be between 0 and 1")

        # Validate top_k
        if params.top_k <= 0:
            raise HTTPException(
                status_code=422, detail="top_k must be a positive integer"
            )
        if params.top_k > 100:
            raise HTTPException(
                status_code=422, detail="top_k must be less than or equal to 100"
            )

        # Validate format
        if params.format not in ["mp3", "wav"]:
            raise HTTPException(
                status_code=422,
                detail="Invalid format. Supported formats are mp3 and wav",
            )

        calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)

        spk = calc_params.get("spk", params.spk)
        if not isinstance(spk, Speaker):
            raise HTTPException(status_code=422, detail="Invalid speaker")

        style = calc_params.get("style", params.style)
        seed = params.seed or calc_params.get("seed", params.seed)
        temperature = params.temperature or calc_params.get(
            "temperature", params.temperature
        )
        prefix = params.prefix or calc_params.get("prefix", params.prefix)
        prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
        prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
        eos = params.eos or ""

        batch_size = int(params.bs)
        threshold = int(params.thr)

        tts_config = ChatTTSConfig(
            style=style,
            temperature=temperature,
            top_k=params.top_k,
            top_p=params.top_p,
            prefix=prefix,
            prompt1=prompt1,
            prompt2=prompt2,
        )
        infer_config = InferConfig(
            batch_size=batch_size,
            spliter_threshold=threshold,
            eos=eos,
            seed=seed,
        )
        adjust_config = AdjustConfig(
            pitch=params.pitch,
            speed_rate=params.speed,
            volume_gain_db=params.volume_gain,
        )
        enhancer_config = EnhancerConfig(
            enabled=params.enhance or params.denoise or False,
            lambd=0.9 if params.denoise else 0.1,
        )

        handler = TTSHandler(
            text_content=params.text,
            spk=spk,
            tts_config=tts_config,
            infer_config=infer_config,
            adjust_config=adjust_config,
            enhancer_config=enhancer_config,
        )
        media_type = f"audio/{params.format}"
        if params.format == "mp3":
            media_type = "audio/mpeg"

        if params.stream:
            if infer_config.batch_size != 1:
                # 流式生成下仅支持 batch size 为 1,当前请求参数将被忽略
                logger.warning(
                    f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1"
                )

            buffer_gen = handler.enqueue_to_stream(format=AudioFormat(params.format))
            return StreamingResponse(buffer_gen, media_type=media_type)
        else:
            buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
            return StreamingResponse(buffer, media_type=media_type)
    except Exception as e:
        import logging

        logging.exception(e)

        if isinstance(e, HTTPException):
            raise e
        else:
            raise HTTPException(status_code=500, detail=str(e))


def setup(api_manager: APIManager):
    api_manager.get("/v1/tts", response_class=FileResponse)(synthesize_tts)