File size: 1,801 Bytes
96e64e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from BigVGAN.meldataset import get_mel_spectrogram
from voice_restore import VoiceRestore


class OptimizedAudioRestorationModel(torch.nn.Module):
    def __init__(self, target_sample_rate=24000, device=None, bigvgan_model=None):
        super().__init__()

        # Initialize VoiceRestore
        self.voice_restore = VoiceRestore(
            sigma=0.0, 
            transformer={
                'dim': 768, 
                'depth': 20, 
                'heads': 16, 
                'dim_head': 64,
                'skip_connect_type': 'concat', 
                'max_seq_len': 2000,
            }, 
            num_channels=100
        )  
        
        self.device = device
        if self.device == 'cuda':
            self.voice_restore.bfloat16()
        self.voice_restore.eval()
        self.voice_restore.to(self.device)
        self.target_sample_rate = target_sample_rate
        self.bigvgan_model = bigvgan_model
        


    def forward(self, audio, steps=32, cfg_strength=0.5):
        # Convert to Mel-spectrogram

        if self.bigvgan_model is None:
            raise ValueError("BigVGAN model is not provided. Please provide the BigVGAN model.")
        
        if self.device is None:
            raise ValueError("Device is not provided. Please provide the device (cuda, cpu or mps).")

        processed_mel = get_mel_spectrogram(audio, self.bigvgan_model.h).to(self.device)

        # Restore audio
        restored_mel = self.voice_restore.sample(processed_mel.transpose(1, 2), steps=steps, cfg_strength=cfg_strength)
        restored_mel = restored_mel.squeeze(0).transpose(0, 1)
        
        # Convert restored mel-spectrogram to waveform
        restored_wav = self.bigvgan_model(restored_mel.unsqueeze(0))
        
        return restored_wav