Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import shutil | |
import warnings | |
import argparse | |
import torch | |
import os | |
import yaml | |
warnings.simplefilter("ignore") | |
from .modules.commons import * | |
import time | |
import torchaudio | |
import librosa | |
from collections import OrderedDict | |
class FAcodecInference(object): | |
def __init__(self, args=None, cfg=None): | |
self.args = args | |
self.cfg = cfg | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = self._build_model() | |
self._load_checkpoint() | |
def _build_model(self): | |
model = build_model(self.cfg.model_params) | |
_ = [model[key].to(self.device) for key in model] | |
return model | |
def _load_checkpoint(self): | |
sd = torch.load(self.args.checkpoint_path, map_location="cpu") | |
sd = sd["net"] if "net" in sd else sd | |
new_params = dict() | |
for key, state_dict in sd.items(): | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
if k.startswith("module."): | |
k = k[7:] | |
new_state_dict[k] = v | |
new_params[key] = new_state_dict | |
for key in new_params: | |
if key in self.model: | |
self.model[key].load_state_dict(new_params[key]) | |
_ = [self.model[key].eval() for key in self.model] | |
def inference(self, source, output_dir): | |
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0] | |
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device) | |
z = self.model.encoder(source_audio[None, ...].to(self.device).float()) | |
( | |
z, | |
quantized, | |
commitment_loss, | |
codebook_loss, | |
timbre, | |
codes, | |
) = self.model.quantizer( | |
z, | |
source_audio[None, ...].to(self.device).float(), | |
n_c=self.cfg.model_params.n_c_codebooks, | |
return_codes=True, | |
) | |
full_pred_wave = self.model.decoder(z) | |
os.makedirs(output_dir, exist_ok=True) | |
source_name = source.split("/")[-1].split(".")[0] | |
torchaudio.save( | |
f"{output_dir}/reconstructed_{source_name}.wav", | |
full_pred_wave[0].cpu(), | |
self.cfg.preprocess_params.sr, | |
) | |
print( | |
"Reconstructed audio saved as: ", | |
f"{output_dir}/reconstructed_{source_name}.wav", | |
) | |
return quantized, codes | |
def voice_conversion(self, source, reference, output_dir): | |
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0] | |
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device) | |
reference_audio = librosa.load(reference, sr=self.cfg.preprocess_params.sr)[0] | |
reference_audio = ( | |
torch.tensor(reference_audio).unsqueeze(0).float().to(self.device) | |
) | |
z = self.model.encoder(source_audio[None, ...].to(self.device).float()) | |
z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer( | |
z, | |
source_audio[None, ...].to(self.device).float(), | |
n_c=self.cfg.model_params.n_c_codebooks, | |
) | |
z_ref = self.model.encoder(reference_audio[None, ...].to(self.device).float()) | |
( | |
z_ref, | |
quantized_ref, | |
commitment_loss_ref, | |
codebook_loss_ref, | |
timbre_ref, | |
) = self.model.quantizer( | |
z_ref, | |
reference_audio[None, ...].to(self.device).float(), | |
n_c=self.cfg.model_params.n_c_codebooks, | |
) | |
z_conv = self.model.quantizer.voice_conversion( | |
quantized[0] + quantized[1], | |
reference_audio[None, ...].to(self.device).float(), | |
) | |
full_pred_wave = self.model.decoder(z_conv) | |
os.makedirs(output_dir, exist_ok=True) | |
source_name = source.split("/")[-1].split(".")[0] | |
reference_name = reference.split("/")[-1].split(".")[0] | |
torchaudio.save( | |
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav", | |
full_pred_wave[0].cpu(), | |
self.cfg.preprocess_params.sr, | |
) | |
print( | |
"Voice conversion results saved as: ", | |
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav", | |
) | |