maskgct / models /codec /facodec /facodec_inference.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
4.56 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import shutil
import warnings
import argparse
import torch
import os
import yaml
warnings.simplefilter("ignore")
from .modules.commons import *
import time
import torchaudio
import librosa
from collections import OrderedDict
class FAcodecInference(object):
def __init__(self, args=None, cfg=None):
self.args = args
self.cfg = cfg
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self._build_model()
self._load_checkpoint()
def _build_model(self):
model = build_model(self.cfg.model_params)
_ = [model[key].to(self.device) for key in model]
return model
def _load_checkpoint(self):
sd = torch.load(self.args.checkpoint_path, map_location="cpu")
sd = sd["net"] if "net" in sd else sd
new_params = dict()
for key, state_dict in sd.items():
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:]
new_state_dict[k] = v
new_params[key] = new_state_dict
for key in new_params:
if key in self.model:
self.model[key].load_state_dict(new_params[key])
_ = [self.model[key].eval() for key in self.model]
@torch.no_grad()
def inference(self, source, output_dir):
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
z = self.model.encoder(source_audio[None, ...].to(self.device).float())
(
z,
quantized,
commitment_loss,
codebook_loss,
timbre,
codes,
) = self.model.quantizer(
z,
source_audio[None, ...].to(self.device).float(),
n_c=self.cfg.model_params.n_c_codebooks,
return_codes=True,
)
full_pred_wave = self.model.decoder(z)
os.makedirs(output_dir, exist_ok=True)
source_name = source.split("/")[-1].split(".")[0]
torchaudio.save(
f"{output_dir}/reconstructed_{source_name}.wav",
full_pred_wave[0].cpu(),
self.cfg.preprocess_params.sr,
)
print(
"Reconstructed audio saved as: ",
f"{output_dir}/reconstructed_{source_name}.wav",
)
return quantized, codes
@torch.no_grad()
def voice_conversion(self, source, reference, output_dir):
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
reference_audio = librosa.load(reference, sr=self.cfg.preprocess_params.sr)[0]
reference_audio = (
torch.tensor(reference_audio).unsqueeze(0).float().to(self.device)
)
z = self.model.encoder(source_audio[None, ...].to(self.device).float())
z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
z,
source_audio[None, ...].to(self.device).float(),
n_c=self.cfg.model_params.n_c_codebooks,
)
z_ref = self.model.encoder(reference_audio[None, ...].to(self.device).float())
(
z_ref,
quantized_ref,
commitment_loss_ref,
codebook_loss_ref,
timbre_ref,
) = self.model.quantizer(
z_ref,
reference_audio[None, ...].to(self.device).float(),
n_c=self.cfg.model_params.n_c_codebooks,
)
z_conv = self.model.quantizer.voice_conversion(
quantized[0] + quantized[1],
reference_audio[None, ...].to(self.device).float(),
)
full_pred_wave = self.model.decoder(z_conv)
os.makedirs(output_dir, exist_ok=True)
source_name = source.split("/")[-1].split(".")[0]
reference_name = reference.split("/")[-1].split(".")[0]
torchaudio.save(
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
full_pred_wave[0].cpu(),
self.cfg.preprocess_params.sr,
)
print(
"Voice conversion results saved as: ",
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
)