File size: 7,100 Bytes
a09029c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import random
import pandas as pd
import torch
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
from .utils import scale_shift_re
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@torch.no_grad()
def inference(autoencoder, unet, gt, gt_mask,
tokenizer, text_encoder,
params, noise_scheduler,
text_raw, neg_text=None,
audio_frames=500,
guidance_scale=3, guidance_rescale=0.0,
ddim_steps=50, eta=1, random_seed=2024,
device='cuda',
):
if neg_text is None:
neg_text = [""]
if tokenizer is not None:
text_batch = tokenizer(text_raw,
max_length=params['text_encoder']['max_length'],
padding="max_length", truncation=True, return_tensors="pt")
text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
uncond_text_batch = tokenizer(neg_text,
max_length=params['text_encoder']['max_length'],
padding="max_length", truncation=True, return_tensors="pt")
uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
uncond_text = text_encoder(input_ids=uncond_text,
attention_mask=uncond_text_mask).last_hidden_state
else:
text, text_mask = None, None
guidance_scale = None
codec_dim = params['model']['out_chans']
unet.eval()
if random_seed is not None:
generator = torch.Generator(device=device).manual_seed(random_seed)
else:
generator = torch.Generator(device=device)
generator.seed()
noise_scheduler.set_timesteps(ddim_steps)
# init noise
noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
latents = noise
for t in noise_scheduler.timesteps:
latents = noise_scheduler.scale_model_input(latents, t)
if guidance_scale:
latents_combined = torch.cat([latents, latents], dim=0)
text_combined = torch.cat([text, uncond_text], dim=0)
text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
if gt is not None:
gt_combined = torch.cat([gt, gt], dim=0)
gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
else:
gt_combined = None
gt_mask_combined = None
output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
if guidance_rescale > 0.0:
output_pred = rescale_noise_cfg(output_pred, output_text,
guidance_rescale=guidance_rescale)
else:
output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
cls_token=None, gt=gt, mae_mask_infer=gt_mask)
latents = noise_scheduler.step(model_output=output_pred, timestep=t,
sample=latents,
eta=eta, generator=generator).prev_sample
pred = scale_shift_re(latents, params['autoencoder']['scale'],
params['autoencoder']['shift'])
if gt is not None:
pred[~gt_mask] = gt[~gt_mask]
pred_wav = autoencoder(embedding=pred)
return pred_wav
@torch.no_grad()
def eval_udit(autoencoder, unet,
tokenizer, text_encoder,
params, noise_scheduler,
val_df, subset,
audio_frames, mae=False,
guidance_scale=3, guidance_rescale=0.0,
ddim_steps=50, eta=1, random_seed=2023,
device='cuda',
epoch=0, save_path='logs/eval/', val_num=5):
val_df = pd.read_csv(val_df)
val_df = val_df[val_df['split'] == subset]
if mae:
val_df = val_df[val_df['audio_length'] != 0]
save_path = save_path + str(epoch) + '/'
os.makedirs(save_path, exist_ok=True)
for i in tqdm(range(len(val_df))):
row = val_df.iloc[i]
text = [row['caption']]
if mae:
audio_path = params['data']['val_dir'] + str(row['audio_path'])
gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
gt = gt / (np.max(np.abs(gt)) + 1e-9)
sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
num_samples = 10 * sr
if len(gt) < num_samples:
padding = num_samples - len(gt)
gt = np.pad(gt, (0, padding), 'constant')
else:
gt = gt[:num_samples]
gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
gt = autoencoder(audio=gt)
B, D, L = gt.shape
mask_len = int(L * 0.2)
gt_mask = torch.zeros(B, D, L).to(device)
for _ in range(2):
start = random.randint(0, L - mask_len)
gt_mask[:, :, start:start + mask_len] = 1
gt_mask = gt_mask.bool()
else:
gt = None
gt_mask = None
pred = inference(autoencoder, unet, gt, gt_mask,
tokenizer, text_encoder,
params, noise_scheduler,
text, neg_text=None,
audio_frames=audio_frames,
guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
device=device)
pred = pred.cpu().numpy().squeeze(0).squeeze(0)
sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
if i + 1 >= val_num:
break
|