|
import logging |
|
import math |
|
from typing import Union |
|
|
|
import torch |
|
import torchaudio |
|
from audio_denoiser.helpers.audio_helper import ( |
|
create_spectrogram, |
|
reconstruct_from_spectrogram, |
|
) |
|
from audio_denoiser.helpers.torch_helper import batched_apply |
|
from torch import nn |
|
|
|
from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model |
|
|
|
_expected_t_std = 0.23 |
|
_recommended_backend = "soundfile" |
|
|
|
|
|
|
|
class AudioDenoiser: |
|
def __init__( |
|
self, |
|
local_dir: str, |
|
device: Union[str, torch.device] = None, |
|
num_iterations: int = 100, |
|
): |
|
super().__init__() |
|
if device is None: |
|
is_cuda = torch.cuda.is_available() |
|
if not is_cuda: |
|
logging.warning("CUDA not available. Will use CPU.") |
|
device = torch.device("cuda:0") if is_cuda else torch.device("cpu") |
|
self.device = device |
|
self.model = load_audio_denosier_model(dir_path=local_dir, device=device) |
|
self.model.eval() |
|
self.model_sample_rate = self.model.sample_rate |
|
self.scaler = self.model.scaler |
|
self.n_fft = self.model.n_fft |
|
self.segment_num_frames = self.model.num_frames |
|
self.num_iterations = num_iterations |
|
|
|
@staticmethod |
|
def _sp_log(spectrogram: torch.Tensor, eps=0.01): |
|
return torch.log(spectrogram + eps) |
|
|
|
@staticmethod |
|
def _sp_exp(log_spectrogram: torch.Tensor, eps=0.01): |
|
return torch.clamp(torch.exp(log_spectrogram) - eps, min=0) |
|
|
|
@staticmethod |
|
def _trimmed_dev(waveform: torch.Tensor, q: float = 0.90) -> float: |
|
|
|
abs_waveform = torch.abs(waveform) |
|
quantile_value = torch.quantile(abs_waveform, q).item() |
|
trimmed_values = waveform[abs_waveform >= quantile_value] |
|
return torch.std(trimmed_values).item() |
|
|
|
def process_waveform( |
|
self, |
|
waveform: torch.Tensor, |
|
sample_rate: int, |
|
return_cpu_tensor: bool = False, |
|
auto_scale: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
Denoises a waveform. |
|
@param waveform: A waveform tensor. Use torchaudio structure. |
|
@param sample_rate: The sample rate of the waveform in Hz. |
|
@param return_cpu_tensor: Whether the returned tensor must be a CPU tensor. |
|
@param auto_scale: Normalize the scale of the waveform before processing. Recommended for low-volume audio. |
|
@return: A denoised waveform. |
|
""" |
|
waveform = waveform.cpu() |
|
if auto_scale: |
|
w_t_std = self._trimmed_dev(waveform) |
|
waveform = waveform * _expected_t_std / w_t_std |
|
if sample_rate != self.model_sample_rate: |
|
transform = torchaudio.transforms.Resample( |
|
orig_freq=sample_rate, new_freq=self.model_sample_rate |
|
) |
|
waveform = transform(waveform) |
|
hop_len = self.n_fft // 2 |
|
spectrogram = create_spectrogram(waveform, n_fft=self.n_fft, hop_length=hop_len) |
|
spectrogram = spectrogram.to(self.device) |
|
num_a_channels = spectrogram.size(0) |
|
with torch.no_grad(): |
|
results = [] |
|
for c in range(num_a_channels): |
|
c_spectrogram = spectrogram[c] |
|
|
|
fft_size, num_frames = c_spectrogram.shape |
|
num_segments = math.ceil(num_frames / self.segment_num_frames) |
|
adj_num_frames = num_segments * self.segment_num_frames |
|
if adj_num_frames > num_frames: |
|
c_spectrogram = nn.functional.pad( |
|
c_spectrogram, (0, adj_num_frames - num_frames) |
|
) |
|
c_spectrogram = c_spectrogram.view( |
|
fft_size, num_segments, self.segment_num_frames |
|
) |
|
|
|
c_spectrogram = torch.permute(c_spectrogram, (1, 0, 2)) |
|
|
|
log_c_spectrogram = self._sp_log(c_spectrogram) |
|
scaled_log_c_sp = self.scaler(log_c_spectrogram) |
|
pred_noise_log_sp = batched_apply( |
|
self.model, scaled_log_c_sp, detached=True |
|
) |
|
log_denoised_sp = log_c_spectrogram - pred_noise_log_sp |
|
denoised_sp = self._sp_exp(log_denoised_sp) |
|
|
|
denoised_sp = torch.permute(denoised_sp, (1, 0, 2)) |
|
|
|
denoised_sp = denoised_sp.contiguous().view(1, fft_size, adj_num_frames) |
|
|
|
denoised_sp = denoised_sp[:, :, :num_frames] |
|
denoised_sp = denoised_sp.cpu() |
|
denoised_waveform = reconstruct_from_spectrogram( |
|
denoised_sp, num_iterations=self.num_iterations |
|
) |
|
|
|
results.append(denoised_waveform) |
|
cpu_results = torch.cat(results) |
|
return cpu_results if return_cpu_tensor else cpu_results.to(self.device) |
|
|
|
def process_audio_file( |
|
self, in_audio_file: str, out_audio_file: str, auto_scale: bool = False |
|
): |
|
""" |
|
Denoises an audio file. |
|
@param in_audio_file: An input audio file with a format supported by torchaudio. |
|
@param out_audio_file: Am output audio file with a format supported by torchaudio. |
|
@param auto_scale: Whether the input waveform scale should be normalized before processing. Recommended for low-volume audio. |
|
""" |
|
waveform, sample_rate = torchaudio.load(in_audio_file) |
|
denoised_waveform = self.process_waveform( |
|
waveform, sample_rate, return_cpu_tensor=True, auto_scale=auto_scale |
|
) |
|
torchaudio.save( |
|
out_audio_file, denoised_waveform, sample_rate=self.model_sample_rate |
|
) |
|
|