|
import torch |
|
import torchaudio |
|
from transformers import Pipeline |
|
from librosa import resample |
|
from soundfile import write |
|
from sgmse.model import ScoreModel |
|
from sgmse.util.other import pad_spec |
|
|
|
class CustomSpeechEnhancementPipeline(Pipeline): |
|
def __init__(self, model, target_sr=16000, pad_mode="zero_pad", args=None): |
|
""" |
|
Custom pipeline for speech enhancement using ScoreModel. |
|
|
|
Args: |
|
model: The speech enhancement model loaded from a checkpoint (ScoreModel). |
|
target_sr: Target sample rate for the input audio (default is 16 kHz). |
|
pad_mode: Padding mode for spectrogram (default is "zero_pad"). |
|
args: Parsed arguments (device, corrector, corrector_steps, snr, etc.). |
|
""" |
|
super().__init__(model=model) |
|
self.target_sr = target_sr |
|
self.pad_mode = pad_mode |
|
self.args = args |
|
|
|
def preprocess(self, audio_path): |
|
|
|
y, sr = torchaudio.load(audio_path) |
|
|
|
|
|
if sr != self.target_sr: |
|
y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=self.target_sr)) |
|
|
|
|
|
norm_factor = y.abs().max() |
|
y = y / norm_factor |
|
|
|
|
|
Y = torch.unsqueeze(self.model._forward_transform(self.model._stft(y.to(self.args.device))), 0) |
|
Y = pad_spec(Y, mode=self.pad_mode) |
|
|
|
return Y, norm_factor, y.size(1) |
|
|
|
def _forward(self, model_inputs): |
|
Y, norm_factor, T_orig = model_inputs |
|
|
|
|
|
sampler = self.model.get_pc_sampler( |
|
'reverse_diffusion', |
|
self.args.corrector, |
|
Y.to(self.args.device), |
|
N=self.args.N, |
|
corrector_steps=self.args.corrector_steps, |
|
snr=self.args.snr |
|
) |
|
|
|
|
|
sample, _ = sampler() |
|
|
|
|
|
x_hat = self.model.to_audio(sample.squeeze(), T_orig) |
|
|
|
|
|
x_hat = x_hat * norm_factor |
|
|
|
return x_hat |
|
|
|
def postprocess(self, model_outputs): |
|
|
|
return model_outputs.cpu().numpy() |
|
|
|
def pad_spec(self, Y): |
|
""" |
|
Apply padding to the spectrogram as per the model's required padding mode. |
|
|
|
Args: |
|
Y: Input spectrogram tensor. |
|
|
|
Returns: |
|
Padded spectrogram. |
|
""" |
|
|
|
return torch.nn.functional.pad(Y, (0, 0, 0, 1), mode=self.pad_mode) |
|
|