diff --git a/api.py b/api.py new file mode 100644 index 0000000000000000000000000000000000000000..531c8197686ff15be7e0a00b065a3309bd9a781c --- /dev/null +++ b/api.py @@ -0,0 +1,117 @@ +import os +import torch +import random +import numpy as np +import gradio as gr +import soundfile as sf +from transformers import T5Tokenizer, T5EncoderModel +from diffusers import DDIMScheduler +from src.models.conditioners import MaskDiT +from src.modules.autoencoder_wrapper import Autoencoder +from src.inference import inference +from src.utils import load_yaml_with_includes + + +# Load model and configs +def load_models(config_name, ckpt_path, vae_path, device): + params = load_yaml_with_includes(config_name) + + # Load codec model + autoencoder = Autoencoder(ckpt_path=vae_path, + model_type=params['autoencoder']['name'], + quantization_first=params['autoencoder']['q_first']).to(device) + autoencoder.eval() + + # Load text encoder + tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model']) + text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device) + text_encoder.eval() + + # Load main U-Net model + unet = MaskDiT(**params['model']).to(device) + unet.load_state_dict(torch.load(ckpt_path)['model']) + unet.eval() + + # Load noise scheduler + noise_scheduler = DDIMScheduler(**params['diff']) + + return autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params + +MAX_SEED = np.iinfo(np.int32).max + +# Model and config paths +config_name = 'ckpts/ezaudio-xl.yml' +ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt' +vae_path = 'ckpts/vae/1m.pt' +save_path = 'output/' +os.makedirs(save_path, exist_ok=True) + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path, + device) + +latents = torch.randn((1, 128, 128), device=device) +noise = torch.randn_like(latents) +timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device) +_ = noise_scheduler.add_noise(latents, noise, timesteps) + + +# Inference function +def generate_audio(text, length, + guidance_scale, guidance_rescale, ddim_steps, eta, + random_seed, randomize_seed): + neg_text = None + length = length * params['autoencoder']['latent_sr'] + + if randomize_seed: + random_seed = random.randint(0, MAX_SEED) + + pred = inference(autoencoder, unet, None, None, + tokenizer, text_encoder, + params, noise_scheduler, + text, neg_text, + length, + guidance_scale, guidance_rescale, + ddim_steps, eta, random_seed, + device) + + pred = pred.cpu().numpy().squeeze(0).squeeze(0) + # output_file = f"{save_path}/{text}.wav" + # sf.write(output_file, pred, samplerate=params['autoencoder']['sr']) + + return params['autoencoder']['sr'], pred + + +# Gradio Interface +def gradio_interface(): + # Input components + text_input = gr.Textbox(label="Text Prompt", value="the sound of dog barking") + length_input = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)") + + # Advanced settings + guidance_scale_input = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5, label="Guidance Scale") + guidance_rescale_input = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale") + ddim_steps_input = gr.Slider(minimum=25, maximum=200, step=5, value=100, label="DDIM Steps") + eta_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Eta") + random_seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0,) + + randomize_seed = gr.Checkbox(label="Randomize seed", value=False) + + # Output component + output_audio = gr.Audio(label="Converted Audio", type="numpy") + + # Interface + gr.Interface( + fn=generate_audio, + inputs=[text_input, length_input, guidance_scale_input, guidance_rescale_input, ddim_steps_input, eta_input, + random_seed_input, randomize_seed], + outputs=output_audio, + title="EzAudio Text-to-Audio Generator", + description="Generate audio from text using a diffusion model. Adjust advanced settings for more control.", + allow_flagging="never" + ).launch() + + +if __name__ == "__main__": + gradio_interface() diff --git a/src/.idea/.gitignore b/src/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1c2fda565b94d0f2b94cb65ba7cca866e7a25478 --- /dev/null +++ b/src/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/src/.idea/inspectionProfiles/Project_Default.xml b/src/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000000000000000000000000000000000..fbb15db07580fc05034be15a48d3471c912c3f63 --- /dev/null +++ b/src/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,34 @@ + + + + \ No newline at end of file diff --git a/src/.idea/inspectionProfiles/profiles_settings.xml b/src/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/src/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/src/.idea/misc.xml b/src/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..9de286525ff35cf3ec9f171ef56fd1557939f2a0 --- /dev/null +++ b/src/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/src/.idea/modules.xml b/src/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..7210df4ddc1a7504320efcdcec08b9961df50d0e --- /dev/null +++ b/src/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/src/.idea/src.iml b/src/.idea/src.iml new file mode 100644 index 0000000000000000000000000000000000000000..2946dc0d137bdbc1a0db1f730491160c3bdb883e --- /dev/null +++ b/src/.idea/src.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/src/.idea/workspace.xml b/src/.idea/workspace.xml new file mode 100644 index 0000000000000000000000000000000000000000..425bc7ce443a61597f14ee784049a292dabf9599 --- /dev/null +++ b/src/.idea/workspace.xml @@ -0,0 +1,128 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1726457759523 + + + + + + \ No newline at end of file diff --git a/src/inference.py b/src/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..93a9fd81a73480fa348a225a70cba20bcc35dc93 --- /dev/null +++ b/src/inference.py @@ -0,0 +1,169 @@ +import os +import random +import pandas as pd +import torch +import librosa +import numpy as np +import soundfile as sf +from tqdm import tqdm +from utils import scale_shift_re + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +@torch.no_grad() +def inference(autoencoder, unet, gt, gt_mask, + tokenizer, text_encoder, + params, noise_scheduler, + text_raw, neg_text=None, + audio_frames=500, + guidance_scale=3, guidance_rescale=0.0, + ddim_steps=50, eta=1, random_seed=2024, + device='cuda', + ): + if neg_text is None: + neg_text = [""] + if tokenizer is not None: + text_batch = tokenizer(text_raw, + max_length=params['text_encoder']['max_length'], + padding="max_length", truncation=True, return_tensors="pt") + text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool() + text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state + + uncond_text_batch = tokenizer(neg_text, + max_length=params['text_encoder']['max_length'], + padding="max_length", truncation=True, return_tensors="pt") + uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool() + uncond_text = text_encoder(input_ids=uncond_text, + attention_mask=uncond_text_mask).last_hidden_state + else: + text, text_mask = None, None + guidance_scale = None + + codec_dim = params['model']['out_chans'] + unet.eval() + + if random_seed is not None: + generator = torch.Generator(device=device).manual_seed(random_seed) + else: + generator = torch.Generator(device=device) + generator.seed() + + noise_scheduler.set_timesteps(ddim_steps) + + # init noise + noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device) + latents = noise + + for t in noise_scheduler.timesteps: + latents = noise_scheduler.scale_model_input(latents, t) + + if guidance_scale: + + latents_combined = torch.cat([latents, latents], dim=0) + text_combined = torch.cat([text, uncond_text], dim=0) + text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0) + + if gt is not None: + gt_combined = torch.cat([gt, gt], dim=0) + gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0) + else: + gt_combined = None + gt_mask_combined = None + + output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined, + cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined) + output_text, output_uncond = torch.chunk(output_combined, 2, dim=0) + + output_pred = output_uncond + guidance_scale * (output_text - output_uncond) + if guidance_rescale > 0.0: + output_pred = rescale_noise_cfg(output_pred, output_text, + guidance_rescale=guidance_rescale) + else: + output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask, + cls_token=None, gt=gt, mae_mask_infer=gt_mask) + + latents = noise_scheduler.step(model_output=output_pred, timestep=t, + sample=latents, + eta=eta, generator=generator).prev_sample + + pred = scale_shift_re(latents, params['autoencoder']['scale'], + params['autoencoder']['shift']) + if gt is not None: + pred[~gt_mask] = gt[~gt_mask] + pred_wav = autoencoder(embedding=pred) + return pred_wav + + +@torch.no_grad() +def eval_udit(autoencoder, unet, + tokenizer, text_encoder, + params, noise_scheduler, + val_df, subset, + audio_frames, mae=False, + guidance_scale=3, guidance_rescale=0.0, + ddim_steps=50, eta=1, random_seed=2023, + device='cuda', + epoch=0, save_path='logs/eval/', val_num=5): + val_df = pd.read_csv(val_df) + val_df = val_df[val_df['split'] == subset] + if mae: + val_df = val_df[val_df['audio_length'] != 0] + + save_path = save_path + str(epoch) + '/' + os.makedirs(save_path, exist_ok=True) + + for i in tqdm(range(len(val_df))): + row = val_df.iloc[i] + text = [row['caption']] + if mae: + audio_path = params['data']['val_dir'] + str(row['audio_path']) + gt, sr = librosa.load(audio_path, sr=params['data']['sr']) + gt = gt / (np.max(np.abs(gt)) + 1e-9) + sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr']) + num_samples = 10 * sr + if len(gt) < num_samples: + padding = num_samples - len(gt) + gt = np.pad(gt, (0, padding), 'constant') + else: + gt = gt[:num_samples] + gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device) + gt = autoencoder(audio=gt) + B, D, L = gt.shape + mask_len = int(L * 0.2) + gt_mask = torch.zeros(B, D, L).to(device) + for _ in range(2): + start = random.randint(0, L - mask_len) + gt_mask[:, :, start:start + mask_len] = 1 + gt_mask = gt_mask.bool() + else: + gt = None + gt_mask = None + + pred = inference(autoencoder, unet, gt, gt_mask, + tokenizer, text_encoder, + params, noise_scheduler, + text, neg_text=None, + audio_frames=audio_frames, + guidance_scale=guidance_scale, guidance_rescale=guidance_rescale, + ddim_steps=ddim_steps, eta=eta, random_seed=random_seed, + device=device) + + pred = pred.cpu().numpy().squeeze(0).squeeze(0) + + sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr']) + + if i + 1 >= val_num: + break diff --git a/src/models/blocks.py b/src/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..07fb79082e066098cd465c0317fc0fe6c7285266 --- /dev/null +++ b/src/models/blocks.py @@ -0,0 +1,325 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from .utils.attention import Attention, JointAttention +from .utils.modules import unpatchify, FeedForward +from .utils.modules import film_modulate + + +class AdaLN(nn.Module): + def __init__(self, dim, ada_mode='ada', r=None, alpha=None): + super().__init__() + self.ada_mode = ada_mode + self.scale_shift_table = None + if ada_mode == 'ada': + # move nn.silu outside + self.time_ada = nn.Linear(dim, 6 * dim, bias=True) + elif ada_mode == 'ada_single': + # adaln used in pixel-art alpha + self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) + elif ada_mode in ['ada_lora', 'ada_lora_bias']: + self.lora_a = nn.Linear(dim, r * 6, bias=False) + self.lora_b = nn.Linear(r * 6, dim * 6, bias=False) + self.scaling = alpha / r + if ada_mode == 'ada_lora_bias': + # take bias out for consistency + self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) + else: + raise NotImplementedError + + def forward(self, time_token=None, time_ada=None): + if self.ada_mode == 'ada': + assert time_ada is None + B = time_token.shape[0] + time_ada = self.time_ada(time_token).reshape(B, 6, -1) + elif self.ada_mode == 'ada_single': + B = time_ada.shape[0] + time_ada = time_ada.reshape(B, 6, -1) + time_ada = self.scale_shift_table[None] + time_ada + elif self.ada_mode in ['ada_lora', 'ada_lora_bias']: + B = time_ada.shape[0] + time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling + time_ada = time_ada + time_ada_lora + time_ada = time_ada.reshape(B, 6, -1) + if self.scale_shift_table is not None: + time_ada = self.scale_shift_table[None] + time_ada + else: + raise NotImplementedError + return time_ada + + +class DiTBlock(nn.Module): + """ + A modified PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, dim, context_dim=None, + num_heads=8, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer=nn.LayerNorm, + time_fusion='none', + ada_lora_rank=None, ada_lora_alpha=None, + skip=False, skip_norm=False, + rope_mode='none', + context_norm=False, + use_checkpoint=False): + + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode=rope_mode) + + if context_dim is not None: + self.use_context = True + self.cross_attn = Attention(dim=dim, + num_heads=num_heads, + context_dim=context_dim, + qkv_bias=qkv_bias, qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode='none') + self.norm2 = norm_layer(dim) + if context_norm: + self.norm_context = norm_layer(context_dim) + else: + self.norm_context = nn.Identity() + else: + self.use_context = False + + self.norm3 = norm_layer(dim) + self.mlp = FeedForward(dim=dim, mult=mlp_ratio, + activation_fn=act_layer, dropout=0) + + self.use_adanorm = True if time_fusion != 'token' else False + if self.use_adanorm: + self.adaln = AdaLN(dim, ada_mode=time_fusion, + r=ada_lora_rank, alpha=ada_lora_alpha) + if skip: + self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity() + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + self.use_checkpoint = use_checkpoint + + def forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + if self.use_checkpoint: + return checkpoint(self._forward, x, + time_token, time_ada, skip, context, + x_mask, context_mask, extras, + use_reentrant=False) + else: + return self._forward(x, + time_token, time_ada, skip, context, + x_mask, context_mask, extras) + + def _forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + B, T, C = x.shape + if self.skip_linear is not None: + assert skip is not None + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + if self.use_adanorm: + time_ada = self.adaln(time_token, time_ada) + (shift_msa, scale_msa, gate_msa, + shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1) + + # self attention + if self.use_adanorm: + x_norm = film_modulate(self.norm1(x), shift=shift_msa, + scale=scale_msa) + x = x + (1 - gate_msa) * self.attn(x_norm, context=None, + context_mask=x_mask, + extras=extras) + else: + x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask, + extras=extras) + + # cross attention + if self.use_context: + assert context is not None + x = x + self.cross_attn(x=self.norm2(x), + context=self.norm_context(context), + context_mask=context_mask, extras=extras) + + # mlp + if self.use_adanorm: + x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp) + x = x + (1 - gate_mlp) * self.mlp(x_norm) + else: + x = x + self.mlp(self.norm3(x)) + + return x + + +class JointDiTBlock(nn.Module): + """ + A modified PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, dim, context_dim=None, + num_heads=8, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer=nn.LayerNorm, + time_fusion='none', + ada_lora_rank=None, ada_lora_alpha=None, + skip=(False, False), + rope_mode=False, + context_norm=False, + use_checkpoint=False,): + + super().__init__() + # no cross attention + assert context_dim is None + self.attn_norm_x = norm_layer(dim) + self.attn_norm_c = norm_layer(dim) + self.attn = JointAttention(dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode=rope_mode) + self.ffn_norm_x = norm_layer(dim) + self.ffn_norm_c = norm_layer(dim) + self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio, + activation_fn=act_layer, dropout=0) + self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio, + activation_fn=act_layer, dropout=0) + + # Zero-out the shift table + self.use_adanorm = True if time_fusion != 'token' else False + if self.use_adanorm: + self.adaln = AdaLN(dim, ada_mode=time_fusion, + r=ada_lora_rank, alpha=ada_lora_alpha) + + if skip is False: + skip_x, skip_c = False, False + else: + skip_x, skip_c = skip + + self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None + self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None + + self.use_checkpoint = use_checkpoint + + def forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + if self.use_checkpoint: + return checkpoint(self._forward, x, + time_token, time_ada, skip, + context, x_mask, context_mask, extras, + use_reentrant=False) + else: + return self._forward(x, + time_token, time_ada, skip, + context, x_mask, context_mask, extras) + + def _forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + + assert context is None and context_mask is None + + context, x = x[:, :extras, :], x[:, extras:, :] + context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:] + + if skip is not None: + skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :] + + B, T, C = x.shape + if self.skip_linear_x is not None: + x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1)) + + if self.skip_linear_c is not None: + context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1)) + + if self.use_adanorm: + time_ada = self.adaln(time_token, time_ada) + (shift_msa, scale_msa, gate_msa, + shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1) + + # self attention + x_norm = self.attn_norm_x(x) + c_norm = self.attn_norm_c(context) + if self.use_adanorm: + x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa) + x_out, c_out = self.attn(x_norm, context=c_norm, + x_mask=x_mask, context_mask=context_mask, + extras=extras) + if self.use_adanorm: + x = x + (1 - gate_msa) * x_out + else: + x = x + x_out + context = context + c_out + + # mlp + if self.use_adanorm: + x_norm = film_modulate(self.ffn_norm_x(x), + shift=shift_mlp, scale=scale_mlp) + x = x + (1 - gate_mlp) * self.mlp_x(x_norm) + else: + x = x + self.mlp_x(self.ffn_norm_x(x)) + + c_norm = self.ffn_norm_c(context) + context = context + self.mlp_c(c_norm) + + return torch.cat((context, x), dim=1) + + +class FinalBlock(nn.Module): + def __init__(self, embed_dim, patch_size, in_chans, + img_size, + input_type='2d', + norm_layer=nn.LayerNorm, + use_conv=True, + use_adanorm=True): + super().__init__() + self.in_chans = in_chans + self.img_size = img_size + self.input_type = input_type + + self.norm = norm_layer(embed_dim) + if use_adanorm: + self.use_adanorm = True + else: + self.use_adanorm = False + + if input_type == '2d': + self.patch_dim = patch_size ** 2 * in_chans + self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) + if use_conv: + self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, + 3, padding=1) + else: + self.final_layer = nn.Identity() + + elif input_type == '1d': + self.patch_dim = patch_size * in_chans + self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) + if use_conv: + self.final_layer = nn.Conv1d(self.in_chans, self.in_chans, + 3, padding=1) + else: + self.final_layer = nn.Identity() + + def forward(self, x, time_ada=None, extras=0): + B, T, C = x.shape + x = x[:, extras:, :] + # only handle generation target + if self.use_adanorm: + shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1) + x = film_modulate(self.norm(x), shift, scale) + else: + x = self.norm(x) + x = self.linear(x) + x = unpatchify(x, self.in_chans, self.input_type, self.img_size) + x = self.final_layer(x) + return x \ No newline at end of file diff --git a/src/models/conditioners.py b/src/models/conditioners.py new file mode 100644 index 0000000000000000000000000000000000000000..7414cfe2b9dc3dc7dec25e698c2ce43db2765f17 --- /dev/null +++ b/src/models/conditioners.py @@ -0,0 +1,180 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +import math +from .udit import UDiT +from .utils.span_mask import compute_mask_indices + + +class EmbeddingCFG(nn.Module): + """ + Handles label dropout for classifier-free guidance. + """ + # todo: support 2D input + + def __init__(self, in_channels): + super().__init__() + self.cfg_embedding = nn.Parameter( + torch.randn(in_channels) / in_channels ** 0.5) + + def token_drop(self, condition, condition_mask, cfg_prob): + """ + Drops labels to enable classifier-free guidance. + """ + b, t, device = condition.shape[0], condition.shape[1], condition.device + drop_ids = torch.rand(b, device=device) < cfg_prob + uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t) + condition = torch.where(drop_ids[:, None, None], uncond, condition) + if condition_mask is not None: + condition_mask[drop_ids] = False + condition_mask[drop_ids, 0] = True + + return condition, condition_mask + + def forward(self, condition, condition_mask, cfg_prob=0.0): + if condition_mask is not None: + condition_mask = condition_mask.clone() + if cfg_prob > 0: + condition, condition_mask = self.token_drop(condition, + condition_mask, + cfg_prob) + return condition, condition_mask + + +class DiscreteCFG(nn.Module): + def __init__(self, replace_id=2): + super(DiscreteCFG, self).__init__() + self.replace_id = replace_id + + def forward(self, context, context_mask, cfg_prob): + context = context.clone() + if context_mask is not None: + context_mask = context_mask.clone() + if cfg_prob > 0: + cfg_mask = torch.rand(len(context)) < cfg_prob + if torch.any(cfg_mask): + context[cfg_mask] = 0 + context[cfg_mask, 0] = self.replace_id + if context_mask is not None: + context_mask[cfg_mask] = False + context_mask[cfg_mask, 0] = True + return context, context_mask + + +class CFGModel(nn.Module): + def __init__(self, context_dim, backbone): + super().__init__() + self.model = backbone + self.context_cfg = EmbeddingCFG(context_dim) + + def forward(self, x, timesteps, + context, x_mask=None, context_mask=None, + cfg_prob=0.0): + context = self.context_cfg(context, cfg_prob) + x = self.model(x=x, timesteps=timesteps, + context=context, + x_mask=x_mask, context_mask=context_mask) + return x + + +class ConcatModel(nn.Module): + def __init__(self, backbone, in_dim, stride=[]): + super().__init__() + self.model = backbone + + self.downsample_layers = nn.ModuleList() + for i, s in enumerate(stride): + downsample_layer = nn.Conv1d( + in_dim, + in_dim * 2, + kernel_size=2 * s, + stride=s, + padding=math.ceil(s / 2), + ) + self.downsample_layers.append(downsample_layer) + in_dim = in_dim * 2 + + self.context_cfg = EmbeddingCFG(in_dim) + + def forward(self, x, timesteps, + context, x_mask=None, + cfg=False, cfg_prob=0.0): + + # todo: support 2D input + # x: B, C, L + # context: B, C, L + + for downsample_layer in self.downsample_layers: + context = downsample_layer(context) + + context = context.transpose(1, 2) + context = self.context_cfg(caption=context, + cfg=cfg, cfg_prob=cfg_prob) + context = context.transpose(1, 2) + + assert context.shape[-1] == x.shape[-1] + x = torch.cat([context, x], dim=1) + x = self.model(x=x, timesteps=timesteps, + context=None, x_mask=x_mask, context_mask=None) + return x + + +class MaskDiT(nn.Module): + def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs): + super().__init__() + self.model = UDiT(**kwargs) + self.mae = mae + if self.mae: + out_channel = kwargs.pop('out_chans', None) + self.mask_embed = nn.Parameter(torch.zeros((out_channel))) + self.mae_prob = mae_prob + self.mask_ratio = mask_ratio + self.mask_span = mask_span + + def random_masking(self, gt, mask_ratios, mae_mask_infer=None): + B, D, L = gt.shape + if mae_mask_infer is None: + # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1) + mask_ratios = mask_ratios.cpu().numpy() + mask = compute_mask_indices(shape=[B, L], + padding_mask=None, + mask_prob=mask_ratios, + mask_length=self.mask_span, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0,) + mask = mask.unsqueeze(1).expand_as(gt) + else: + mask = mae_mask_infer + mask = mask.expand_as(gt) + gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask] + return gt, mask.type_as(gt) + + def forward(self, x, timesteps, context, + x_mask=None, context_mask=None, cls_token=None, + gt=None, mae_mask_infer=None): + mae_mask = torch.ones_like(x) + if self.mae: + if gt is not None: + B, D, L = gt.shape + mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device) + gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer) + # apply mae only to the selected batches + if mae_mask_infer is None: + # determine mae batch + mae_batch = torch.rand(B) < self.mae_prob + gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch] + mae_mask[~mae_batch] = 1.0 + else: + B, D, L = x.shape + gt = self.mask_embed.view(1, D, 1).expand_as(x) + x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1) + + x = self.model(x=x, timesteps=timesteps, context=context, + x_mask=x_mask, context_mask=context_mask, + cls_token=cls_token) + # print(mae_mask[:, 0, :].sum(dim=-1)) + return x, mae_mask diff --git a/src/models/udit.py b/src/models/udit.py new file mode 100644 index 0000000000000000000000000000000000000000..db86b3e01acaea2de42ae85e3ebbad83f7c19fe9 --- /dev/null +++ b/src/models/udit.py @@ -0,0 +1,356 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint +import math +from .utils.modules import PatchEmbed, TimestepEmbedder +from .utils.modules import PE_wrapper, RMSNorm +from .blocks import DiTBlock, JointDiTBlock, FinalBlock + + +class UDiT(nn.Module): + def __init__(self, + img_size=224, patch_size=16, in_chans=3, + input_type='2d', out_chans=None, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer='layernorm', + context_norm=False, + use_checkpoint=False, + # time fusion ada or token + time_fusion='token', + ada_lora_rank=None, ada_lora_alpha=None, + cls_dim=None, + # max length is only used for concat + context_dim=768, context_fusion='concat', + context_max_length=128, context_pe_method='sinu', + pe_method='abs', rope_mode='none', + use_conv=True, + skip=True, skip_norm=True): + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + # input + self.in_chans = in_chans + self.input_type = input_type + if self.input_type == '2d': + num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + elif self.input_type == '1d': + num_patches = img_size // patch_size + self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, input_type=input_type) + out_chans = in_chans if out_chans is None else out_chans + self.out_chans = out_chans + + # position embedding + self.rope = rope_mode + self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method, + length=num_patches) + + print(f'x position embedding: {pe_method}') + print(f'rope mode: {self.rope}') + + # time embed + self.time_embed = TimestepEmbedder(embed_dim) + self.time_fusion = time_fusion + self.use_adanorm = False + + # cls embed + if cls_dim is not None: + self.cls_embed = nn.Sequential( + nn.Linear(cls_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + else: + self.cls_embed = None + + # time fusion + if time_fusion == 'token': + # put token at the beginning of sequence + self.extras = 2 if self.cls_embed else 1 + self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras) + elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']: + self.use_adanorm = True + # aviod repetitive silu for each adaln block + self.time_act = nn.SiLU() + self.extras = 0 + self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True) + if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']: + # shared adaln + self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) + else: + self.time_ada = None + else: + raise NotImplementedError + print(f'time fusion mode: {self.time_fusion}') + + # context + # use a simple projection + self.use_context = False + self.context_cross = False + self.context_max_length = context_max_length + self.context_fusion = 'none' + if context_dim is not None: + self.use_context = True + self.context_embed = nn.Sequential( + nn.Linear(context_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + self.context_fusion = context_fusion + if context_fusion == 'concat' or context_fusion == 'joint': + self.extras += context_max_length + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + # no cross attention layers + context_dim = None + elif context_fusion == 'cross': + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + self.context_cross = True + context_dim = embed_dim + else: + raise NotImplementedError + print(f'context fusion mode: {context_fusion}') + print(f'context position embedding: {context_pe_method}') + + if self.context_fusion == 'joint': + Block = JointDiTBlock + self.use_skip = skip[0] + else: + Block = DiTBlock + self.use_skip = skip + + # norm layers + if norm_layer == 'layernorm': + norm_layer = nn.LayerNorm + elif norm_layer == 'rmsnorm': + norm_layer = RMSNorm + else: + raise NotImplementedError + + print(f'use long skip connection: {skip}') + self.in_blocks = nn.ModuleList([ + Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=False, skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.mid_block = Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=False, skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + + self.out_blocks = nn.ModuleList([ + Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=skip, skip_norm=skip_norm, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + # FinalLayer block + self.use_conv = use_conv + self.final_block = FinalBlock(embed_dim=embed_dim, + patch_size=patch_size, + img_size=img_size, + in_chans=out_chans, + input_type=input_type, + norm_layer=norm_layer, + use_conv=use_conv, + use_adanorm=self.use_adanorm) + self.initialize_weights() + + def _init_ada(self): + if self.time_fusion == 'ada': + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + for block in self.in_blocks: + nn.init.constant_(block.adaln.time_ada.weight, 0) + nn.init.constant_(block.adaln.time_ada.bias, 0) + nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0) + nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0) + for block in self.out_blocks: + nn.init.constant_(block.adaln.time_ada.weight, 0) + nn.init.constant_(block.adaln.time_ada.bias, 0) + elif self.time_fusion == 'ada_single': + nn.init.constant_(self.time_ada.weight, 0) + nn.init.constant_(self.time_ada.bias, 0) + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + elif self.time_fusion in ['ada_lora', 'ada_lora_bias']: + nn.init.constant_(self.time_ada.weight, 0) + nn.init.constant_(self.time_ada.bias, 0) + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + for block in self.in_blocks: + nn.init.kaiming_uniform_(block.adaln.lora_a.weight, + a=math.sqrt(5)) + nn.init.constant_(block.adaln.lora_b.weight, 0) + nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight, + a=math.sqrt(5)) + nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0) + for block in self.out_blocks: + nn.init.kaiming_uniform_(block.adaln.lora_a.weight, + a=math.sqrt(5)) + nn.init.constant_(block.adaln.lora_b.weight, 0) + + def initialize_weights(self): + # Basic init for all layers + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # init patch Conv like Linear + w = self.patch_embed.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.patch_embed.proj.bias, 0) + + # Zero-out AdaLN + if self.use_adanorm: + self._init_ada() + + # Zero-out Cross Attention + if self.context_cross: + for block in self.in_blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0) + nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0) + for block in self.out_blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out cls embedding + if self.cls_embed: + if self.use_adanorm: + nn.init.constant_(self.cls_embed[-1].weight, 0) + nn.init.constant_(self.cls_embed[-1].bias, 0) + + # Zero-out Output + # might not zero-out this when using v-prediction + # it could be good when using noise-prediction + # nn.init.constant_(self.final_block.linear.weight, 0) + # nn.init.constant_(self.final_block.linear.bias, 0) + # if self.use_conv: + # nn.init.constant_(self.final_block.final_layer.weight.data, 0) + # nn.init.constant_(self.final_block.final_layer.bias, 0) + + # init out Conv + if self.use_conv: + nn.init.xavier_uniform_(self.final_block.final_layer.weight) + nn.init.constant_(self.final_block.final_layer.bias, 0) + + def _concat_x_context(self, x, context, x_mask=None, context_mask=None): + assert context.shape[-2] == self.context_max_length + # Check if either x_mask or context_mask is provided + B = x.shape[0] + # Create default masks if they are not provided + if x_mask is None: + x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() + if context_mask is None: + context_mask = torch.ones(B, context.shape[-2], + device=context.device).bool() + # Concatenate the masks along the second dimension (dim=1) + x_mask = torch.cat([context_mask, x_mask], dim=1) + # Concatenate context and x along the second dimension (dim=1) + x = torch.cat((context, x), dim=1) + return x, x_mask + + def forward(self, x, timesteps, context, + x_mask=None, context_mask=None, + cls_token=None + ): + # make it compatible with int time step during inference + if timesteps.dim() == 0: + timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) + + x = self.patch_embed(x) + x = self.x_pe(x) + + B, L, D = x.shape + + if self.use_context: + context_token = self.context_embed(context) + context_token = self.context_pe(context_token) + if self.context_fusion == 'concat' or self.context_fusion == 'joint': + x, x_mask = self._concat_x_context(x=x, context=context_token, + x_mask=x_mask, + context_mask=context_mask) + context_token, context_mask = None, None + else: + context_token, context_mask = None, None + + time_token = self.time_embed(timesteps) + if self.cls_embed: + cls_token = self.cls_embed(cls_token) + time_ada = None + time_ada_final = None + if self.use_adanorm: + if self.cls_embed: + time_token = time_token + cls_token + time_token = self.time_act(time_token) + time_ada_final = self.time_ada_final(time_token) + if self.time_ada is not None: + time_ada = self.time_ada(time_token) + else: + time_token = time_token.unsqueeze(dim=1) + if self.cls_embed: + cls_token = cls_token.unsqueeze(dim=1) + time_token = torch.cat([time_token, cls_token], dim=1) + time_token = self.time_pe(time_token) + x = torch.cat((time_token, x), dim=1) + if x_mask is not None: + x_mask = torch.cat( + [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(), + x_mask], dim=1) + time_token = None + + skips = [] + for blk in self.in_blocks: + x = blk(x=x, time_token=time_token, time_ada=time_ada, + skip=None, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + if self.use_skip: + skips.append(x) + + x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada, + skip=None, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + + for blk in self.out_blocks: + skip = skips.pop() if self.use_skip else None + x = blk(x=x, time_token=time_token, time_ada=time_ada, + skip=skip, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + + x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) + + return x \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py b/src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py b/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ffad4542ada8f1824e93c76647da232db7f2da4e --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py @@ -0,0 +1,290 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .rotary import RotaryEmbedding +from .modules import RMSNorm + + +if hasattr(nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +def add_mask(sim, mask): + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + + +def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None): + def default(val, d): + return val if val is not None else (d() if isfunction(d) else d) + b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device + q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool)) + k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool)) + attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j') + return attn_mask + + +class Attention(nn.Module): + def __init__(self, dim, context_dim=None, num_heads=8, + qkv_bias=False, qk_scale=None, qk_norm=None, + attn_drop=0., proj_drop=0., rope_mode='none'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + if context_dim is None: + self.cross_attn = False + else: + self.cross_attn = True + + context_dim = dim if context_dim is None else context_dim + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) + + if qk_norm is None: + self.norm_q = nn.Identity() + self.norm_k = nn.Identity() + elif qk_norm == 'layernorm': + self.norm_q = nn.LayerNorm(head_dim) + self.norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + self.norm_q = RMSNorm(head_dim) + self.norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if self.cross_attn: + assert rope_mode == 'none' + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def forward(self, x, context=None, context_mask=None, extras=0): + B, L, C = x.shape + if context is None: + context = x + + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + if context_mask is not None: + mask_binary = create_mask(x.shape, context.shape, + x.device, None, context_mask) + else: + mask_binary = None + + q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads) + k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads) + v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads) + + q = self.norm_q(q) + k = self.norm_k(k) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JointAttention(nn.Module): + def __init__(self, dim, num_heads=8, + qkv_bias=False, qk_scale=None, qk_norm=None, + attn_drop=0., proj_drop=0., + rope_mode='none'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias) + self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias) + + self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim) + self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim) + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + + self.proj_x = nn.Linear(dim, dim) + self.proj_drop_x = nn.Dropout(proj_drop) + + self.proj_c = nn.Linear(dim, dim) + self.proj_drop_c = nn.Dropout(proj_drop) + + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _make_qkv_layers(self, dim, qkv_bias): + return (nn.Linear(dim, dim, bias=qkv_bias), + nn.Linear(dim, dim, bias=qkv_bias), + nn.Linear(dim, dim, bias=qkv_bias)) + + def _make_norm_layers(self, qk_norm, head_dim): + if qk_norm is None: + norm_q = nn.Identity() + norm_k = nn.Identity() + elif qk_norm == 'layernorm': + norm_q = nn.LayerNorm(head_dim) + norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + norm_q = RMSNorm(head_dim) + norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + return norm_q, norm_k + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def _cat_mask(self, x, context, x_mask=None, context_mask=None): + B = x.shape[0] + if x_mask is None: + x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() + if context_mask is None: + context_mask = torch.ones(B, context.shape[-2], device=context.device).bool() + mask = torch.cat([context_mask, x_mask], dim=1) + return mask + + def forward(self, x, context, x_mask=None, context_mask=None, extras=0): + B, Lx, C = x.shape + _, Lc, _ = context.shape + if x_mask is not None or context_mask is not None: + mask = self._cat_mask(x, context, + x_mask=x_mask, + context_mask=context_mask) + shape = [B, Lx+Lc, C] + mask_binary = create_mask(q_shape=shape, k_shape=shape, + device=x.device, + q_mask=None, k_mask=mask) + else: + mask_binary = None + + qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x) + qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context) + + qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', + H=self.num_heads), [qx, kx, vx]) + qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', + H=self.num_heads), [qc, kc, vc]) + + qx, kx = self.norm_qx(qx), self.norm_kx(kx) + qc, kc = self.norm_qc(qc), self.norm_kc(kc) + + q, k, v = (torch.cat([qc, qx], dim=2), + torch.cat([kc, kx], dim=2), + torch.cat([vc, vx], dim=2)) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + context, x = x[:, :Lc, :], x[:, Lc:, :] + + x = self.proj_x(x) + x = self.proj_drop_x(x) + + context = self.proj_c(context) + context = self.proj_drop_c(context) + + return x, context \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py b/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2a2a8c841b62748120d7cb33a6aa10860ecdb674 --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py @@ -0,0 +1,374 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.cuda.amp import autocast +import math +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .timm import trunc_normal_ + + +# disable in checkpoint mode +# @torch.jit.script +def film_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256, + out_size=None): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type( + self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +def patchify(imgs, patch_size, input_type='2d'): + if input_type == '2d': + x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) + elif input_type == '1d': + x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size) + return x + + +def unpatchify(x, channels=3, input_type='2d', img_size=None): + if input_type == '2d': + patch_size = int((x.shape[2] // channels) ** 0.5) + # h = w = int(x.shape[1] ** .5) + h, w = img_size[0] // patch_size, img_size[1] // patch_size + assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] + x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, + p1=patch_size, p2=patch_size) + elif input_type == '1d': + patch_size = int((x.shape[2] // channels)) + h = x.shape[1] + assert patch_size * channels == x.shape[2] + x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size) + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'): + super().__init__() + self.patch_size = patch_size + self.input_type = input_type + if input_type == '2d': + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + elif input_type == '1d': + self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + + def forward(self, x): + if self.input_type == '2d': + B, C, H, W = x.shape + assert H % self.patch_size == 0 and W % self.patch_size == 0 + elif self.input_type == '1d': + B, C, H = x.shape + assert H % self.patch_size == 0 + + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PositionalConvEmbedding(nn.Module): + """ + Relative positional embedding used in HuBERT + """ + + def __init__(self, dim=768, kernel_size=128, groups=16): + super().__init__() + self.conv = nn.Conv1d( + dim, + dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + bias=True + ) + self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x): + # B C T + x = self.conv(x) + x = F.gelu(x[:, :, :-1]) + return x + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, dim, length): + super(SinusoidalPositionalEncoding, self).__init__() + self.length = length + self.dim = dim + self.register_buffer('pe', self._generate_positional_encoding(length, dim)) + + def _generate_positional_encoding(self, length, dim): + pe = torch.zeros(length, dim) + position = torch.arange(0, length, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + return pe + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return x + + +class PE_wrapper(nn.Module): + def __init__(self, dim=768, method='abs', length=None, **kwargs): + super().__init__() + self.method = method + if method == 'abs': + # init absolute pe like UViT + self.length = length + self.abs_pe = nn.Parameter(torch.zeros(1, length, dim)) + trunc_normal_(self.abs_pe, std=.02) + elif method == 'conv': + self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs) + elif method == 'sinu': + self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length) + elif method == 'none': + # skip pe + self.id = nn.Identity() + else: + raise NotImplementedError + + def forward(self, x): + if self.method == 'abs': + _, L, _ = x.shape + assert L <= self.length + x = x + self.abs_pe[:, :L, :] + elif self.method == 'conv': + x = x + self.conv_pe(x) + elif self.method == 'sinu': + x = self.sinu_pe(x) + elif self.method == 'none': + x = self.id(x) + else: + raise NotImplementedError + return x + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class GELU(nn.Module): + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", + bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), + approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +# disable in checkpoint mode +# @torch.jit.script +def snake_beta(x, alpha, beta): + return x + beta * torch.sin(x * alpha).pow(2) + + +class Snake(nn.Module): + def __init__(self, dim_in, dim_out, bias, + alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x = snake_beta(x, self.alpha, self.beta) + return x + + +class GESnake(nn.Module): + def __init__(self, dim_in, dim_out, bias, + alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x, gate = x.chunk(2, dim=-1) + return x * snake_beta(gate, self.alpha, self.beta) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + dropout=0.0, + activation_fn="geglu", + final_dropout=False, + inner_dim=None, + bias=True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "snake": + act_fn = Snake(dim, inner_dim, bias=bias) + elif activation_fn == "gesnake": + act_fn = GESnake(dim, inner_dim, bias=bias) + else: + raise NotImplementedError + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py b/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..04f8b199ced89d0ed0365b8d74c1088749e7c441 --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py @@ -0,0 +1,91 @@ +import torch + +"this rope is faster than llama rope with jit script" + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +# disable in checkpoint mode +# @torch.jit.script +def apply_rotary_pos_emb(x, cos, sin): + # NOTE: This could probably be moved to Triton + # Handle a possible sequence length mismatch in between q and k + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + return (x * cos) + (rotate_half(x) * sin) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + + .. warning: Please note that this embedding is not registered on purpose, as it is transformative + (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis + """ + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=-2): + # expect input: B, H, L, D + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + # also make sure dtype wont change + if ( + seq_len != self._seq_len_cached + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): + self._seq_len_cached = seq_len + t = torch.arange( + x.shape[seq_dimension], device=x.device, dtype=torch.float32 + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + + return self._cos_cached, self._sin_cached + + def forward(self, q, k): + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + q.float(), seq_dimension=-2 + ) + if k is not None: + return ( + apply_rotary_pos_emb(q.float(), + self._cos_cached, + self._sin_cached).type_as(q), + apply_rotary_pos_emb(k.float(), + self._cos_cached, + self._sin_cached).type_as(k), + ) + else: + return ( + apply_rotary_pos_emb(q.float(), + self._cos_cached, + self._sin_cached).type_as(q), + None + ) \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py b/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..23f8557e9907c4f9ec17efa36ebd035d8667ff00 --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py @@ -0,0 +1,146 @@ +import numpy as np +import torch +from typing import Optional, Tuple + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + # Convert mask_prob to a NumPy array + mask_prob = np.array(mask_prob) + + # Calculate all_num_mask for each element in the batch + all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int) + + # Apply the max operation with min_masks for each element + all_num_mask = np.maximum(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask[i] + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + # min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + # if len(mask_idc) > min_len: + # mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return torch.tensor(mask) + + +if __name__ == '__main__': + mask = compute_mask_indices( + shape=[4, 500], + padding_mask=None, + mask_prob=[0.65, 0.5, 0.65, 0.65], + mask_length=10, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0, + ) + print(mask) + print(mask.sum(dim=1)) \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py b/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ae72f4fca681617778028b64da5daed26b3ed3d4 --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py @@ -0,0 +1,114 @@ +# code from timm 0.3.2 +import torch +import torch.nn as nn +import math +import warnings + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x \ No newline at end of file diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/utils/__pycache__/__init__.cpython-310.pyc b/src/models/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc31b0a19bf83eef2a703df69b4272a12bfbe577 Binary files /dev/null and b/src/models/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/__init__.cpython-311.pyc b/src/models/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6136f879b34c073f2bb3a13f6ef55d72b7ecf028 Binary files /dev/null and b/src/models/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/attention.cpython-310.pyc b/src/models/utils/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ebbb2fc665b2c2ff2115bc98ec054430333e6ee Binary files /dev/null and b/src/models/utils/__pycache__/attention.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/attention.cpython-311.pyc b/src/models/utils/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f457ce350f5ec3c6dfd75532635bfaf27cd72ba Binary files /dev/null and b/src/models/utils/__pycache__/attention.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/modules.cpython-310.pyc b/src/models/utils/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d507df3df9cbf9fa29fbabb4591df36aedf6bdd4 Binary files /dev/null and b/src/models/utils/__pycache__/modules.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/modules.cpython-311.pyc b/src/models/utils/__pycache__/modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..917e6f3d92cc78d984aeeb0e5c97b61fa7a97a84 Binary files /dev/null and b/src/models/utils/__pycache__/modules.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/rotary.cpython-310.pyc b/src/models/utils/__pycache__/rotary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..716bfd4abc834b3e912c4f4574ddc3d3597183a5 Binary files /dev/null and b/src/models/utils/__pycache__/rotary.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/rotary.cpython-311.pyc b/src/models/utils/__pycache__/rotary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9eff5238b9477aad457d21e785db9b986ff456e Binary files /dev/null and b/src/models/utils/__pycache__/rotary.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/span_mask.cpython-310.pyc b/src/models/utils/__pycache__/span_mask.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dc66ee9fc445b41e939f29ba32cfdeb6169bcc2 Binary files /dev/null and b/src/models/utils/__pycache__/span_mask.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/span_mask.cpython-311.pyc b/src/models/utils/__pycache__/span_mask.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55645308b0731bc848b667d789370364fe94426e Binary files /dev/null and b/src/models/utils/__pycache__/span_mask.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/timm.cpython-310.pyc b/src/models/utils/__pycache__/timm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa38c2a3330015ecde142ac1e187afd6afd3aa5 Binary files /dev/null and b/src/models/utils/__pycache__/timm.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/timm.cpython-311.pyc b/src/models/utils/__pycache__/timm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8ca43e9c8f24c9e6634babca5adadc78e686949 Binary files /dev/null and b/src/models/utils/__pycache__/timm.cpython-311.pyc differ diff --git a/src/models/utils/attention.py b/src/models/utils/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ffad4542ada8f1824e93c76647da232db7f2da4e --- /dev/null +++ b/src/models/utils/attention.py @@ -0,0 +1,290 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .rotary import RotaryEmbedding +from .modules import RMSNorm + + +if hasattr(nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +def add_mask(sim, mask): + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + + +def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None): + def default(val, d): + return val if val is not None else (d() if isfunction(d) else d) + b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device + q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool)) + k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool)) + attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j') + return attn_mask + + +class Attention(nn.Module): + def __init__(self, dim, context_dim=None, num_heads=8, + qkv_bias=False, qk_scale=None, qk_norm=None, + attn_drop=0., proj_drop=0., rope_mode='none'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + if context_dim is None: + self.cross_attn = False + else: + self.cross_attn = True + + context_dim = dim if context_dim is None else context_dim + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) + + if qk_norm is None: + self.norm_q = nn.Identity() + self.norm_k = nn.Identity() + elif qk_norm == 'layernorm': + self.norm_q = nn.LayerNorm(head_dim) + self.norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + self.norm_q = RMSNorm(head_dim) + self.norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if self.cross_attn: + assert rope_mode == 'none' + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def forward(self, x, context=None, context_mask=None, extras=0): + B, L, C = x.shape + if context is None: + context = x + + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + if context_mask is not None: + mask_binary = create_mask(x.shape, context.shape, + x.device, None, context_mask) + else: + mask_binary = None + + q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads) + k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads) + v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads) + + q = self.norm_q(q) + k = self.norm_k(k) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JointAttention(nn.Module): + def __init__(self, dim, num_heads=8, + qkv_bias=False, qk_scale=None, qk_norm=None, + attn_drop=0., proj_drop=0., + rope_mode='none'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias) + self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias) + + self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim) + self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim) + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + + self.proj_x = nn.Linear(dim, dim) + self.proj_drop_x = nn.Dropout(proj_drop) + + self.proj_c = nn.Linear(dim, dim) + self.proj_drop_c = nn.Dropout(proj_drop) + + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _make_qkv_layers(self, dim, qkv_bias): + return (nn.Linear(dim, dim, bias=qkv_bias), + nn.Linear(dim, dim, bias=qkv_bias), + nn.Linear(dim, dim, bias=qkv_bias)) + + def _make_norm_layers(self, qk_norm, head_dim): + if qk_norm is None: + norm_q = nn.Identity() + norm_k = nn.Identity() + elif qk_norm == 'layernorm': + norm_q = nn.LayerNorm(head_dim) + norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + norm_q = RMSNorm(head_dim) + norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + return norm_q, norm_k + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def _cat_mask(self, x, context, x_mask=None, context_mask=None): + B = x.shape[0] + if x_mask is None: + x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() + if context_mask is None: + context_mask = torch.ones(B, context.shape[-2], device=context.device).bool() + mask = torch.cat([context_mask, x_mask], dim=1) + return mask + + def forward(self, x, context, x_mask=None, context_mask=None, extras=0): + B, Lx, C = x.shape + _, Lc, _ = context.shape + if x_mask is not None or context_mask is not None: + mask = self._cat_mask(x, context, + x_mask=x_mask, + context_mask=context_mask) + shape = [B, Lx+Lc, C] + mask_binary = create_mask(q_shape=shape, k_shape=shape, + device=x.device, + q_mask=None, k_mask=mask) + else: + mask_binary = None + + qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x) + qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context) + + qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', + H=self.num_heads), [qx, kx, vx]) + qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', + H=self.num_heads), [qc, kc, vc]) + + qx, kx = self.norm_qx(qx), self.norm_kx(kx) + qc, kc = self.norm_qc(qc), self.norm_kc(kc) + + q, k, v = (torch.cat([qc, qx], dim=2), + torch.cat([kc, kx], dim=2), + torch.cat([vc, vx], dim=2)) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + context, x = x[:, :Lc, :], x[:, Lc:, :] + + x = self.proj_x(x) + x = self.proj_drop_x(x) + + context = self.proj_c(context) + context = self.proj_drop_c(context) + + return x, context \ No newline at end of file diff --git a/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py b/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..fd201997bfe10a721d3473692a820ccde7189797 --- /dev/null +++ b/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .rotary import RotaryEmbedding + +if hasattr(nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +def add_mask(sim, mask): + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + + +def create_mask(q, k, q_mask=None, k_mask=None): + def default(val, d): + return val if val is not None else (d() if isfunction(d) else d) + + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool)) + k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool)) + attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j') + return attn_mask + + +class Attention(nn.Module): + def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, + attn_drop=0., proj_drop=0., use_rope=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + context_dim = dim if context_dim is None else context_dim + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.use_rope = use_rope + if self.use_rope: + self.rotary = RotaryEmbedding(dim=head_dim) + + def forward(self, x, context=None, context_mask=None): + B, L, C = x.shape + q = self.to_q(x) + if context is None: + context = x + else: + assert self.use_rope is False + + k = self.to_k(context) + v = self.to_v(context) + + if context_mask is not None: + mask_binary = create_mask(x, context, None, context_mask) + else: + mask_binary = None + + q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float() + k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float() + v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float() + + if self.use_rope: + q, k = self.rotary(q=q, k=k) + + if ATTENTION_MODE == 'flash': + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + else: + raise NotImplementedError + + x = self.proj(x) + x = self.proj_drop(x) + return x \ No newline at end of file diff --git a/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py b/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..52d66b7a08773f2aadab74ac1ffd6d35409ff52b --- /dev/null +++ b/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py @@ -0,0 +1,74 @@ +import torch +from typing import Tuple +from rotary import RotaryEmbedding +import time + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, + x: torch.Tensor,): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def compute_rope(q, freqs_cis): + return q * freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq1, xq2 = xq.chunk(2, dim=-1) + xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float()) + + xk1, xk2 = xk.chunk(2, dim=-1) + xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float()) + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3) + xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +if __name__ == '__main__': + # Move data to CUDA + freq_cis = precompute_freqs_cis(4, 5).cuda() + x = torch.rand(1, 5, 1, 4).cuda() + y = torch.rand(1, 5, 1, 4).cuda() + + # First method + start_time = time.time() + for _ in range(20000): + x1, y1 = apply_rotary_emb(x, y, freq_cis) + end_time = time.time() + print(f"Method 1 time cost: {end_time - start_time} seconds") + + # Prepare data for the second method + x = x.permute(0, 2, 1, 3) + y = y.permute(0, 2, 1, 3) + rope = RotaryEmbedding(4).cuda() + + # Second method + start_time = time.time() + for _ in range(20000): + x2, y2 = rope(x, y) + end_time = time.time() + print(f"Method 2 time cost: {end_time - start_time} seconds") + + # Print the results + print(x1) + print(x2) \ No newline at end of file diff --git a/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc b/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de78be925584642b52de19239fd67bdcf6173d95 Binary files /dev/null and b/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc differ diff --git a/src/models/utils/bk/attention.py b/src/models/utils/bk/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..fd201997bfe10a721d3473692a820ccde7189797 --- /dev/null +++ b/src/models/utils/bk/attention.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .rotary import RotaryEmbedding + +if hasattr(nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +def add_mask(sim, mask): + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + + +def create_mask(q, k, q_mask=None, k_mask=None): + def default(val, d): + return val if val is not None else (d() if isfunction(d) else d) + + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool)) + k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool)) + attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j') + return attn_mask + + +class Attention(nn.Module): + def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, + attn_drop=0., proj_drop=0., use_rope=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + context_dim = dim if context_dim is None else context_dim + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.use_rope = use_rope + if self.use_rope: + self.rotary = RotaryEmbedding(dim=head_dim) + + def forward(self, x, context=None, context_mask=None): + B, L, C = x.shape + q = self.to_q(x) + if context is None: + context = x + else: + assert self.use_rope is False + + k = self.to_k(context) + v = self.to_v(context) + + if context_mask is not None: + mask_binary = create_mask(x, context, None, context_mask) + else: + mask_binary = None + + q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float() + k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float() + v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float() + + if self.use_rope: + q, k = self.rotary(q=q, k=k) + + if ATTENTION_MODE == 'flash': + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + else: + raise NotImplementedError + + x = self.proj(x) + x = self.proj_drop(x) + return x \ No newline at end of file diff --git a/src/models/utils/bk/llama_rotary.py b/src/models/utils/bk/llama_rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..52d66b7a08773f2aadab74ac1ffd6d35409ff52b --- /dev/null +++ b/src/models/utils/bk/llama_rotary.py @@ -0,0 +1,74 @@ +import torch +from typing import Tuple +from rotary import RotaryEmbedding +import time + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, + x: torch.Tensor,): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def compute_rope(q, freqs_cis): + return q * freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq1, xq2 = xq.chunk(2, dim=-1) + xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float()) + + xk1, xk2 = xk.chunk(2, dim=-1) + xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float()) + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3) + xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +if __name__ == '__main__': + # Move data to CUDA + freq_cis = precompute_freqs_cis(4, 5).cuda() + x = torch.rand(1, 5, 1, 4).cuda() + y = torch.rand(1, 5, 1, 4).cuda() + + # First method + start_time = time.time() + for _ in range(20000): + x1, y1 = apply_rotary_emb(x, y, freq_cis) + end_time = time.time() + print(f"Method 1 time cost: {end_time - start_time} seconds") + + # Prepare data for the second method + x = x.permute(0, 2, 1, 3) + y = y.permute(0, 2, 1, 3) + rope = RotaryEmbedding(4).cuda() + + # Second method + start_time = time.time() + for _ in range(20000): + x2, y2 = rope(x, y) + end_time = time.time() + print(f"Method 2 time cost: {end_time - start_time} seconds") + + # Print the results + print(x1) + print(x2) \ No newline at end of file diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2a2a8c841b62748120d7cb33a6aa10860ecdb674 --- /dev/null +++ b/src/models/utils/modules.py @@ -0,0 +1,374 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.cuda.amp import autocast +import math +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .timm import trunc_normal_ + + +# disable in checkpoint mode +# @torch.jit.script +def film_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256, + out_size=None): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type( + self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +def patchify(imgs, patch_size, input_type='2d'): + if input_type == '2d': + x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) + elif input_type == '1d': + x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size) + return x + + +def unpatchify(x, channels=3, input_type='2d', img_size=None): + if input_type == '2d': + patch_size = int((x.shape[2] // channels) ** 0.5) + # h = w = int(x.shape[1] ** .5) + h, w = img_size[0] // patch_size, img_size[1] // patch_size + assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] + x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, + p1=patch_size, p2=patch_size) + elif input_type == '1d': + patch_size = int((x.shape[2] // channels)) + h = x.shape[1] + assert patch_size * channels == x.shape[2] + x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size) + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'): + super().__init__() + self.patch_size = patch_size + self.input_type = input_type + if input_type == '2d': + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + elif input_type == '1d': + self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + + def forward(self, x): + if self.input_type == '2d': + B, C, H, W = x.shape + assert H % self.patch_size == 0 and W % self.patch_size == 0 + elif self.input_type == '1d': + B, C, H = x.shape + assert H % self.patch_size == 0 + + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PositionalConvEmbedding(nn.Module): + """ + Relative positional embedding used in HuBERT + """ + + def __init__(self, dim=768, kernel_size=128, groups=16): + super().__init__() + self.conv = nn.Conv1d( + dim, + dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + bias=True + ) + self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x): + # B C T + x = self.conv(x) + x = F.gelu(x[:, :, :-1]) + return x + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, dim, length): + super(SinusoidalPositionalEncoding, self).__init__() + self.length = length + self.dim = dim + self.register_buffer('pe', self._generate_positional_encoding(length, dim)) + + def _generate_positional_encoding(self, length, dim): + pe = torch.zeros(length, dim) + position = torch.arange(0, length, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + return pe + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return x + + +class PE_wrapper(nn.Module): + def __init__(self, dim=768, method='abs', length=None, **kwargs): + super().__init__() + self.method = method + if method == 'abs': + # init absolute pe like UViT + self.length = length + self.abs_pe = nn.Parameter(torch.zeros(1, length, dim)) + trunc_normal_(self.abs_pe, std=.02) + elif method == 'conv': + self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs) + elif method == 'sinu': + self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length) + elif method == 'none': + # skip pe + self.id = nn.Identity() + else: + raise NotImplementedError + + def forward(self, x): + if self.method == 'abs': + _, L, _ = x.shape + assert L <= self.length + x = x + self.abs_pe[:, :L, :] + elif self.method == 'conv': + x = x + self.conv_pe(x) + elif self.method == 'sinu': + x = self.sinu_pe(x) + elif self.method == 'none': + x = self.id(x) + else: + raise NotImplementedError + return x + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class GELU(nn.Module): + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", + bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), + approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +# disable in checkpoint mode +# @torch.jit.script +def snake_beta(x, alpha, beta): + return x + beta * torch.sin(x * alpha).pow(2) + + +class Snake(nn.Module): + def __init__(self, dim_in, dim_out, bias, + alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x = snake_beta(x, self.alpha, self.beta) + return x + + +class GESnake(nn.Module): + def __init__(self, dim_in, dim_out, bias, + alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x, gate = x.chunk(2, dim=-1) + return x * snake_beta(gate, self.alpha, self.beta) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + dropout=0.0, + activation_fn="geglu", + final_dropout=False, + inner_dim=None, + bias=True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "snake": + act_fn = Snake(dim, inner_dim, bias=bias) + elif activation_fn == "gesnake": + act_fn = GESnake(dim, inner_dim, bias=bias) + else: + raise NotImplementedError + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states \ No newline at end of file diff --git a/src/models/utils/rotary.py b/src/models/utils/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..04f8b199ced89d0ed0365b8d74c1088749e7c441 --- /dev/null +++ b/src/models/utils/rotary.py @@ -0,0 +1,91 @@ +import torch + +"this rope is faster than llama rope with jit script" + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +# disable in checkpoint mode +# @torch.jit.script +def apply_rotary_pos_emb(x, cos, sin): + # NOTE: This could probably be moved to Triton + # Handle a possible sequence length mismatch in between q and k + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + return (x * cos) + (rotate_half(x) * sin) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + + .. warning: Please note that this embedding is not registered on purpose, as it is transformative + (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis + """ + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=-2): + # expect input: B, H, L, D + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + # also make sure dtype wont change + if ( + seq_len != self._seq_len_cached + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): + self._seq_len_cached = seq_len + t = torch.arange( + x.shape[seq_dimension], device=x.device, dtype=torch.float32 + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + + return self._cos_cached, self._sin_cached + + def forward(self, q, k): + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + q.float(), seq_dimension=-2 + ) + if k is not None: + return ( + apply_rotary_pos_emb(q.float(), + self._cos_cached, + self._sin_cached).type_as(q), + apply_rotary_pos_emb(k.float(), + self._cos_cached, + self._sin_cached).type_as(k), + ) + else: + return ( + apply_rotary_pos_emb(q.float(), + self._cos_cached, + self._sin_cached).type_as(q), + None + ) \ No newline at end of file diff --git a/src/models/utils/span_mask.py b/src/models/utils/span_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..23f8557e9907c4f9ec17efa36ebd035d8667ff00 --- /dev/null +++ b/src/models/utils/span_mask.py @@ -0,0 +1,146 @@ +import numpy as np +import torch +from typing import Optional, Tuple + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + # Convert mask_prob to a NumPy array + mask_prob = np.array(mask_prob) + + # Calculate all_num_mask for each element in the batch + all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int) + + # Apply the max operation with min_masks for each element + all_num_mask = np.maximum(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask[i] + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + # min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + # if len(mask_idc) > min_len: + # mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return torch.tensor(mask) + + +if __name__ == '__main__': + mask = compute_mask_indices( + shape=[4, 500], + padding_mask=None, + mask_prob=[0.65, 0.5, 0.65, 0.65], + mask_length=10, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0, + ) + print(mask) + print(mask.sum(dim=1)) \ No newline at end of file diff --git a/src/models/utils/timm.py b/src/models/utils/timm.py new file mode 100644 index 0000000000000000000000000000000000000000..ae72f4fca681617778028b64da5daed26b3ed3d4 --- /dev/null +++ b/src/models/utils/timm.py @@ -0,0 +1,114 @@ +# code from timm 0.3.2 +import torch +import torch.nn as nn +import math +import warnings + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x \ No newline at end of file diff --git a/src/modules/autoencoder_wrapper.py b/src/modules/autoencoder_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca144325af0e2642514da91c277b25fe6f52af8 --- /dev/null +++ b/src/modules/autoencoder_wrapper.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +from .dac import DAC +from .stable_vae import load_vae + + +class Autoencoder(nn.Module): + def __init__(self, ckpt_path, model_type='dac', quantization_first=False): + super(Autoencoder, self).__init__() + self.model_type = model_type + if self.model_type == 'dac': + model = DAC.load(ckpt_path) + elif self.model_type == 'stable_vae': + model = load_vae(ckpt_path) + else: + raise NotImplementedError(f"Model type not implemented: {self.model_type}") + self.ae = model.eval() + self.quantization_first = quantization_first + print(f'Autoencoder quantization first mode: {quantization_first}') + + @torch.no_grad() + def forward(self, audio=None, embedding=None): + if self.model_type == 'dac': + return self.process_dac(audio, embedding) + elif self.model_type == 'encodec': + return self.process_encodec(audio, embedding) + elif self.model_type == 'stable_vae': + return self.process_stable_vae(audio, embedding) + else: + raise NotImplementedError(f"Model type not implemented: {self.model_type}") + + def process_dac(self, audio=None, embedding=None): + if audio is not None: + z = self.ae.encoder(audio) + if self.quantization_first: + z, *_ = self.ae.quantizer(z, None) + return z + elif embedding is not None: + z = embedding + if self.quantization_first: + audio = self.ae.decoder(z) + else: + z, *_ = self.ae.quantizer(z, None) + audio = self.ae.decoder(z) + return audio + else: + raise ValueError("Either audio or embedding must be provided.") + + def process_encodec(self, audio=None, embedding=None): + if audio is not None: + z = self.ae.encoder(audio) + if self.quantization_first: + code = self.ae.quantizer.encode(z) + z = self.ae.quantizer.decode(code) + return z + elif embedding is not None: + z = embedding + if self.quantization_first: + audio = self.ae.decoder(z) + else: + code = self.ae.quantizer.encode(z) + z = self.ae.quantizer.decode(code) + audio = self.ae.decoder(z) + return audio + else: + raise ValueError("Either audio or embedding must be provided.") + + def process_stable_vae(self, audio=None, embedding=None): + if audio is not None: + z = self.ae.encoder(audio) + if self.quantization_first: + z = self.ae.bottleneck.encode(z) + return z + if embedding is not None: + z = embedding + if self.quantization_first: + audio = self.ae.decoder(z) + else: + z = self.ae.bottleneck.encode(z) + audio = self.ae.decoder(z) + return audio + else: + raise ValueError("Either audio or embedding must be provided.") diff --git a/src/modules/clap_wrapper.py b/src/modules/clap_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/modules/dac/__init__.py b/src/modules/dac/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5d03388fa63486960c783ebe7f1bd411b95b1d --- /dev/null +++ b/src/modules/dac/__init__.py @@ -0,0 +1,16 @@ +__version__ = "1.0.0" + +# preserved here for legacy reasons +__model_version__ = "latest" + +import audiotools + +audiotools.ml.BaseModel.INTERN += ["dac.**"] +audiotools.ml.BaseModel.EXTERN += ["einops"] + + +from . import nn +from . import model +from . import utils +from .model import DAC +from .model import DACFile diff --git a/src/modules/dac/__main__.py b/src/modules/dac/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..393698e7da671bce1478b4d78b2f685d2640636b --- /dev/null +++ b/src/modules/dac/__main__.py @@ -0,0 +1,36 @@ +import sys + +import argbind + +from dac.utils import download +from dac.utils.decode import decode +from dac.utils.encode import encode + +STAGES = ["encode", "decode", "download"] + + +def run(stage: str): + """Run stages. + + Parameters + ---------- + stage : str + Stage to run + """ + if stage not in STAGES: + raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") + stage_fn = globals()[stage] + + if stage == "download": + stage_fn() + return + + stage_fn() + + +if __name__ == "__main__": + group = sys.argv.pop(1) + args = argbind.parse_args(group=group) + + with argbind.scope(args): + run(group) diff --git a/src/modules/dac/compare/__init__.py b/src/modules/dac/compare/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/modules/dac/compare/encodec.py b/src/modules/dac/compare/encodec.py new file mode 100644 index 0000000000000000000000000000000000000000..f74e4990eacdafaa8a2e10fb62c38bde10816db0 --- /dev/null +++ b/src/modules/dac/compare/encodec.py @@ -0,0 +1,54 @@ +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from encodec import EncodecModel + + +class Encodec(BaseModel): + def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): + super().__init__() + + if sample_rate == 24000: + self.model = EncodecModel.encodec_model_24khz() + else: + self.model = EncodecModel.encodec_model_48khz() + self.model.set_target_bandwidth(bandwidth) + self.sample_rate = 44100 + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = 44100, + n_quantizers: int = None, + ): + signal = AudioSignal(audio_data, sample_rate) + signal.resample(self.model.sample_rate) + recons = self.model(signal.audio_data) + recons = AudioSignal(recons, self.model.sample_rate) + recons.resample(sample_rate) + return {"audio": recons.audio_data} + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = Encodec() + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p/1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + + print(x.shape, out.shape) diff --git a/src/modules/dac/model/__init__.py b/src/modules/dac/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58b47475d39e4249d2cd45577503bb68ffdacd00 --- /dev/null +++ b/src/modules/dac/model/__init__.py @@ -0,0 +1,4 @@ +from .base import CodecMixin +from .base import DACFile +from .dac import DAC +from .discriminator import Discriminator diff --git a/src/modules/dac/model/base.py b/src/modules/dac/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef5a44074ca2bd5d726fe90fa7c8c87da1b3b7a --- /dev/null +++ b/src/modules/dac/model/base.py @@ -0,0 +1,294 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError( + f"Given file {path} can't be loaded with this version of descript-audio-codec." + ) + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [ + l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) + ] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = ( + audio_signal.signal_duration if win_duration is None else win_duration + ) + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) + + self.padding = original_padding + return recons diff --git a/src/modules/dac/model/dac.py b/src/modules/dac/model/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..98c5983644dd4ffdf87730bc92ac3e9bb13e4374 --- /dev/null +++ b/src/modules/dac/model/dac.py @@ -0,0 +1,364 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from ..nn.layers import Snake1d +from ..nn.layers import WNConv1d +from ..nn.layers import WNConvTranspose1d +from ..nn.quantize import ResidualVectorQuantize + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer( + z, n_quantizers + ) + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + return self.decoder(z) + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode( + audio_data, n_quantizers + ) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p/1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/src/modules/dac/model/discriminator.py b/src/modules/dac/model/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..ade8e6ddc3741e6fb443ba4c88cd179c3f0196eb --- /dev/null +++ b/src/modules/dac/model/discriminator.py @@ -0,0 +1,228 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import ml +from audiotools import STFTParams +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100): + super().__init__() + self.convs = nn.ModuleList( + [ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class Discriminator(ml.BaseModel): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = Discriminator() + x = torch.zeros(1, 1, 44100) + results = disc(x) + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print() diff --git a/src/modules/dac/nn/__init__.py b/src/modules/dac/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2f6b972aeb8a21764ff73dae2095eb94bb8ba4 --- /dev/null +++ b/src/modules/dac/nn/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/src/modules/dac/nn/layers.py b/src/modules/dac/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..63c53a88b881deab94fb06b6a395951cf9cc995a --- /dev/null +++ b/src/modules/dac/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/src/modules/dac/nn/loss.py b/src/modules/dac/nn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7a26e62f885f0bb7c6774e95660a7848e10a8f87 --- /dev/null +++ b/src/modules/dac/nn/loss.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/src/modules/dac/nn/quantize.py b/src/modules/dac/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..499088fbde266dc89a704e45944b8900f2c72952 --- /dev/null +++ b/src/modules/dac/nn/quantize.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from .layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/src/modules/dac/utils/__init__.py b/src/modules/dac/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4522f8b226ac1c5cf8abeefc6ba32328c2375d5 --- /dev/null +++ b/src/modules/dac/utils/__init__.py @@ -0,0 +1,122 @@ +from pathlib import Path + +import argbind +from audiotools import ml + +from ..model import DAC + +Accelerator = ml.Accelerator + +__MODEL_LATEST_TAGS__ = { + ("44khz", "8kbps"): "0.0.1", + ("24khz", "8kbps"): "0.0.4", + ("16khz", "8kbps"): "0.0.5", + ("44khz", "16kbps"): "1.0.0", +} + +__MODEL_URLS__ = { + ( + "44khz", + "0.0.1", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", + ( + "24khz", + "0.0.4", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", + ( + "16khz", + "0.0.5", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", + ( + "44khz", + "1.0.0", + "16kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", +} + + +@argbind.bind(group="download", positional=True, without_prefix=True) +def download( + model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" +): + """ + Function that downloads the weights file from URL if a local cache is not found. + + Parameters + ---------- + model_type : str + The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + Only 44khz model supports 16kbps. + tag : str + The tag of the model to download. Defaults to "latest". + + Returns + ------- + Path + Directory path required to load model via audiotools. + """ + model_type = model_type.lower() + tag = tag.lower() + + assert model_type in [ + "44khz", + "24khz", + "16khz", + ], "model_type must be one of '44khz', '24khz', or '16khz'" + + assert model_bitrate in [ + "8kbps", + "16kbps", + ], "model_bitrate must be one of '8kbps', or '16kbps'" + + if tag == "latest": + tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] + + download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) + + if download_link is None: + raise ValueError( + f"Could not find model with tag {tag} and model type {model_type}" + ) + + local_path = ( + Path.home() + / ".cache" + / "descript" + / "dac" + / f"weights_{model_type}_{model_bitrate}_{tag}.pth" + ) + if not local_path.exists(): + local_path.parent.mkdir(parents=True, exist_ok=True) + + # Download the model + import requests + + response = requests.get(download_link) + + if response.status_code != 200: + raise ValueError( + f"Could not download model. Received response code {response.status_code}" + ) + local_path.write_bytes(response.content) + + return local_path + + +def load_model( + model_type: str = "44khz", + model_bitrate: str = "8kbps", + tag: str = "latest", + load_path: str = None, +): + if not load_path: + load_path = download( + model_type=model_type, model_bitrate=model_bitrate, tag=tag + ) + generator = DAC.load(load_path) + return generator diff --git a/src/modules/dac/utils/decode.py b/src/modules/dac/utils/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..48c25298fd0f9a53f19c1d6d260f1f0563c7b7e4 --- /dev/null +++ b/src/modules/dac/utils/decode.py @@ -0,0 +1,95 @@ +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from tqdm import tqdm + +from dac import DACFile +from dac.utils import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="decode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def decode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + device: str = "cuda", + model_type: str = "44khz", + verbose: bool = False, +): + """Decode audio from codes. + + Parameters + ---------- + input : str + Path to input directory or file + output : str, optional + Path to output directory, by default "". + If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + device : str, optional + Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + + # Find all .dac files in input directory + _input = Path(input) + input_files = list(_input.glob("**/*.dac")) + + # If input is a .dac file, add it to the list + if _input.suffix == ".dac": + input_files.append(_input) + + # Create output directory + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(input_files)), desc=f"Decoding files"): + # Load file + artifact = DACFile.load(input_files[i]) + + # Reconstruct audio from codes + recons = generator.decompress(artifact, verbose=verbose) + + # Compute output path + relative_path = input_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = input_files[i] + output_name = relative_path.with_suffix(".wav").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write to file + recons.write(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + decode() diff --git a/src/modules/dac/utils/encode.py b/src/modules/dac/utils/encode.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9a8b582b24c194ba82c2d280c4df4fa2cfff87 --- /dev/null +++ b/src/modules/dac/utils/encode.py @@ -0,0 +1,94 @@ +import math +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.core import util +from tqdm import tqdm + +from dac.utils import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="encode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def encode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + n_quantizers: int = None, + device: str = "cuda", + model_type: str = "44khz", + win_duration: float = 5.0, + verbose: bool = False, +): + """Encode audio files in input path to .dac format. + + Parameters + ---------- + input : str + Path to input audio file or directory + output : str, optional + Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + n_quantizers : int, optional + Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. + device : str, optional + Device to use, by default "cuda" + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + kwargs = {"n_quantizers": n_quantizers} + + # Find all audio files in input path + input = Path(input) + audio_files = util.find_audio(input) + + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(audio_files)), desc="Encoding files"): + # Load file + signal = AudioSignal(audio_files[i]) + + # Encode audio to .dac format + artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) + + # Compute output path + relative_path = audio_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = audio_files[i] + output_name = relative_path.with_suffix(".dac").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + artifact.save(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + encode() diff --git a/src/modules/stable_vae/__init__.py b/src/modules/stable_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd2a867025e88625177ccad8d73ce46b807c945 --- /dev/null +++ b/src/modules/stable_vae/__init__.py @@ -0,0 +1,40 @@ +from .models.autoencoders import create_autoencoder_from_config +import os +import json +import torch +from torch.nn.utils import remove_weight_norm + + +def remove_all_weight_norm(model): + for name, module in model.named_modules(): + if hasattr(module, 'weight_g'): + remove_weight_norm(module) + + +def load_vae(ckpt_path, remove_weight_norm=False): + config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json') + + # Load the model configuration + with open(config_file) as f: + model_config = json.load(f) + + # Create the model from the configuration + model = create_autoencoder_from_config(model_config) + + # Load the state dictionary from the checkpoint + model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict'] + + # Strip the "autoencoder." prefix from the keys + model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")} + + # Load the state dictionary into the model + model.load_state_dict(model_dict) + + # Remove weight normalization + if remove_weight_norm: + remove_all_weight_norm(model) + + # Set the model to evaluation mode + model.eval() + + return model diff --git a/src/modules/stable_vae/models/autoencoders.py b/src/modules/stable_vae/models/autoencoders.py new file mode 100644 index 0000000000000000000000000000000000000000..23e1b428ec40767b0ab3838a37b16401fa55ccbd --- /dev/null +++ b/src/modules/stable_vae/models/autoencoders.py @@ -0,0 +1,683 @@ +import torch +import math +import numpy as np + +from torch import nn +from torch.nn import functional as F +from torchaudio import transforms as T +from alias_free_torch import Activation1d +from .nn.layers import WNConv1d, WNConvTranspose1d +from typing import Literal, Dict, Any + +# from .inference.sampling import sample +from .utils import prepare_audio +from .blocks import SnakeBeta +from .bottleneck import Bottleneck, DiscreteBottleneck +from .factory import create_pretransform_from_config, create_bottleneck_from_config +from .pretransforms import Pretransform + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + +class ResidualUnit(nn.Module): + def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): + super().__init__() + + self.dilation = dilation + + padding = (dilation * (7-1)) // 2 + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=7, dilation=dilation, padding=padding), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=out_channels, out_channels=out_channels, + kernel_size=1) + ) + + def forward(self, x): + res = x + + #x = checkpoint(self.layers, x) + x = self.layers(x) + + return x + res + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): + super().__init__() + + self.layers = nn.Sequential( + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=9, use_snake=use_snake), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + ) + + def forward(self, x): + return self.layers(x) + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, + stride=1, + bias=False, + padding='same') + ) + else: + upsample_layer = WNConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + upsample_layer, + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=9, use_snake=use_snake), + ) + + def forward(self, x): + return self.layers(x) + +class OobleckEncoder(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) + ] + + for i in range(self.depth-1): + layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), + WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__(self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + ] + + for i in range(self.depth-1, 0, -1): + layers += [DecoderBlock( + in_channels=c_mults[i]*channels, + out_channels=c_mults[i-1]*channels, + stride=strides[i-1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), + WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class DACEncoderWrapper(nn.Module): + def __init__(self, in_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Encoder as DACEncoder + + latent_dim = kwargs.pop("latent_dim", None) + + encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) + self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) + self.latent_dim = latent_dim + + # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility + self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + + if in_channels != 1: + self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) + + def forward(self, x): + x = self.encoder(x) + x = self.proj_out(x) + return x + +class DACDecoderWrapper(nn.Module): + def __init__(self, latent_dim, out_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Decoder as DACDecoder + + self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) + + self.latent_dim = latent_dim + + def forward(self, x): + return self.decoder(x) + +class AudioAutoencoder(nn.Module): + def __init__( + self, + encoder, + decoder, + latent_dim, + downsampling_ratio, + sample_rate, + io_channels=2, + bottleneck: Bottleneck = None, + pretransform: Pretransform = None, + in_channels = None, + out_channels = None, + soft_clip = False + ): + super().__init__() + + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + + self.min_length = self.downsampling_ratio + + if in_channels is not None: + self.in_channels = in_channels + + if out_channels is not None: + self.out_channels = out_channels + + self.bottleneck = bottleneck + + self.encoder = encoder + + self.decoder = decoder + + self.pretransform = pretransform + + self.soft_clip = soft_clip + + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete + + def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): + + info = {} + + if self.pretransform is not None and not skip_pretransform: + if self.pretransform.enable_grad: + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + else: + with torch.no_grad(): + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + + if self.encoder is not None: + if iterate_batch: + latents = [] + for i in range(audio.shape[0]): + latents.append(self.encoder(audio[i:i+1])) + latents = torch.cat(latents, dim=0) + else: + latents = self.encoder(audio) + else: + latents = audio + + if self.bottleneck is not None: + # TODO: Add iterate batch logic, needs to merge the info dicts + latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) + + info.update(bottleneck_info) + + if return_info: + return latents, info + + return latents + + def decode(self, latents, iterate_batch=False, **kwargs): + + if self.bottleneck is not None: + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.bottleneck.decode(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + latents = self.bottleneck.decode(latents) + + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.decoder(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + decoded = self.decoder(latents, **kwargs) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + if iterate_batch: + decodeds = [] + for i in range(decoded.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + if iterate_batch: + decodeds = [] + for i in range(latents.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + + if self.soft_clip: + decoded = torch.tanh(decoded) + + return decoded + + def decode_tokens(self, tokens, **kwargs): + ''' + Decode discrete tokens to audio + Only works with discrete autoencoders + ''' + + assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" + + latents = self.bottleneck.decode_tokens(tokens, **kwargs) + + return self.decode(latents, **kwargs) + + + def preprocess_audio_for_encoder(self, audio, in_sr): + ''' + Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. + If the model is mono, stereo audio will be converted to mono. + Audio will be silence-padded to be a multiple of the model's downsampling ratio. + Audio will be resampled to the model's sample rate. + The output will have batch size 1 and be shape (1 x Channels x Length) + ''' + return self.preprocess_audio_list_for_encoder([audio], [in_sr]) + + def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): + ''' + Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. + The audio in that list can be of different lengths and channels. + in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. + All audio will be resampled to the model's sample rate. + Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. + If the model is mono, all audio will be converted to mono. + The output will be a tensor of shape (Batch x Channels x Length) + ''' + batch_size = len(audio_list) + if isinstance(in_sr_list, int): + in_sr_list = [in_sr_list]*batch_size + assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" + new_audio = [] + max_length = 0 + # resample & find the max length + for i in range(batch_size): + audio = audio_list[i] + in_sr = in_sr_list[i] + if len(audio.shape) == 3 and audio.shape[0] == 1: + # batchsize 1 was given by accident. Just squeeze it. + audio = audio.squeeze(0) + elif len(audio.shape) == 1: + # Mono signal, channel dimension is missing, unsqueeze it in + audio = audio.unsqueeze(0) + assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" + # Resample audio + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) + audio = resample_tf(audio) + new_audio.append(audio) + if audio.shape[-1] > max_length: + max_length = audio.shape[-1] + # Pad every audio to the same length, multiple of model's downsampling ratio + padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length + for i in range(batch_size): + # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model + new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, + target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) + # convert to tensor + return torch.stack(new_audio) + + def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. + If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. + Overlap and chunk_size params are both measured in number of latents (not audio samples) + # and therefore you likely could use the same values with decode_audio. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Encode the entire audio in parallel + return self.encode(audio, **kwargs) + else: + # CHUNKED ENCODING + # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) + samples_per_latent = self.downsampling_ratio + total_size = audio.shape[2] # in samples + batch_size = audio.shape[0] + chunk_size *= samples_per_latent # converting metric in latents to samples + overlap *= samples_per_latent # converting metric in latents to samples + hop_size = chunk_size - overlap + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = audio[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = audio[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # Note: y_size might be a different value from the latent length used in diffusion training + # because we can encode audio of varying lengths + # However, the audio should've been padded to a multiple of samples_per_latent by now. + y_size = total_size // samples_per_latent + # Create an empty latent, we will populate it with chunks as we encode them + y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # encode the chunk + y_chunk = self.encode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size // samples_per_latent + t_end = t_start + chunk_size // samples_per_latent + # remove the edges of the overlaps + ol = overlap//samples_per_latent//2 + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Decode latents to audio. + If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Decode the entire latent in parallel + return self.decode(latents, **kwargs) + else: + # chunked decoding + hop_size = chunk_size - overlap + total_size = latents.shape[2] + batch_size = latents.shape[0] + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = latents[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = latents[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # samples_per_latent is just the downsampling ratio + samples_per_latent = self.downsampling_ratio + # Create an empty waveform, we will populate it with chunks as decode them + y_size = total_size * samples_per_latent + y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # decode the chunk + y_chunk = self.decode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size * samples_per_latent + t_end = t_start + chunk_size * samples_per_latent + # remove the edges of the overlaps + ol = (overlap//2) * samples_per_latent + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + +# AE factories + +def create_encoder_from_config(encoder_config: Dict[str, Any]): + encoder_type = encoder_config.get("type", None) + assert encoder_type is not None, "Encoder type must be specified" + + if encoder_type == "oobleck": + encoder = OobleckEncoder( + **encoder_config["config"] + ) + + elif encoder_type == "seanet": + from encodec.modules import SEANetEncoder + seanet_encoder_config = encoder_config["config"] + + #SEANet encoder expects strides in reverse order + seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) + encoder = SEANetEncoder( + **seanet_encoder_config + ) + elif encoder_type == "dac": + dac_config = encoder_config["config"] + + encoder = DACEncoderWrapper(**dac_config) + elif encoder_type == "local_attn": + from .local_attention import TransformerEncoder1D + + local_attn_config = encoder_config["config"] + + encoder = TransformerEncoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown encoder type {encoder_type}") + + requires_grad = encoder_config.get("requires_grad", True) + if not requires_grad: + for param in encoder.parameters(): + param.requires_grad = False + + return encoder + +def create_decoder_from_config(decoder_config: Dict[str, Any]): + decoder_type = decoder_config.get("type", None) + assert decoder_type is not None, "Decoder type must be specified" + + if decoder_type == "oobleck": + decoder = OobleckDecoder( + **decoder_config["config"] + ) + elif decoder_type == "seanet": + from encodec.modules import SEANetDecoder + + decoder = SEANetDecoder( + **decoder_config["config"] + ) + elif decoder_type == "dac": + dac_config = decoder_config["config"] + + decoder = DACDecoderWrapper(**dac_config) + elif decoder_type == "local_attn": + from .local_attention import TransformerDecoder1D + + local_attn_config = decoder_config["config"] + + decoder = TransformerDecoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown decoder type {decoder_type}") + + requires_grad = decoder_config.get("requires_grad", True) + if not requires_grad: + for param in decoder.parameters(): + param.requires_grad = False + + return decoder + +def create_autoencoder_from_config(config: Dict[str, Any]): + + ae_config = config["model"] + + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + + bottleneck = ae_config.get("bottleneck", None) + + latent_dim = ae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = ae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = ae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + in_channels = ae_config.get("in_channels", None) + out_channels = ae_config.get("out_channels", None) + + pretransform = ae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + soft_clip = ae_config["decoder"].get("soft_clip", False) + + return AudioAutoencoder( + encoder, + decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=pretransform, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip + ) \ No newline at end of file diff --git a/src/modules/stable_vae/models/blocks.py b/src/modules/stable_vae/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..a832f5c561a5b8e0a1c58a5c5fb4c05325534561 --- /dev/null +++ b/src/modules/stable_vae/models/blocks.py @@ -0,0 +1,359 @@ +from functools import reduce +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.backends.cuda import sdp_kernel +from packaging import version + +from .nn.layers import Snake1d + + +class ResidualBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + + +class ResConvBlock(ResidualBlock): + def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): + skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) + super().__init__([ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], skip) + + +class SelfAttention1d(nn.Module): + def __init__(self, c_in, n_head=1, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm = nn.GroupNorm(1, c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv1d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward(self, input): + n, c, s = input.shape + qkv = self.qkv_proj(self.norm(input)) + qkv = qkv.view( + [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3]**-0.25 + + if self.use_flash: + with sdp_kernel(*self.sdp_kernel_config): + y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) + else: + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) + + + return input + self.dropout(self.out_proj(y)) + + +class SkipBlock(nn.Module): + def __init__(self, *main): + super().__init__() + self.main = nn.Sequential(*main) + + def forward(self, input): + return torch.cat([self.main(input), input], dim=1) + + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn( + [out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + + +def expand_to_planes(input, shape): + return input[..., None].repeat([1, 1, shape[2]]) + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, (self.pad,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv1d(x, weight, stride=2) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +class Upsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +def Downsample1d_2( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor, + padding=factor * (kernel_multiplier // 2), + ) + + +def Upsample1d_2( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + else: + return nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor, + padding=factor // 2 + factor % 2, + output_padding=factor % 2, + ) + + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + + +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +#rms_norm = torch.compile(rms_norm) + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) + + +def normalize(x, eps=1e-4): + dim = list(range(1, x.ndim)) + n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) + alpha = np.sqrt(n.numel() / x.numel()) + return x / torch.add(eps, n, alpha=alpha) + + +class ForcedWNConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1): + super().__init__() + self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(self.weight)) + + fan_in = self.weight[0].numel() + + w = normalize(self.weight) / math.sqrt(fan_in) + + return F.conv1d(x, w, padding='same') + +# Kernels + +use_compile = True + +def compile(function, *args, **kwargs): + if not use_compile: + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +# Layers + + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, fix_scale = False, eps=1e-6): + super().__init__() + self.eps = eps + + if fix_scale: + self.register_buffer("scale", torch.ones(shape)) + else: + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + + +# jit script make it 1.4x faster and save GPU memory +@torch.jit.script +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +# try: +# snake_beta = torch.compile(snake_beta) +# except RuntimeError: +# pass + + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: + # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + # self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x \ No newline at end of file diff --git a/src/modules/stable_vae/models/bottleneck.py b/src/modules/stable_vae/models/bottleneck.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2254c97fbb5d1f4bbf7150106b5f3a154e3705 --- /dev/null +++ b/src/modules/stable_vae/models/bottleneck.py @@ -0,0 +1,346 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from einops import rearrange +from vector_quantize_pytorch import ResidualVQ, FSQ +from .nn.quantize import ResidualVectorQuantize as DACResidualVQ + + +class Bottleneck(nn.Module): + def __init__(self, is_discrete: bool = False): + super().__init__() + + self.is_discrete = is_discrete + + def encode(self, x, return_info=False, **kwargs): + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError + + +class DiscreteBottleneck(Bottleneck): + def __init__(self, num_quantizers, codebook_size, tokens_id): + super().__init__(is_discrete=True) + + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.tokens_id = tokens_id + + def decode_tokens(self, codes, **kwargs): + raise NotImplementedError + + +class TanhBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + self.tanh = nn.Tanh() + + def encode(self, x, return_info=False): + info = {} + + x = torch.tanh(x) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + +@torch.jit.script +def vae_sample_kl(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + + +@torch.jit.script +def vae_sample(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + latents = torch.randn_like(mean) * stdev + mean + return latents + + +class VAEBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False, **kwargs): + mean, scale = x.chunk(2, dim=1) + + if return_info: + info = {} + x, kl = vae_sample_kl(mean, scale) + info["kl"] = kl + return x, info + else: + x = vae_sample(mean, scale) + return x + + def decode(self, x): + return x + + +def compute_mean_kernel(x, y): + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + + +def compute_mmd(latents): + latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) + noise = torch.randn_like(latents_reshaped) + + latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) + noise_kernel = compute_mean_kernel(noise, noise) + latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) + + mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel + return mmd.mean() + + +class WassersteinBottleneck(Bottleneck): + def __init__(self, noise_augment_dim: int = 0): + super().__init__(is_discrete=False) + + self.noise_augment_dim = noise_augment_dim + + def encode(self, x, return_info=False): + info = {} + + if self.training and return_info: + mmd = compute_mmd(x) + info["mmd"] = mmd + + if return_info: + return x, info + + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + +class L2Bottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False): + info = {} + + x = F.normalize(x, dim=1) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return F.normalize(x, dim=1) + + +class RVQBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False, **kwargs): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + + +class RVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False): + info = {} + + x, kl = vae_sample(*x.chunk(2, dim=1)) + + info["kl"] = kl + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + + +class DACRVQBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, **kwargs): + info = {} + + info["pre_quantizer"] = x + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + + +class DACRVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, n_quantizers: int = None): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["pre_quantizer"] = x + info["kl"] = kl + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + + +class FSQBottleneck(DiscreteBottleneck): + def __init__(self, dim, levels): + super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices") + self.quantizer = FSQ(levels=[levels] * dim) + + def encode(self, x, return_info=False): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, tokens, **kwargs): + latents = self.quantizer.indices_to_codes(tokens) + + return self.decode(latents, **kwargs) \ No newline at end of file diff --git a/src/modules/stable_vae/models/factory.py b/src/modules/stable_vae/models/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..46292eca5f5367bef4ac7133beada867c1e43f9a --- /dev/null +++ b/src/modules/stable_vae/models/factory.py @@ -0,0 +1,153 @@ +import json + +def create_model_from_config(model_config): + model_type = model_config.get('model_type', None) + + assert model_type is not None, 'model_type must be specified in model config' + + if model_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + return create_autoencoder_from_config(model_config) + elif model_type == 'diffusion_uncond': + from .diffusion import create_diffusion_uncond_from_config + return create_diffusion_uncond_from_config(model_config) + elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": + from .diffusion import create_diffusion_cond_from_config + return create_diffusion_cond_from_config(model_config) + elif model_type == 'diffusion_autoencoder': + from .autoencoders import create_diffAE_from_config + return create_diffAE_from_config(model_config) + elif model_type == 'lm': + from .lm import create_audio_lm_from_config + return create_audio_lm_from_config(model_config) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_model_from_config_path(model_config_path): + with open(model_config_path) as f: + model_config = json.load(f) + + return create_model_from_config(model_config) + +def create_pretransform_from_config(pretransform_config, sample_rate): + pretransform_type = pretransform_config.get('type', None) + + assert pretransform_type is not None, 'type must be specified in pretransform config' + + if pretransform_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + from .pretransforms import AutoencoderPretransform + + # Create fake top-level config to pass sample rate to autoencoder constructor + # This is a bit of a hack but it keeps us from re-defining the sample rate in the config + autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} + autoencoder = create_autoencoder_from_config(autoencoder_config) + + scale = pretransform_config.get("scale", 1.0) + model_half = pretransform_config.get("model_half", False) + iterate_batch = pretransform_config.get("iterate_batch", False) + chunked = pretransform_config.get("chunked", False) + + pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) + elif pretransform_type == 'wavelet': + from .pretransforms import WaveletPretransform + + wavelet_config = pretransform_config["config"] + channels = wavelet_config["channels"] + levels = wavelet_config["levels"] + wavelet = wavelet_config["wavelet"] + + pretransform = WaveletPretransform(channels, levels, wavelet) + elif pretransform_type == 'pqmf': + from .pretransforms import PQMFPretransform + pqmf_config = pretransform_config["config"] + pretransform = PQMFPretransform(**pqmf_config) + elif pretransform_type == 'dac_pretrained': + from .pretransforms import PretrainedDACPretransform + pretrained_dac_config = pretransform_config["config"] + pretransform = PretrainedDACPretransform(**pretrained_dac_config) + elif pretransform_type == "audiocraft_pretrained": + from .pretransforms import AudiocraftCompressionPretransform + + audiocraft_config = pretransform_config["config"] + pretransform = AudiocraftCompressionPretransform(**audiocraft_config) + else: + raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') + + enable_grad = pretransform_config.get('enable_grad', False) + pretransform.enable_grad = enable_grad + + pretransform.eval().requires_grad_(pretransform.enable_grad) + + return pretransform + +def create_bottleneck_from_config(bottleneck_config): + bottleneck_type = bottleneck_config.get('type', None) + + assert bottleneck_type is not None, 'type must be specified in bottleneck config' + + if bottleneck_type == 'tanh': + from .bottleneck import TanhBottleneck + bottleneck = TanhBottleneck() + elif bottleneck_type == 'vae': + from .bottleneck import VAEBottleneck + bottleneck = VAEBottleneck() + elif bottleneck_type == 'rvq': + from .bottleneck import RVQBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQBottleneck(**quantizer_params) + elif bottleneck_type == "dac_rvq": + from .bottleneck import DACRVQBottleneck + + bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) + + elif bottleneck_type == 'rvq_vae': + from .bottleneck import RVQVAEBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQVAEBottleneck(**quantizer_params) + + elif bottleneck_type == 'dac_rvq_vae': + from .bottleneck import DACRVQVAEBottleneck + bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) + elif bottleneck_type == 'l2_norm': + from .bottleneck import L2Bottleneck + bottleneck = L2Bottleneck() + elif bottleneck_type == "wasserstein": + from .bottleneck import WassersteinBottleneck + bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) + elif bottleneck_type == "fsq": + from .bottleneck import FSQBottleneck + bottleneck = FSQBottleneck(**bottleneck_config["config"]) + else: + raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') + + requires_grad = bottleneck_config.get('requires_grad', True) + if not requires_grad: + for param in bottleneck.parameters(): + param.requires_grad = False + + return bottleneck diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2f6b972aeb8a21764ff73dae2095eb94bb8ba4 --- /dev/null +++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..63c53a88b881deab94fb06b6a395951cf9cc995a --- /dev/null +++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7a26e62f885f0bb7c6774e95660a7848e10a8f87 --- /dev/null +++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..499088fbde266dc89a704e45944b8900f2c72952 --- /dev/null +++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from .layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/src/modules/stable_vae/models/nn/__init__.py b/src/modules/stable_vae/models/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2f6b972aeb8a21764ff73dae2095eb94bb8ba4 --- /dev/null +++ b/src/modules/stable_vae/models/nn/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd83e18bd22222ca6b9ce0f0ab056cc026747bb3 Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e1c8ae0e314e5ef76519da0314434a8c8ffe772 Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57a85d0e28029662f1ecc2790e44f71caa09cd0c Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecebbf1fbbca206d55b66b3476169545e09b115d Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47d3c0daa35a755146c3ebf4f49d430469fe0c6a Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3bc1a6e78fa00b5b451d12827d0f9d504e3db22 Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d960f1c6e8f02f6b1c3b72b9d60543ceadb619cf Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1f38a6916548fcdf0e9c2d4caafa7e5b93d9c10 Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/nn/layers.py b/src/modules/stable_vae/models/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..63c53a88b881deab94fb06b6a395951cf9cc995a --- /dev/null +++ b/src/modules/stable_vae/models/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/src/modules/stable_vae/models/nn/loss.py b/src/modules/stable_vae/models/nn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7a26e62f885f0bb7c6774e95660a7848e10a8f87 --- /dev/null +++ b/src/modules/stable_vae/models/nn/loss.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/src/modules/stable_vae/models/nn/quantize.py b/src/modules/stable_vae/models/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..499088fbde266dc89a704e45944b8900f2c72952 --- /dev/null +++ b/src/modules/stable_vae/models/nn/quantize.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from .layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/src/modules/stable_vae/models/pretransforms.py b/src/modules/stable_vae/models/pretransforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4ffbbebd379695d123e5155f1387afe4d824f1 --- /dev/null +++ b/src/modules/stable_vae/models/pretransforms.py @@ -0,0 +1,258 @@ +import torch +from einops import rearrange +from torch import nn + +class Pretransform(nn.Module): + def __init__(self, enable_grad, io_channels, is_discrete): + super().__init__() + + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + + self.enable_grad = enable_grad + + def encode(self, x): + raise NotImplementedError + + def decode(self, z): + raise NotImplementedError + + def tokenize(self, x): + raise NotImplementedError + + def decode_tokens(self, tokens): + raise NotImplementedError + +class AutoencoderPretransform(Pretransform): + def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): + super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) + self.model = model + self.model.requires_grad_(False).eval() + self.scale=scale + self.downsampling_ratio = model.downsampling_ratio + self.io_channels = model.io_channels + self.sample_rate = model.sample_rate + + self.model_half = model_half + self.iterate_batch = iterate_batch + + self.encoded_channels = model.latent_dim + + self.chunked = chunked + self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None + self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + + if self.model_half: + self.model.half() + + def encode(self, x, **kwargs): + + if self.model_half: + x = x.half() + self.model.to(torch.float16) + + encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + encoded = encoded.float() + + return encoded / self.scale + + def decode(self, z, **kwargs): + z = z * self.scale + + if self.model_half: + z = z.half() + self.model.to(torch.float16) + + decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + decoded = decoded.float() + + return decoded + + def tokenize(self, x, **kwargs): + assert self.model.is_discrete, "Cannot tokenize with a continuous model" + + _, info = self.model.encode(x, return_info = True, **kwargs) + + return info[self.model.bottleneck.tokens_id] + + def decode_tokens(self, tokens, **kwargs): + assert self.model.is_discrete, "Cannot decode tokens with a continuous model" + + return self.model.decode_tokens(tokens, **kwargs) + + def load_state_dict(self, state_dict, strict=True): + self.model.load_state_dict(state_dict, strict=strict) + +class WaveletPretransform(Pretransform): + def __init__(self, channels, levels, wavelet): + super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) + + from .wavelets import WaveletEncode1d, WaveletDecode1d + + self.encoder = WaveletEncode1d(channels, levels, wavelet) + self.decoder = WaveletDecode1d(channels, levels, wavelet) + + self.downsampling_ratio = 2 ** levels + self.io_channels = channels + self.encoded_channels = channels * self.downsampling_ratio + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + return self.decoder(z) + +class PQMFPretransform(Pretransform): + def __init__(self, attenuation=100, num_bands=16): + # TODO: Fix PQMF to take in in-channels + super().__init__(enable_grad=False, io_channels=1, is_discrete=False) + from .pqmf import PQMF + self.pqmf = PQMF(attenuation, num_bands) + + + def encode(self, x): + # x is (Batch x Channels x Time) + x = self.pqmf.forward(x) + # pqmf.forward returns (Batch x Channels x Bands x Time) + # but Pretransform needs Batch x Channels x Time + # so concatenate channels and bands into one axis + return rearrange(x, "b c n t -> b (c n) t") + + def decode(self, x): + # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) + x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) + # returns (Batch x Channels x Time) + return self.pqmf.inverse(x) + +class PretrainedDACPretransform(Pretransform): + def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + import dac + + model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) + + self.model = dac.DAC.load(model_path) + + self.quantize_on_decode = quantize_on_decode + + if model_type == "44khz": + self.downsampling_ratio = 512 + else: + self.downsampling_ratio = 320 + + self.io_channels = 1 + + self.scale = scale + + self.chunked = chunked + + self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.n_codebooks + + self.codebook_size = self.model.codebook_size + + def encode(self, x): + + latents = self.model.encoder(x) + + if self.quantize_on_decode: + output = latents + else: + z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + output = z + + if self.scale != 1.0: + output = output / self.scale + + return output + + def decode(self, z): + + if self.scale != 1.0: + z = z * self.scale + + if self.quantize_on_decode: + z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + return self.model.decode(z) + + def tokenize(self, x): + return self.model.encode(x)[1] + + def decode_tokens(self, tokens): + latents = self.model.quantizer.from_codes(tokens) + return self.model.decode(latents) + +class AudiocraftCompressionPretransform(Pretransform): + def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + try: + from audiocraft.models import CompressionModel + except ImportError: + raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") + + self.model = CompressionModel.get_pretrained(model_type) + + self.quantize_on_decode = quantize_on_decode + + self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) + + self.sample_rate = self.model.sample_rate + + self.io_channels = self.model.channels + + self.scale = scale + + #self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.num_codebooks + + self.codebook_size = self.model.cardinality + + self.model.to(torch.float16).eval().requires_grad_(False) + + def encode(self, x): + + assert False, "Audiocraft compression models do not support continuous encoding" + + # latents = self.model.encoder(x) + + # if self.quantize_on_decode: + # output = latents + # else: + # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + # output = z + + # if self.scale != 1.0: + # output = output / self.scale + + # return output + + def decode(self, z): + + assert False, "Audiocraft compression models do not support continuous decoding" + + # if self.scale != 1.0: + # z = z * self.scale + + # if self.quantize_on_decode: + # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + # return self.model.decode(z) + + def tokenize(self, x): + with torch.cuda.amp.autocast(enabled=False): + return self.model.encode(x.to(torch.float16))[0] + + def decode_tokens(self, tokens): + with torch.cuda.amp.autocast(enabled=False): + return self.model.decode(tokens) diff --git a/src/modules/stable_vae/models/utils.py b/src/modules/stable_vae/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7adc0a9b4d72ed1e1d64262dfd71265fee9c9a4a --- /dev/null +++ b/src/modules/stable_vae/models/utils.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torchaudio import transforms as T + + +class PadCrop(nn.Module): + def __init__(self, n_samples, randomize=True): + super().__init__() + self.n_samples = n_samples + self.randomize = randomize + + def __call__(self, signal): + n, s = signal.shape + start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() + end = start + self.n_samples + output = signal.new_zeros([n, self.n_samples]) + output[:, :min(s, self.n_samples)] = signal[:, start:end] + return output + + +def set_audio_channels(audio, target_channels): + if target_channels == 1: + # Convert to mono + audio = audio.mean(1, keepdim=True) + elif target_channels == 2: + # Convert to stereo + if audio.shape[1] == 1: + audio = audio.repeat(1, 2, 1) + elif audio.shape[1] > 2: + audio = audio[:, :2, :] + return audio + +def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): + + audio = audio.to(device) + + if in_sr != target_sr: + resample_tf = T.Resample(in_sr, target_sr).to(device) + audio = resample_tf(audio) + + audio = PadCrop(target_length, randomize=False)(audio) + + # Add batch dimension + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + + audio = set_audio_channels(audio, target_channels) + + return audio \ No newline at end of file diff --git a/src/test.py b/src/test.py new file mode 100644 index 0000000000000000000000000000000000000000..b605f0f29482d06510b618691ff74a718816aba4 --- /dev/null +++ b/src/test.py @@ -0,0 +1,97 @@ +import random +import argparse +import os +import time +import soundfile as sf +import numpy as np +import pandas as pd +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers import DDIMScheduler +from models.conditioners import MaskDiT +from modules.autoencoder_wrapper import Autoencoder +from transformers import T5Tokenizer, T5EncoderModel +from inference import inference +from utils import scale_shift, get_lr_scheduler, compute_snr, load_yaml_with_includes + + +parser = argparse.ArgumentParser() +# config settings +parser.add_argument('--config-name', type=str, default='configs/udit_ada.yml') +parser.add_argument('--ckpt-path', type=str, default='../ckpts/') +parser.add_argument('--ckpt-id', type=str, default='120') +parser.add_argument('--save_path', type=str, default='../output/') +parser.add_argument('--test-df', type=str, default='audiocaps_test.csv') +# parser.add_argument('--test-split', type=str, default='test') + +parser.add_argument('--device', type=str, default='cuda') +parser.add_argument('--guidance-scale', type=float, default=3) +parser.add_argument('--guidance-rescale', type=float, default=0) +parser.add_argument('--ddim-steps', type=int, default=50) +parser.add_argument('--eta', type=float, default=1) +parser.add_argument('--random-seed', type=int, default=None) + +args = parser.parse_args() +params = load_yaml_with_includes(args.config_name) + +# args.ckpt_path = f"{args.ckpt_path}/{params['model_name']}/{args.ckpt_id}.pt" +args.save_path = f"{args.save_path}/{params['model_name']}/{args.ckpt_id}_{args.ddim_steps}_{args.guidance_scale}_{args.guidance_rescale}/" +args.ckpt_path = f"{args.ckpt_path}/{args.ckpt_id}.pt" + +if __name__ == '__main__': + # Codec Model + autoencoder = Autoencoder(ckpt_path=params['autoencoder']['path'], + model_type=params['autoencoder']['name'], + quantization_first=params['autoencoder']['q_first']) + autoencoder.to(args.device) + autoencoder.eval() + + # text encoder + tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model']) + text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model'], + device_map='cpu').to(args.device) + text_encoder.eval() + + # main U-Net + unet = MaskDiT(**params['model']).to(args.device) + unet.eval() + unet.load_state_dict(torch.load(args.ckpt_path)['model']) + + total_params = sum([param.nelement() for param in unet.parameters()]) + print("Number of parameter: %.2fM" % (total_params / 1e6)) + + noise_scheduler = DDIMScheduler(**params['diff']) + # these steps reset dtype of noise_scheduler params + latents = torch.randn((1, 128, 128), device=args.device) + noise = torch.randn_like(latents) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=args.device) + _ = noise_scheduler.add_noise(latents, noise, timesteps) + + df = pd.read_csv(args.test_df) + # Wdf = df[df['split'] == args.test_split] + df = df[df['audio_length'] != 0] + # df = df.sample(10) + os.makedirs(args.save_path, exist_ok=True) + audio_frames = params['data']['train_frames'] + + for i in tqdm(range(len(df))): + row = df.iloc[i] + text = row['caption'] + audio_id = row['audiocap_id'] + + pred = inference(autoencoder, unet, None, None, + tokenizer, text_encoder, + params, noise_scheduler, + text, None, + audio_frames, + args.guidance_scale, args.guidance_rescale, + args.ddim_steps, args.eta, args.random_seed, + args.device) + pred = pred.cpu().numpy().squeeze(0).squeeze(0) + + sf.write(f"{args.save_path}/{audio_id}.wav", + pred, samplerate=params['data']['sr']) \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90f60fdd89ad8575faafe45188bd1d968852fc67 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/src/utils/utils.py b/src/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82a6bc56e9e341e54dc6a136f1f78261dde0f655 --- /dev/null +++ b/src/utils/utils.py @@ -0,0 +1,94 @@ +import torch +import numpy as np +import yaml +import os + + +def load_yaml_with_includes(yaml_file): + def loader_with_include(loader, node): + # Load the included file + include_path = os.path.join(os.path.dirname(yaml_file), loader.construct_scalar(node)) + with open(include_path, 'r') as f: + return yaml.load(f, Loader=yaml.FullLoader) + + yaml.add_constructor('!include', loader_with_include, Loader=yaml.FullLoader) + + with open(yaml_file, 'r') as f: + return yaml.load(f, Loader=yaml.FullLoader) + + +def scale_shift(x, scale, shift): + return (x+shift) * scale + + +def scale_shift_re(x, scale, shift): + return (x/scale) - shift + + +def align_seq(source, target_length, mapping_method='hard'): + source_len = source.shape[1] + if mapping_method == 'hard': + mapping_idx = np.round(np.arange(target_length) * source_len / target_length) + output = source[:, mapping_idx] + else: + # TBD + raise NotImplementedError + + return output + + +def customized_lr_scheduler(optimizer, warmup_steps=-1): + from torch.optim.lr_scheduler import LambdaLR + + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'customized': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + + +def compute_snr(noise_scheduler, timesteps): + """ + Computes SNR as per + https://github.com/TiankaiHang/Min-SNR-Diffusion + Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion + # Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + +if __name__ == "__main__": + + a = torch.rand(2, 10) + target_len = 15 + + b = align_seq(a, target_len) \ No newline at end of file