zhzluke96
update
32b2aaa
raw
history blame
1.49 kB
import logging
from functools import cache
from pathlib import Path
import torch
from ..inference import inference
from .download import download
from .hparams import HParams
from .enhancer import Enhancer
logger = logging.getLogger(__name__)
@cache
def load_enhancer(run_dir: str | Path | None, device):
run_dir = download(run_dir)
hp = HParams.load(run_dir)
enhancer = Enhancer(hp)
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
state_dict = torch.load(path, map_location="cpu")["module"]
enhancer.load_state_dict(state_dict)
enhancer.eval()
enhancer.to(device)
return enhancer
@torch.inference_mode()
def denoise(dwav, sr, device, run_dir=None):
enhancer = load_enhancer(run_dir, device)
return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
@torch.inference_mode()
def enhance(
dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None
):
assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
assert solver in (
"midpoint",
"rk4",
"euler",
), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
enhancer = load_enhancer(run_dir, device)
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
return inference(model=enhancer, dwav=dwav, sr=sr, device=device)