ChatTTS-Forge / modules /webui /webui_utils.py
zhzluke96
update
374f426
raw
history blame
3.97 kB
import os
import logging
import sys
import numpy as np
from modules.devices import devices
from modules.synthesize_audio import synthesize_audio
from modules.hf import spaces
from modules.webui import webui_config
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
import gradio as gr
import torch
from modules.ssml import parse_ssml
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
from modules.speaker import speaker_mgr
from modules.data import styles_mgr
from modules.api.utils import calc_spk_style
import modules.generate_audio as generate
from modules.normalization import text_normalize
from modules import refiner, config
from modules.utils import env, audio
from modules.SentenceSplitter import SentenceSplitter
def get_speakers():
return speaker_mgr.list_speakers()
def get_styles():
return styles_mgr.list_items()
def segments_length_limit(segments, total_max: int):
ret_segments = []
total_len = 0
for seg in segments:
if "text" not in seg:
continue
total_len += len(seg["text"])
if total_len > total_max:
break
ret_segments.append(seg)
return ret_segments
@torch.inference_mode()
@spaces.GPU
def synthesize_ssml(ssml: str, batch_size=4):
try:
batch_size = int(batch_size)
except Exception:
batch_size = 8
ssml = ssml.strip()
if ssml == "":
return None
segments = parse_ssml(ssml)
max_len = webui_config.ssml_max
segments = segments_length_limit(segments, max_len)
if len(segments) == 0:
return None
synthesize = SynthesizeSegments(batch_size=batch_size)
audio_segments = synthesize.synthesize_segments(segments)
combined_audio = combine_audio_segments(audio_segments)
return audio.pydub_to_np(combined_audio)
@torch.inference_mode()
@spaces.GPU
def tts_generate(
text,
temperature,
top_p,
top_k,
spk,
infer_seed,
use_decoder,
prompt1,
prompt2,
prefix,
style,
disable_normalize=False,
batch_size=4,
):
try:
batch_size = int(batch_size)
except Exception:
batch_size = 4
max_len = webui_config.tts_max
text = text.strip()[0:max_len]
if text == "":
return None
if style == "*auto":
style = None
if isinstance(top_k, float):
top_k = int(top_k)
params = calc_spk_style(spk=spk, style=style)
spk = params.get("spk", spk)
infer_seed = infer_seed or params.get("seed", infer_seed)
temperature = temperature or params.get("temperature", temperature)
prefix = prefix or params.get("prefix", prefix)
prompt1 = prompt1 or params.get("prompt1", "")
prompt2 = prompt2 or params.get("prompt2", "")
infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.int64)
infer_seed = int(infer_seed)
if not disable_normalize:
text = text_normalize(text)
sample_rate, audio_data = synthesize_audio(
text=text,
temperature=temperature,
top_P=top_p,
top_K=top_k,
spk=spk,
infer_seed=infer_seed,
use_decoder=use_decoder,
prompt1=prompt1,
prompt2=prompt2,
prefix=prefix,
batch_size=batch_size,
)
audio_data = audio.audio_to_int16(audio_data)
return sample_rate, audio_data
@torch.inference_mode()
@spaces.GPU
def refine_text(text: str, prompt: str):
text = text_normalize(text)
return refiner.refine_text(text, prompt=prompt)
@torch.inference_mode()
@spaces.GPU
def split_long_text(long_text_input):
spliter = SentenceSplitter(webui_config.spliter_threshold)
sentences = spliter.parse(long_text_input)
sentences = [text_normalize(s) for s in sentences]
data = []
for i, text in enumerate(sentences):
data.append([i, text, len(text)])
return data