File size: 789 Bytes
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import logging
from functools import cache

import torch

from ..denoiser.denoiser import Denoiser
from ..inference import inference
from .hparams import HParams

logger = logging.getLogger(__name__)


@cache
def load_denoiser(run_dir, device):
    if run_dir is None:
        return Denoiser(HParams())
    hp = HParams.load(run_dir)
    denoiser = Denoiser(hp)
    path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
    state_dict = torch.load(path, map_location="cpu")["module"]
    denoiser.load_state_dict(state_dict)
    denoiser.eval()
    denoiser.to(device)
    return denoiser


@torch.inference_mode()
def denoise(dwav, sr, run_dir, device):
    denoiser = load_denoiser(run_dir, device)
    return inference(model=denoiser, dwav=dwav, sr=sr, device=device)