speech-enhancement-sgmse / custom_pipeline.py
Shokoufeh
Add custom pipeline file
2aa6704
raw
history blame
2.82 kB
import torch
import torchaudio
from transformers import Pipeline
from librosa import resample
from soundfile import write
from sgmse.model import ScoreModel
from sgmse.util.other import pad_spec
class CustomSpeechEnhancementPipeline(Pipeline):
def __init__(self, model, target_sr=16000, pad_mode="zero_pad", args=None):
"""
Custom pipeline for speech enhancement using ScoreModel.
Args:
model: The speech enhancement model loaded from a checkpoint (ScoreModel).
target_sr: Target sample rate for the input audio (default is 16 kHz).
pad_mode: Padding mode for spectrogram (default is "zero_pad").
args: Parsed arguments (device, corrector, corrector_steps, snr, etc.).
"""
super().__init__(model=model)
self.target_sr = target_sr
self.pad_mode = pad_mode
self.args = args
def preprocess(self, audio_path):
# Load the audio file
y, sr = torchaudio.load(audio_path)
# Resample if necessary
if sr != self.target_sr:
y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=self.target_sr))
# Normalize the audio
norm_factor = y.abs().max()
y = y / norm_factor
# Prepare the input for the model by transforming to the frequency domain
Y = torch.unsqueeze(self.model._forward_transform(self.model._stft(y.to(self.args.device))), 0)
Y = pad_spec(Y, mode=self.pad_mode)
return Y, norm_factor, y.size(1) # Return input spec, normalization factor, and original length
def _forward(self, model_inputs):
Y, norm_factor, T_orig = model_inputs
# Perform reverse sampling using the model's PC sampler
sampler = self.model.get_pc_sampler(
'reverse_diffusion',
self.args.corrector,
Y.to(self.args.device),
N=self.args.N,
corrector_steps=self.args.corrector_steps,
snr=self.args.snr
)
# Get the enhanced speech sample
sample, _ = sampler()
# Convert back to time domain
x_hat = self.model.to_audio(sample.squeeze(), T_orig)
# Renormalize the audio
x_hat = x_hat * norm_factor
return x_hat
def postprocess(self, model_outputs):
# Convert the enhanced output back to NumPy for further processing or saving
return model_outputs.cpu().numpy()
def pad_spec(self, Y):
"""
Apply padding to the spectrogram as per the model's required padding mode.
Args:
Y: Input spectrogram tensor.
Returns:
Padded spectrogram.
"""
# Implement padding as per the provided mode
return torch.nn.functional.pad(Y, (0, 0, 0, 1), mode=self.pad_mode)