import logging from functools import cache from pathlib import Path from typing import Union import torch from ..inference import inference from .download import download from .enhancer import Enhancer from .hparams import HParams logger = logging.getLogger(__name__) @cache def load_enhancer(run_dir: Union[str, Path, None], device): run_dir = download(run_dir) hp = HParams.load(run_dir) enhancer = Enhancer(hp) path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" state_dict = torch.load(path, map_location="cpu")["module"] enhancer.load_state_dict(state_dict) enhancer.eval() enhancer.to(device) return enhancer @torch.inference_mode() def denoise(dwav, sr, device, run_dir=None): enhancer = load_enhancer(run_dir, device) return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device) @torch.inference_mode() def enhance( dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None ): assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}" assert solver in ( "midpoint", "rk4", "euler", ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}" assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}" assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}" enhancer = load_enhancer(run_dir, device) enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau) return inference(model=enhancer, dwav=dwav, sr=sr, device=device)