File size: 1,801 Bytes
96e64e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch
from BigVGAN.meldataset import get_mel_spectrogram
from voice_restore import VoiceRestore
class OptimizedAudioRestorationModel(torch.nn.Module):
def __init__(self, target_sample_rate=24000, device=None, bigvgan_model=None):
super().__init__()
# Initialize VoiceRestore
self.voice_restore = VoiceRestore(
sigma=0.0,
transformer={
'dim': 768,
'depth': 20,
'heads': 16,
'dim_head': 64,
'skip_connect_type': 'concat',
'max_seq_len': 2000,
},
num_channels=100
)
self.device = device
if self.device == 'cuda':
self.voice_restore.bfloat16()
self.voice_restore.eval()
self.voice_restore.to(self.device)
self.target_sample_rate = target_sample_rate
self.bigvgan_model = bigvgan_model
def forward(self, audio, steps=32, cfg_strength=0.5):
# Convert to Mel-spectrogram
if self.bigvgan_model is None:
raise ValueError("BigVGAN model is not provided. Please provide the BigVGAN model.")
if self.device is None:
raise ValueError("Device is not provided. Please provide the device (cuda, cpu or mps).")
processed_mel = get_mel_spectrogram(audio, self.bigvgan_model.h).to(self.device)
# Restore audio
restored_mel = self.voice_restore.sample(processed_mel.transpose(1, 2), steps=steps, cfg_strength=cfg_strength)
restored_mel = restored_mel.squeeze(0).transpose(0, 1)
# Convert restored mel-spectrogram to waveform
restored_wav = self.bigvgan_model(restored_mel.unsqueeze(0))
return restored_wav |