Spaces:
Sleeping
Sleeping
import logging | |
from functools import cache | |
import torch | |
from ..denoiser.denoiser import Denoiser | |
from ..inference import inference | |
from .hparams import HParams | |
logger = logging.getLogger(__name__) | |
def load_denoiser(run_dir, device): | |
if run_dir is None: | |
return Denoiser(HParams()) | |
hp = HParams.load(run_dir) | |
denoiser = Denoiser(hp) | |
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" | |
state_dict = torch.load(path, map_location="cpu")["module"] | |
denoiser.load_state_dict(state_dict) | |
denoiser.eval() | |
denoiser.to(device) | |
return denoiser | |
def denoise(dwav, sr, run_dir, device): | |
denoiser = load_denoiser(run_dir, device) | |
return inference(model=denoiser, dwav=dwav, sr=sr, device=device) | |