File size: 1,621 Bytes
01e655b
 
 
 
 
 
8c22399
01e655b
 
 
 
 
 
 
 
 
 
 
8c22399
01e655b
 
 
 
 
 
 
 
8c22399
 
 
d5d0921
01e655b
8c22399
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
ebc4336
 
 
 
 
01e655b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import HTTPException
from pydantic import BaseModel

from modules import refiner
from modules.api import utils as api_utils
from modules.api.Api import APIManager
from modules.normalization import text_normalize


class RefineTextRequest(BaseModel):
    text: str
    prompt: str = "[oral_2][laugh_0][break_6]"
    seed: int = -1
    top_P: float = 0.7
    top_K: int = 20
    temperature: float = 0.7
    repetition_penalty: float = 1.0
    max_new_token: int = 384
    normalize: bool = True


async def refiner_prompt_post(request: RefineTextRequest):
    """
    This endpoint receives a prompt and returns the refined result
    """

    try:
        text = request.text
        if request.normalize:
            text = text_normalize(request.text)
        # TODO 其实这里可以做 spliter 和 batch 处理
        refined_text = refiner.refine_text(
            text=text,
            prompt=request.prompt,
            seed=request.seed,
            top_P=request.top_P,
            top_K=request.top_K,
            temperature=request.temperature,
            repetition_penalty=request.repetition_penalty,
            max_new_token=request.max_new_token,
        )
        return {"message": "ok", "data": refined_text}

    except Exception as e:
        import logging

        logging.exception(e)

        if isinstance(e, HTTPException):
            raise e
        else:
            raise HTTPException(status_code=500, detail=str(e))


def setup(api_manager: APIManager):
    api_manager.post("/v1/prompt/refine", response_model=api_utils.BaseResponse)(
        refiner_prompt_post
    )