File size: 6,095 Bytes
da8d589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import logging
import math
from typing import Union
import torch
import torchaudio
from torch import nn
from audio_denoiser.helpers.torch_helper import batched_apply
from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
from audio_denoiser.helpers.audio_helper import (
    create_spectrogram,
    reconstruct_from_spectrogram,
)

_expected_t_std = 0.23
_recommended_backend = "soundfile"


# ref: https://github.com/jose-solorzano/audio-denoiser
class AudioDenoiser:
    def __init__(
        self,
        local_dir: str,
        device: Union[str, torch.device] = None,
        num_iterations: int = 100,
    ):
        super().__init__()
        if device is None:
            is_cuda = torch.cuda.is_available()
            if not is_cuda:
                logging.warning("CUDA not available. Will use CPU.")
            device = torch.device("cuda:0") if is_cuda else torch.device("cpu")
        self.device = device
        self.model = load_audio_denosier_model(dir_path=local_dir, device=device)
        self.model.eval()
        self.model_sample_rate = self.model.sample_rate
        self.scaler = self.model.scaler
        self.n_fft = self.model.n_fft
        self.segment_num_frames = self.model.num_frames
        self.num_iterations = num_iterations

    @staticmethod
    def _sp_log(spectrogram: torch.Tensor, eps=0.01):
        return torch.log(spectrogram + eps)

    @staticmethod
    def _sp_exp(log_spectrogram: torch.Tensor, eps=0.01):
        return torch.clamp(torch.exp(log_spectrogram) - eps, min=0)

    @staticmethod
    def _trimmed_dev(waveform: torch.Tensor, q: float = 0.90) -> float:
        # Expected for training data is ~0.23
        abs_waveform = torch.abs(waveform)
        quantile_value = torch.quantile(abs_waveform, q).item()
        trimmed_values = waveform[abs_waveform >= quantile_value]
        return torch.std(trimmed_values).item()

    def process_waveform(
        self,
        waveform: torch.Tensor,
        sample_rate: int,
        return_cpu_tensor: bool = False,
        auto_scale: bool = False,
    ) -> torch.Tensor:
        """
        Denoises a waveform.
        @param waveform: A waveform tensor. Use torchaudio structure.
        @param sample_rate: The sample rate of the waveform in Hz.
        @param return_cpu_tensor: Whether the returned tensor must be a CPU tensor.
        @param auto_scale: Normalize the scale of the waveform before processing. Recommended for low-volume audio.
        @return: A denoised waveform.
        """
        waveform = waveform.cpu()
        if auto_scale:
            w_t_std = self._trimmed_dev(waveform)
            waveform = waveform * _expected_t_std / w_t_std
        if sample_rate != self.model_sample_rate:
            transform = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=self.model_sample_rate
            )
            waveform = transform(waveform)
        hop_len = self.n_fft // 2
        spectrogram = create_spectrogram(waveform, n_fft=self.n_fft, hop_length=hop_len)
        spectrogram = spectrogram.to(self.device)
        num_a_channels = spectrogram.size(0)
        with torch.no_grad():
            results = []
            for c in range(num_a_channels):
                c_spectrogram = spectrogram[c]
                # c_spectrogram: (257, num_frames)
                fft_size, num_frames = c_spectrogram.shape
                num_segments = math.ceil(num_frames / self.segment_num_frames)
                adj_num_frames = num_segments * self.segment_num_frames
                if adj_num_frames > num_frames:
                    c_spectrogram = nn.functional.pad(
                        c_spectrogram, (0, adj_num_frames - num_frames)
                    )
                c_spectrogram = c_spectrogram.view(
                    fft_size, num_segments, self.segment_num_frames
                )
                # c_spectrogram: (257, num_segments, 32)
                c_spectrogram = torch.permute(c_spectrogram, (1, 0, 2))
                # c_spectrogram: (num_segments, 257, 32)
                log_c_spectrogram = self._sp_log(c_spectrogram)
                scaled_log_c_sp = self.scaler(log_c_spectrogram)
                pred_noise_log_sp = batched_apply(
                    self.model, scaled_log_c_sp, detached=True
                )
                log_denoised_sp = log_c_spectrogram - pred_noise_log_sp
                denoised_sp = self._sp_exp(log_denoised_sp)
                # denoised_sp: (num_segments, 257, 32)
                denoised_sp = torch.permute(denoised_sp, (1, 0, 2))
                # denoised_sp: (257, num_segments, 32)
                denoised_sp = denoised_sp.contiguous().view(1, fft_size, adj_num_frames)
                # denoised_sp: (1, 257, adj_num_frames)
                denoised_sp = denoised_sp[:, :, :num_frames]
                denoised_sp = denoised_sp.cpu()
                denoised_waveform = reconstruct_from_spectrogram(
                    denoised_sp, num_iterations=self.num_iterations
                )
                # denoised_waveform: (1, num_samples)
                results.append(denoised_waveform)
            cpu_results = torch.cat(results)
            return cpu_results if return_cpu_tensor else cpu_results.to(self.device)

    def process_audio_file(
        self, in_audio_file: str, out_audio_file: str, auto_scale: bool = False
    ):
        """
        Denoises an audio file.
        @param in_audio_file: An input audio file with a format supported by torchaudio.
        @param out_audio_file: Am output audio file with a format supported by torchaudio.
        @param auto_scale: Whether the input waveform scale should be normalized before processing. Recommended for low-volume audio.
        """
        waveform, sample_rate = torchaudio.load(in_audio_file)
        denoised_waveform = self.process_waveform(
            waveform, sample_rate, return_cpu_tensor=True, auto_scale=auto_scale
        )
        torchaudio.save(
            out_audio_file, denoised_waveform, sample_rate=self.model_sample_rate
        )