Spaces:
Running
on
A10G
Running
on
A10G
import os | |
import argparse | |
import yaml | |
import torch | |
from torch import autocast | |
from tqdm import tqdm, trange | |
from audioldm import LatentDiffusion, seed_everything | |
from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint | |
from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file | |
from audioldm.latent_diffusion.ddim import DDIMSampler | |
from einops import repeat | |
import os | |
def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): | |
text = [text] * batchsize | |
if batchsize < 1: | |
print("Warning: Batchsize must be at least 1. Batchsize is set to .") | |
if(fbank is None): | |
fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format | |
else: | |
fbank = torch.FloatTensor(fbank) | |
fbank = fbank.expand(batchsize, 1024, 64) | |
assert fbank.size(0) == batchsize | |
stft = torch.zeros((batchsize, 1024, 512)) # Not used | |
if(waveform is None): | |
waveform = torch.zeros((batchsize, 160000)) # Not used | |
else: | |
waveform = torch.FloatTensor(waveform) | |
waveform = waveform.expand(batchsize, -1) | |
assert waveform.size(0) == batchsize | |
fname = [""] * batchsize # Not used | |
batch = ( | |
fbank, | |
stft, | |
None, | |
fname, | |
waveform, | |
text, | |
) | |
return batch | |
def round_up_duration(duration): | |
return int(round(duration/2.5) + 1) * 2.5 | |
def build_model( | |
ckpt_path=None, | |
config=None, | |
model_name="audioldm-s-full" | |
): | |
print("Load AudioLDM: %s", model_name) | |
if(ckpt_path is None): | |
ckpt_path = get_metadata()[model_name]["path"] | |
if(not os.path.exists(ckpt_path)): | |
download_checkpoint(model_name) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
else: | |
device = torch.device("cpu") | |
if config is not None: | |
assert type(config) is str | |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) | |
else: | |
config = default_audioldm_config(model_name) | |
# Use text as condition instead of using waveform during training | |
config["model"]["params"]["device"] = device | |
config["model"]["params"]["cond_stage_key"] = "text" | |
# No normalization here | |
latent_diffusion = LatentDiffusion(**config["model"]["params"]) | |
resume_from_checkpoint = ckpt_path | |
checkpoint = torch.load(resume_from_checkpoint, map_location=device) | |
latent_diffusion.load_state_dict(checkpoint["state_dict"]) | |
latent_diffusion.eval() | |
latent_diffusion = latent_diffusion.to(device) | |
latent_diffusion.cond_stage_model.embed_mode = "text" | |
return latent_diffusion | |
def duration_to_latent_t_size(duration): | |
return int(duration * 25.6) | |
def set_cond_audio(latent_diffusion): | |
latent_diffusion.cond_stage_key = "waveform" | |
latent_diffusion.cond_stage_model.embed_mode="audio" | |
return latent_diffusion | |
def set_cond_text(latent_diffusion): | |
latent_diffusion.cond_stage_key = "text" | |
latent_diffusion.cond_stage_model.embed_mode="text" | |
return latent_diffusion | |
def text_to_audio( | |
latent_diffusion, | |
text, | |
original_audio_file_path = None, | |
seed=42, | |
ddim_steps=200, | |
duration=10, | |
batchsize=1, | |
guidance_scale=2.5, | |
n_candidate_gen_per_text=3, | |
config=None, | |
): | |
seed_everything(int(seed)) | |
waveform = None | |
if(original_audio_file_path is not None): | |
waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160) | |
batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) | |
latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) | |
if(waveform is not None): | |
print("Generate audio that has similar content as %s" % original_audio_file_path) | |
latent_diffusion = set_cond_audio(latent_diffusion) | |
else: | |
print("Generate audio using text %s" % text) | |
latent_diffusion = set_cond_text(latent_diffusion) | |
with torch.no_grad(): | |
waveform = latent_diffusion.generate_sample( | |
[batch], | |
unconditional_guidance_scale=guidance_scale, | |
ddim_steps=ddim_steps, | |
n_candidate_gen_per_text=n_candidate_gen_per_text, | |
duration=duration, | |
) | |
return waveform | |
def style_transfer( | |
latent_diffusion, | |
text, | |
original_audio_file_path, | |
transfer_strength, | |
seed=42, | |
duration=10, | |
batchsize=1, | |
guidance_scale=2.5, | |
ddim_steps=200, | |
config=None, | |
): | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
else: | |
device = torch.device("cpu") | |
assert original_audio_file_path is not None, "You need to provide the original audio file path" | |
audio_file_duration = get_duration(original_audio_file_path) | |
assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path | |
# if(duration > 20): | |
# print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds") | |
# duration = 20 | |
if(duration >= audio_file_duration): | |
print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration)) | |
duration = round_up_duration(audio_file_duration) | |
print("Set new duration as %s-seconds" % duration) | |
# duration = round_up_duration(duration) | |
latent_diffusion = set_cond_text(latent_diffusion) | |
if config is not None: | |
assert type(config) is str | |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) | |
else: | |
config = default_audioldm_config() | |
seed_everything(int(seed)) | |
# latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) | |
latent_diffusion.cond_stage_model.embed_mode = "text" | |
fn_STFT = TacotronSTFT( | |
config["preprocessing"]["stft"]["filter_length"], | |
config["preprocessing"]["stft"]["hop_length"], | |
config["preprocessing"]["stft"]["win_length"], | |
config["preprocessing"]["mel"]["n_mel_channels"], | |
config["preprocessing"]["audio"]["sampling_rate"], | |
config["preprocessing"]["mel"]["mel_fmin"], | |
config["preprocessing"]["mel"]["mel_fmax"], | |
) | |
mel, _, _ = wav_to_fbank( | |
original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT | |
) | |
mel = mel.unsqueeze(0).unsqueeze(0).to(device) | |
mel = repeat(mel, "1 ... -> b ...", b=batchsize) | |
init_latent = latent_diffusion.get_first_stage_encoding( | |
latent_diffusion.encode_first_stage(mel) | |
) # move to latent space, encode and sample | |
if(torch.max(torch.abs(init_latent)) > 1e2): | |
init_latent = torch.clip(init_latent, min=-10, max=10) | |
sampler = DDIMSampler(latent_diffusion) | |
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False) | |
t_enc = int(transfer_strength * ddim_steps) | |
prompts = text | |
with torch.no_grad(): | |
with autocast("cuda"): | |
with latent_diffusion.ema_scope(): | |
uc = None | |
if guidance_scale != 1.0: | |
uc = latent_diffusion.cond_stage_model.get_unconditional_condition( | |
batchsize | |
) | |
c = latent_diffusion.get_learned_conditioning([prompts] * batchsize) | |
z_enc = sampler.stochastic_encode( | |
init_latent, torch.tensor([t_enc] * batchsize).to(device) | |
) | |
samples = sampler.decode( | |
z_enc, | |
c, | |
t_enc, | |
unconditional_guidance_scale=guidance_scale, | |
unconditional_conditioning=uc, | |
) | |
# x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output | |
# print(torch.sum(torch.isnan(samples))) | |
x_samples = latent_diffusion.decode_first_stage(samples) | |
# print(x_samples) | |
x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:]) | |
# print(x_samples) | |
waveform = latent_diffusion.first_stage_model.decode_to_waveform( | |
x_samples | |
) | |
return waveform | |
def super_resolution_and_inpainting( | |
latent_diffusion, | |
text, | |
original_audio_file_path = None, | |
seed=42, | |
ddim_steps=200, | |
duration=None, | |
batchsize=1, | |
guidance_scale=2.5, | |
n_candidate_gen_per_text=3, | |
time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram | |
# time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting | |
# freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins | |
freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution | |
config=None, | |
): | |
seed_everything(int(seed)) | |
if config is not None: | |
assert type(config) is str | |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) | |
else: | |
config = default_audioldm_config() | |
fn_STFT = TacotronSTFT( | |
config["preprocessing"]["stft"]["filter_length"], | |
config["preprocessing"]["stft"]["hop_length"], | |
config["preprocessing"]["stft"]["win_length"], | |
config["preprocessing"]["mel"]["n_mel_channels"], | |
config["preprocessing"]["audio"]["sampling_rate"], | |
config["preprocessing"]["mel"]["mel_fmin"], | |
config["preprocessing"]["mel"]["mel_fmax"], | |
) | |
# waveform = read_wav_file(original_audio_file_path, None) | |
mel, _, _ = wav_to_fbank( | |
original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT | |
) | |
batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize) | |
# latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) | |
latent_diffusion = set_cond_text(latent_diffusion) | |
with torch.no_grad(): | |
waveform = latent_diffusion.generate_sample_masked( | |
[batch], | |
unconditional_guidance_scale=guidance_scale, | |
ddim_steps=ddim_steps, | |
n_candidate_gen_per_text=n_candidate_gen_per_text, | |
duration=duration, | |
time_mask_ratio_start_and_end=time_mask_ratio_start_and_end, | |
freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end | |
) | |
return waveform |