Spaces:
Sleeping
Sleeping
from modules.ChatTTS import ChatTTS | |
import torch | |
from modules import config | |
import logging | |
logger = logging.getLogger(__name__) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
print(f"device use {device}") | |
chat_tts = None | |
def load_chat_tts(): | |
global chat_tts | |
if chat_tts: | |
return chat_tts | |
chat_tts = ChatTTS.Chat() | |
chat_tts.load_models( | |
compile=config.enable_model_compile, | |
source="local", | |
local_path="./models/ChatTTS", | |
device=device, | |
) | |
if config.model_config.get("half", False): | |
logging.info("half precision enabled") | |
for model_name, model in chat_tts.pretrain_models.items(): | |
if isinstance(model, torch.nn.Module): | |
model.cpu() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
model.half() | |
if torch.cuda.is_available(): | |
model.cuda() | |
model.eval() | |
logger.log(logging.INFO, f"{model_name} converted to half precision.") | |
return chat_tts | |