Spaces:
Running
Running
from fastapi import HTTPException | |
from pydantic import BaseModel | |
from modules import refiner | |
from modules.api import utils as api_utils | |
from modules.api.Api import APIManager | |
from modules.normalization import text_normalize | |
class RefineTextRequest(BaseModel): | |
text: str | |
prompt: str = "[oral_2][laugh_0][break_6]" | |
seed: int = -1 | |
top_P: float = 0.7 | |
top_K: int = 20 | |
temperature: float = 0.7 | |
repetition_penalty: float = 1.0 | |
max_new_token: int = 384 | |
normalize: bool = True | |
async def refiner_prompt_post(request: RefineTextRequest): | |
""" | |
This endpoint receives a prompt and returns the refined result | |
""" | |
try: | |
text = request.text | |
if request.normalize: | |
text = text_normalize(request.text) | |
# TODO 其实这里可以做 spliter 和 batch 处理 | |
refined_text = refiner.refine_text( | |
text=text, | |
prompt=request.prompt, | |
seed=request.seed, | |
top_P=request.top_P, | |
top_K=request.top_K, | |
temperature=request.temperature, | |
repetition_penalty=request.repetition_penalty, | |
max_new_token=request.max_new_token, | |
) | |
return {"message": "ok", "data": refined_text} | |
except Exception as e: | |
import logging | |
logging.exception(e) | |
if isinstance(e, HTTPException): | |
raise e | |
else: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def setup(api_manager: APIManager): | |
api_manager.post("/v1/prompt/refine", response_model=api_utils.BaseResponse)( | |
refiner_prompt_post | |
) | |