Spaces:
Paused
Paused
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from ...core import AudioSignal | |
from ...core import STFTParams | |
from ...core import util | |
class SpectralGate(nn.Module): | |
"""Spectral gating algorithm for noise reduction, | |
as in Audacity/Ocenaudio. The steps are as follows: | |
1. An FFT is calculated over the noise audio clip | |
2. Statistics are calculated over FFT of the the noise | |
(in frequency) | |
3. A threshold is calculated based upon the statistics | |
of the noise (and the desired sensitivity of the algorithm) | |
4. An FFT is calculated over the signal | |
5. A mask is determined by comparing the signal FFT to the | |
threshold | |
6. The mask is smoothed with a filter over frequency and time | |
7. The mask is appled to the FFT of the signal, and is inverted | |
Implementation inspired by Tim Sainburg's noisereduce: | |
https://timsainburg.com/noise-reduction-python.html | |
Parameters | |
---------- | |
n_freq : int, optional | |
Number of frequency bins to smooth by, by default 3 | |
n_time : int, optional | |
Number of time bins to smooth by, by default 5 | |
""" | |
def __init__(self, n_freq: int = 3, n_time: int = 5): | |
super().__init__() | |
smoothing_filter = torch.outer( | |
torch.cat( | |
[ | |
torch.linspace(0, 1, n_freq + 2)[:-1], | |
torch.linspace(1, 0, n_freq + 2), | |
] | |
)[..., 1:-1], | |
torch.cat( | |
[ | |
torch.linspace(0, 1, n_time + 2)[:-1], | |
torch.linspace(1, 0, n_time + 2), | |
] | |
)[..., 1:-1], | |
) | |
smoothing_filter = smoothing_filter / smoothing_filter.sum() | |
smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0) | |
self.register_buffer("smoothing_filter", smoothing_filter) | |
def forward( | |
self, | |
audio_signal: AudioSignal, | |
nz_signal: AudioSignal, | |
denoise_amount: float = 1.0, | |
n_std: float = 3.0, | |
win_length: int = 2048, | |
hop_length: int = 512, | |
): | |
"""Perform noise reduction. | |
Parameters | |
---------- | |
audio_signal : AudioSignal | |
Audio signal that noise will be removed from. | |
nz_signal : AudioSignal, optional | |
Noise signal to compute noise statistics from. | |
denoise_amount : float, optional | |
Amount to denoise by, by default 1.0 | |
n_std : float, optional | |
Number of standard deviations above which to consider | |
noise, by default 3.0 | |
win_length : int, optional | |
Length of window for STFT, by default 2048 | |
hop_length : int, optional | |
Hop length for STFT, by default 512 | |
Returns | |
------- | |
AudioSignal | |
Denoised audio signal. | |
""" | |
stft_params = STFTParams(win_length, hop_length, "sqrt_hann") | |
audio_signal = audio_signal.clone() | |
audio_signal.stft_data = None | |
audio_signal.stft_params = stft_params | |
nz_signal = nz_signal.clone() | |
nz_signal.stft_params = stft_params | |
nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10() | |
nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1) | |
nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1) | |
nz_thresh = nz_freq_mean + nz_freq_std * n_std | |
stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10() | |
nb, nac, nf, nt = stft_db.shape | |
db_thresh = nz_thresh.expand(nb, nac, -1, nt) | |
stft_mask = (stft_db < db_thresh).float() | |
shape = stft_mask.shape | |
stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt) | |
pad_tuple = ( | |
self.smoothing_filter.shape[-2] // 2, | |
self.smoothing_filter.shape[-1] // 2, | |
) | |
stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple) | |
stft_mask = stft_mask.reshape(*shape) | |
stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to( | |
audio_signal.device | |
) | |
stft_mask = 1 - stft_mask | |
audio_signal.stft_data *= stft_mask | |
audio_signal.istft() | |
return audio_signal | |