ChatTTS-Forge / modules /Enhancer /ResembleEnhance.py
zhzluke96
update
627d3d7
raw
history blame
6.14 kB
import gc
import os
from typing import List, Literal
import numpy as np
from modules.devices import devices
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
from modules.repos_static.resemble_enhance.inference import inference
import torch
from modules.utils.constants import MODELS_DIR
from pathlib import Path
from threading import Lock
from modules import config
import logging
logger = logging.getLogger(__name__)
resemble_enhance = None
lock = Lock()
class ResembleEnhance:
def __init__(self, device: torch.device, dtype=torch.float32):
self.device = device
self.dtype = dtype
self.enhancer: HParams = None
self.hparams: Enhancer = None
def load_model(self):
hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
enhancer = Enhancer(hparams).to(device=self.device, dtype=self.dtype).eval()
state_dict = torch.load(
Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
map_location=self.device,
)["module"]
enhancer.load_state_dict(state_dict)
self.hparams = hparams
self.enhancer = enhancer
@torch.inference_mode()
def denoise(self, dwav, sr) -> tuple[torch.Tensor, int]:
assert self.enhancer is not None, "Model not loaded"
assert self.enhancer.denoiser is not None, "Denoiser not loaded"
enhancer = self.enhancer
return inference(
model=enhancer.denoiser,
dwav=dwav,
sr=sr,
device=self.devicem,
dtype=self.dtype,
)
@torch.inference_mode()
def enhance(
self,
dwav,
sr,
nfe=32,
solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
lambd=0.5,
tau=0.5,
) -> tuple[torch.Tensor, int]:
assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
assert solver in (
"midpoint",
"rk4",
"euler",
), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
assert self.enhancer is not None, "Model not loaded"
enhancer = self.enhancer
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
return inference(
model=enhancer, dwav=dwav, sr=sr, device=self.device, dtype=self.dtype
)
def load_enhancer() -> ResembleEnhance:
global resemble_enhance
with lock:
if resemble_enhance is None:
logger.info("Loading ResembleEnhance model")
resemble_enhance = ResembleEnhance(
device=devices.device, dtype=devices.dtype
)
resemble_enhance.load_model()
logger.info("ResembleEnhance model loaded")
return resemble_enhance
def unload_enhancer():
global resemble_enhance
with lock:
if resemble_enhance is not None:
logger.info("Unloading ResembleEnhance model")
del resemble_enhance
resemble_enhance = None
devices.torch_gc()
gc.collect()
logger.info("ResembleEnhance model unloaded")
def reload_enhancer():
logger.info("Reloading ResembleEnhance model")
unload_enhancer()
load_enhancer()
logger.info("ResembleEnhance model reloaded")
def apply_audio_enhance_full(
audio_data: np.ndarray,
sr: int,
nfe=32,
solver: Literal["midpoint", "rk4", "euler"] = "midpoint",
lambd=0.5,
tau=0.5,
):
# FIXME: 这里可能改成 to(device) 会优化一点?
tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
enhancer = load_enhancer()
tensor, sr = enhancer.enhance(
tensor, sr, tau=tau, nfe=nfe, solver=solver, lambd=lambd
)
audio_data = tensor.cpu().numpy()
return audio_data, int(sr)
def apply_audio_enhance(
audio_data: np.ndarray, sr: int, enable_denoise: bool, enable_enhance: bool
):
if not enable_denoise and not enable_enhance:
return audio_data, sr
# FIXME: 这里可能改成 to(device) 会优化一点?
tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
enhancer = load_enhancer()
if enable_enhance or enable_denoise:
lambd = 0.9 if enable_denoise else 0.1
tensor, sr = enhancer.enhance(
tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd
)
audio_data = tensor.cpu().numpy()
return audio_data, int(sr)
if __name__ == "__main__":
import torchaudio
import gradio as gr
device = torch.device("cuda")
# def enhance(file):
# print(file)
# ench = load_enhancer(device)
# dwav, sr = torchaudio.load(file)
# dwav = dwav.mean(dim=0).to(device)
# enhanced, e_sr = ench.enhance(dwav, sr)
# return e_sr, enhanced.cpu().numpy()
# # 随便一个示例
# gr.Interface(
# fn=enhance, inputs=[gr.Audio(type="filepath")], outputs=[gr.Audio()]
# ).launch()
# load_chat_tts()
# ench = load_enhancer(device)
# devices.torch_gc()
# wav, sr = torchaudio.load("test.wav")
# print(wav.shape, type(wav), sr, type(sr))
# # exit()
# wav = wav.squeeze(0).cuda()
# print(wav.device)
# denoised, d_sr = ench.denoise(wav, sr)
# denoised = denoised.unsqueeze(0)
# print(denoised.shape)
# torchaudio.save("denoised.wav", denoised.cpu(), d_sr)
# for solver in ("midpoint", "rk4", "euler"):
# for lambd in (0.1, 0.5, 0.9):
# for tau in (0.1, 0.5, 0.9):
# enhanced, e_sr = ench.enhance(
# wav, sr, solver=solver, lambd=lambd, tau=tau, nfe=128
# )
# enhanced = enhanced.unsqueeze(0)
# print(enhanced.shape)
# torchaudio.save(
# f"enhanced_{solver}_{lambd}_{tau}.wav", enhanced.cpu(), e_sr
# )