File size: 3,548 Bytes
da8d589
 
c5dfbfb
 
 
 
 
 
5e0d8b8
 
da8d589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from typing import List

try:
    from resemble_enhance.enhancer.enhancer import Enhancer
    from resemble_enhance.enhancer.hparams import HParams
    from resemble_enhance.inference import inference
except:
    HParams = dict
    Enhancer = dict

import torch

from modules.utils.constants import MODELS_DIR
from pathlib import Path

from threading import Lock

resemble_enhance = None
lock = Lock()


def load_enhancer(device: torch.device):
    global resemble_enhance
    with lock:
        if resemble_enhance is None:
            resemble_enhance = ResembleEnhance(device)
            resemble_enhance.load_model()
    return resemble_enhance


class ResembleEnhance:
    hparams: HParams
    enhancer: Enhancer

    def __init__(self, device: torch.device):
        self.device = device

        self.enhancer = None
        self.hparams = None

    def load_model(self):
        hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
        enhancer = Enhancer(hparams)
        state_dict = torch.load(
            Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
            map_location="cpu",
        )["module"]
        enhancer.load_state_dict(state_dict)
        enhancer.eval()
        enhancer.to(self.device)
        enhancer.denoiser.to(self.device)

        self.hparams = hparams
        self.enhancer = enhancer

    @torch.inference_mode()
    def denoise(self, dwav, sr, device) -> tuple[torch.Tensor, int]:
        assert self.enhancer is not None, "Model not loaded"
        assert self.enhancer.denoiser is not None, "Denoiser not loaded"
        enhancer = self.enhancer
        return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)

    @torch.inference_mode()
    def enhance(
        self,
        dwav,
        sr,
        device,
        nfe=32,
        solver="midpoint",
        lambd=0.5,
        tau=0.5,
    ) -> tuple[torch.Tensor, int]:
        assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
        assert solver in (
            "midpoint",
            "rk4",
            "euler",
        ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
        assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
        assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
        assert self.enhancer is not None, "Model not loaded"
        enhancer = self.enhancer
        enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
        return inference(model=enhancer, dwav=dwav, sr=sr, device=device)


if __name__ == "__main__":
    import torchaudio
    from modules.models import load_chat_tts

    load_chat_tts()

    device = torch.device("cuda")
    ench = ResembleEnhance(device)
    ench.load_model()

    wav, sr = torchaudio.load("test.wav")

    print(wav.shape, type(wav), sr, type(sr))
    exit()

    wav = wav.squeeze(0).cuda()

    print(wav.device)

    denoised, d_sr = ench.denoise(wav.cpu(), sr, device)
    denoised = denoised.unsqueeze(0)
    print(denoised.shape)
    torchaudio.save("denoised.wav", denoised, d_sr)

    for solver in ("midpoint", "rk4", "euler"):
        for lambd in (0.1, 0.5, 0.9):
            for tau in (0.1, 0.5, 0.9):
                enhanced, e_sr = ench.enhance(
                    wav.cpu(), sr, device, solver=solver, lambd=lambd, tau=tau, nfe=128
                )
                enhanced = enhanced.unsqueeze(0)
                print(enhanced.shape)
                torchaudio.save(f"enhanced_{solver}_{lambd}_{tau}.wav", enhanced, e_sr)