|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
import torch.nn as nn |
|
from collections import OrderedDict |
|
import json |
|
|
|
from models.tta.autoencoder.autoencoder import AutoencoderKL |
|
from models.tta.ldm.inference_utils.vocoder import Generator |
|
from models.tta.ldm.audioldm import AudioLDM |
|
from transformers import T5EncoderModel, AutoTokenizer |
|
from diffusers import PNDMScheduler |
|
|
|
import matplotlib.pyplot as plt |
|
from scipy.io.wavfile import write |
|
|
|
|
|
class AttrDict(dict): |
|
def __init__(self, *args, **kwargs): |
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
|
|
class AudioLDMInference: |
|
def __init__(self, args, cfg): |
|
self.cfg = cfg |
|
self.args = args |
|
|
|
self.build_autoencoderkl() |
|
self.build_textencoder() |
|
|
|
self.model = self.build_model() |
|
self.load_state_dict() |
|
|
|
self.build_vocoder() |
|
|
|
self.out_path = self.args.output_dir |
|
self.out_mel_path = os.path.join(self.out_path, "mel") |
|
self.out_wav_path = os.path.join(self.out_path, "wav") |
|
os.makedirs(self.out_mel_path, exist_ok=True) |
|
os.makedirs(self.out_wav_path, exist_ok=True) |
|
|
|
def build_autoencoderkl(self): |
|
self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl) |
|
self.autoencoder_path = self.cfg.model.autoencoder_path |
|
checkpoint = torch.load(self.autoencoder_path, map_location="cpu") |
|
self.autoencoderkl.load_state_dict(checkpoint["model"]) |
|
self.autoencoderkl.cuda(self.args.local_rank) |
|
self.autoencoderkl.requires_grad_(requires_grad=False) |
|
self.autoencoderkl.eval() |
|
|
|
def build_textencoder(self): |
|
self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) |
|
self.text_encoder = T5EncoderModel.from_pretrained("t5-base") |
|
self.text_encoder.cuda(self.args.local_rank) |
|
self.text_encoder.requires_grad_(requires_grad=False) |
|
self.text_encoder.eval() |
|
|
|
def build_vocoder(self): |
|
config_file = os.path.join(self.args.vocoder_config_path) |
|
with open(config_file) as f: |
|
data = f.read() |
|
json_config = json.loads(data) |
|
h = AttrDict(json_config) |
|
self.vocoder = Generator(h).to(self.args.local_rank) |
|
checkpoint_dict = torch.load( |
|
self.args.vocoder_path, map_location=self.args.local_rank |
|
) |
|
self.vocoder.load_state_dict(checkpoint_dict["generator"]) |
|
|
|
def build_model(self): |
|
self.model = AudioLDM(self.cfg.model.audioldm) |
|
return self.model |
|
|
|
def load_state_dict(self): |
|
self.checkpoint_path = self.args.checkpoint_path |
|
checkpoint = torch.load(self.checkpoint_path, map_location="cpu") |
|
self.model.load_state_dict(checkpoint["model"]) |
|
self.model.cuda(self.args.local_rank) |
|
|
|
def get_text_embedding(self): |
|
text = self.args.text |
|
|
|
prompt = [text] |
|
|
|
text_input = self.tokenizer( |
|
prompt, |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
padding="do_not_pad", |
|
return_tensors="pt", |
|
) |
|
text_embeddings = self.text_encoder( |
|
text_input.input_ids.to(self.args.local_rank) |
|
)[0] |
|
|
|
max_length = text_input.input_ids.shape[-1] |
|
uncond_input = self.tokenizer( |
|
[""] * 1, padding="max_length", max_length=max_length, return_tensors="pt" |
|
) |
|
uncond_embeddings = self.text_encoder( |
|
uncond_input.input_ids.to(self.args.local_rank) |
|
)[0] |
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
|
return text_embeddings |
|
|
|
def inference(self): |
|
text_embeddings = self.get_text_embedding() |
|
print(text_embeddings.shape) |
|
|
|
num_steps = self.args.num_steps |
|
guidance_scale = self.args.guidance_scale |
|
|
|
noise_scheduler = PNDMScheduler( |
|
num_train_timesteps=1000, |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
skip_prk_steps=True, |
|
set_alpha_to_one=False, |
|
steps_offset=1, |
|
prediction_type="epsilon", |
|
) |
|
|
|
noise_scheduler.set_timesteps(num_steps) |
|
|
|
latents = torch.randn( |
|
( |
|
1, |
|
self.cfg.model.autoencoderkl.z_channels, |
|
80 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)), |
|
624 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)), |
|
) |
|
).to(self.args.local_rank) |
|
|
|
self.model.eval() |
|
for t in tqdm(noise_scheduler.timesteps): |
|
t = t.to(self.args.local_rank) |
|
|
|
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
|
latent_model_input = noise_scheduler.scale_model_input( |
|
latent_model_input, timestep=t |
|
) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
noise_pred = self.model( |
|
latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings |
|
) |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
|
|
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
latents_out = latents |
|
print(latents_out.shape) |
|
|
|
with torch.no_grad(): |
|
mel_out = self.autoencoderkl.decode(latents_out) |
|
print(mel_out.shape) |
|
|
|
melspec = mel_out[0, 0].cpu().detach().numpy() |
|
plt.imsave(os.path.join(self.out_mel_path, self.args.text + ".png"), melspec) |
|
|
|
self.vocoder.eval() |
|
self.vocoder.remove_weight_norm() |
|
with torch.no_grad(): |
|
melspec = np.expand_dims(melspec, 0) |
|
melspec = torch.FloatTensor(melspec).to(self.args.local_rank) |
|
|
|
y = self.vocoder(melspec) |
|
audio = y.squeeze() |
|
audio = audio * 32768.0 |
|
audio = audio.cpu().numpy().astype("int16") |
|
|
|
write(os.path.join(self.out_wav_path, self.args.text + ".wav"), 16000, audio) |
|
|