speech-enhancement-sgmse / custom_pipeline.py
Shokoufeh
Add custom pipeline file
2aa6704
import torch
import torchaudio
from transformers import Pipeline
from librosa import resample
from soundfile import write
from sgmse.model import ScoreModel
from sgmse.util.other import pad_spec
class CustomSpeechEnhancementPipeline(Pipeline):
def __init__(self, model, target_sr=16000, pad_mode="zero_pad", args=None):
"""
Custom pipeline for speech enhancement using ScoreModel.
Args:
model: The speech enhancement model loaded from a checkpoint (ScoreModel).
target_sr: Target sample rate for the input audio (default is 16 kHz).
pad_mode: Padding mode for spectrogram (default is "zero_pad").
args: Parsed arguments (device, corrector, corrector_steps, snr, etc.).
"""
super().__init__(model=model)
self.target_sr = target_sr
self.pad_mode = pad_mode
self.args = args
def preprocess(self, audio_path):
# Load the audio file
y, sr = torchaudio.load(audio_path)
# Resample if necessary
if sr != self.target_sr:
y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=self.target_sr))
# Normalize the audio
norm_factor = y.abs().max()
y = y / norm_factor
# Prepare the input for the model by transforming to the frequency domain
Y = torch.unsqueeze(self.model._forward_transform(self.model._stft(y.to(self.args.device))), 0)
Y = pad_spec(Y, mode=self.pad_mode)
return Y, norm_factor, y.size(1) # Return input spec, normalization factor, and original length
def _forward(self, model_inputs):
Y, norm_factor, T_orig = model_inputs
# Perform reverse sampling using the model's PC sampler
sampler = self.model.get_pc_sampler(
'reverse_diffusion',
self.args.corrector,
Y.to(self.args.device),
N=self.args.N,
corrector_steps=self.args.corrector_steps,
snr=self.args.snr
)
# Get the enhanced speech sample
sample, _ = sampler()
# Convert back to time domain
x_hat = self.model.to_audio(sample.squeeze(), T_orig)
# Renormalize the audio
x_hat = x_hat * norm_factor
return x_hat
def postprocess(self, model_outputs):
# Convert the enhanced output back to NumPy for further processing or saving
return model_outputs.cpu().numpy()
def pad_spec(self, Y):
"""
Apply padding to the spectrogram as per the model's required padding mode.
Args:
Y: Input spectrogram tensor.
Returns:
Padded spectrogram.
"""
# Implement padding as per the provided mode
return torch.nn.functional.pad(Y, (0, 0, 0, 1), mode=self.pad_mode)