chattts / modules /refiner.py
zhzluke96
update
bed01bd
raw
history blame
1.15 kB
from typing import Generator
import numpy as np
import torch
from modules import config, models
from modules.utils.SeedContext import SeedContext
@torch.inference_mode()
def refine_text(
text: str,
prompt="[oral_2][laugh_0][break_6]",
seed=-1,
top_P=0.7,
top_K=20,
temperature=0.7,
repetition_penalty=1.0,
max_new_token=384,
) -> str:
chat_tts = models.load_chat_tts()
with SeedContext(seed):
refined_text = chat_tts.refiner_prompt(
text,
{
"prompt": prompt,
"top_K": top_K,
"top_P": top_P,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_token": max_new_token,
"disable_tqdm": config.runtime_env_vars.off_tqdm,
},
)
if isinstance(refined_text, Generator):
raise NotImplementedError(
"Refiner is not yet implemented for generator output"
)
if isinstance(refined_text, list):
refined_text = "\n".join(refined_text)
return refined_text