diff --git a/api.py b/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..531c8197686ff15be7e0a00b065a3309bd9a781c
--- /dev/null
+++ b/api.py
@@ -0,0 +1,117 @@
+import os
+import torch
+import random
+import numpy as np
+import gradio as gr
+import soundfile as sf
+from transformers import T5Tokenizer, T5EncoderModel
+from diffusers import DDIMScheduler
+from src.models.conditioners import MaskDiT
+from src.modules.autoencoder_wrapper import Autoencoder
+from src.inference import inference
+from src.utils import load_yaml_with_includes
+
+
+# Load model and configs
+def load_models(config_name, ckpt_path, vae_path, device):
+ params = load_yaml_with_includes(config_name)
+
+ # Load codec model
+ autoencoder = Autoencoder(ckpt_path=vae_path,
+ model_type=params['autoencoder']['name'],
+ quantization_first=params['autoencoder']['q_first']).to(device)
+ autoencoder.eval()
+
+ # Load text encoder
+ tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
+ text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device)
+ text_encoder.eval()
+
+ # Load main U-Net model
+ unet = MaskDiT(**params['model']).to(device)
+ unet.load_state_dict(torch.load(ckpt_path)['model'])
+ unet.eval()
+
+ # Load noise scheduler
+ noise_scheduler = DDIMScheduler(**params['diff'])
+
+ return autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params
+
+MAX_SEED = np.iinfo(np.int32).max
+
+# Model and config paths
+config_name = 'ckpts/ezaudio-xl.yml'
+ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt'
+vae_path = 'ckpts/vae/1m.pt'
+save_path = 'output/'
+os.makedirs(save_path, exist_ok=True)
+
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path,
+ device)
+
+latents = torch.randn((1, 128, 128), device=device)
+noise = torch.randn_like(latents)
+timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device)
+_ = noise_scheduler.add_noise(latents, noise, timesteps)
+
+
+# Inference function
+def generate_audio(text, length,
+ guidance_scale, guidance_rescale, ddim_steps, eta,
+ random_seed, randomize_seed):
+ neg_text = None
+ length = length * params['autoencoder']['latent_sr']
+
+ if randomize_seed:
+ random_seed = random.randint(0, MAX_SEED)
+
+ pred = inference(autoencoder, unet, None, None,
+ tokenizer, text_encoder,
+ params, noise_scheduler,
+ text, neg_text,
+ length,
+ guidance_scale, guidance_rescale,
+ ddim_steps, eta, random_seed,
+ device)
+
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)
+ # output_file = f"{save_path}/{text}.wav"
+ # sf.write(output_file, pred, samplerate=params['autoencoder']['sr'])
+
+ return params['autoencoder']['sr'], pred
+
+
+# Gradio Interface
+def gradio_interface():
+ # Input components
+ text_input = gr.Textbox(label="Text Prompt", value="the sound of dog barking")
+ length_input = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)")
+
+ # Advanced settings
+ guidance_scale_input = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5, label="Guidance Scale")
+ guidance_rescale_input = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale")
+ ddim_steps_input = gr.Slider(minimum=25, maximum=200, step=5, value=100, label="DDIM Steps")
+ eta_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Eta")
+ random_seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0,)
+
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
+
+ # Output component
+ output_audio = gr.Audio(label="Converted Audio", type="numpy")
+
+ # Interface
+ gr.Interface(
+ fn=generate_audio,
+ inputs=[text_input, length_input, guidance_scale_input, guidance_rescale_input, ddim_steps_input, eta_input,
+ random_seed_input, randomize_seed],
+ outputs=output_audio,
+ title="EzAudio Text-to-Audio Generator",
+ description="Generate audio from text using a diffusion model. Adjust advanced settings for more control.",
+ allow_flagging="never"
+ ).launch()
+
+
+if __name__ == "__main__":
+ gradio_interface()
diff --git a/src/.idea/.gitignore b/src/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..1c2fda565b94d0f2b94cb65ba7cca866e7a25478
--- /dev/null
+++ b/src/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/src/.idea/inspectionProfiles/Project_Default.xml b/src/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000000000000000000000000000000000..fbb15db07580fc05034be15a48d3471c912c3f63
--- /dev/null
+++ b/src/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,34 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/.idea/inspectionProfiles/profiles_settings.xml b/src/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/src/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/.idea/misc.xml b/src/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..9de286525ff35cf3ec9f171ef56fd1557939f2a0
--- /dev/null
+++ b/src/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/.idea/modules.xml b/src/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..7210df4ddc1a7504320efcdcec08b9961df50d0e
--- /dev/null
+++ b/src/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/.idea/src.iml b/src/.idea/src.iml
new file mode 100644
index 0000000000000000000000000000000000000000..2946dc0d137bdbc1a0db1f730491160c3bdb883e
--- /dev/null
+++ b/src/.idea/src.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/.idea/workspace.xml b/src/.idea/workspace.xml
new file mode 100644
index 0000000000000000000000000000000000000000..425bc7ce443a61597f14ee784049a292dabf9599
--- /dev/null
+++ b/src/.idea/workspace.xml
@@ -0,0 +1,128 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1726457759523
+
+
+ 1726457759523
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/inference.py b/src/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..93a9fd81a73480fa348a225a70cba20bcc35dc93
--- /dev/null
+++ b/src/inference.py
@@ -0,0 +1,169 @@
+import os
+import random
+import pandas as pd
+import torch
+import librosa
+import numpy as np
+import soundfile as sf
+from tqdm import tqdm
+from utils import scale_shift_re
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+@torch.no_grad()
+def inference(autoencoder, unet, gt, gt_mask,
+ tokenizer, text_encoder,
+ params, noise_scheduler,
+ text_raw, neg_text=None,
+ audio_frames=500,
+ guidance_scale=3, guidance_rescale=0.0,
+ ddim_steps=50, eta=1, random_seed=2024,
+ device='cuda',
+ ):
+ if neg_text is None:
+ neg_text = [""]
+ if tokenizer is not None:
+ text_batch = tokenizer(text_raw,
+ max_length=params['text_encoder']['max_length'],
+ padding="max_length", truncation=True, return_tensors="pt")
+ text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
+ text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
+
+ uncond_text_batch = tokenizer(neg_text,
+ max_length=params['text_encoder']['max_length'],
+ padding="max_length", truncation=True, return_tensors="pt")
+ uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
+ uncond_text = text_encoder(input_ids=uncond_text,
+ attention_mask=uncond_text_mask).last_hidden_state
+ else:
+ text, text_mask = None, None
+ guidance_scale = None
+
+ codec_dim = params['model']['out_chans']
+ unet.eval()
+
+ if random_seed is not None:
+ generator = torch.Generator(device=device).manual_seed(random_seed)
+ else:
+ generator = torch.Generator(device=device)
+ generator.seed()
+
+ noise_scheduler.set_timesteps(ddim_steps)
+
+ # init noise
+ noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
+ latents = noise
+
+ for t in noise_scheduler.timesteps:
+ latents = noise_scheduler.scale_model_input(latents, t)
+
+ if guidance_scale:
+
+ latents_combined = torch.cat([latents, latents], dim=0)
+ text_combined = torch.cat([text, uncond_text], dim=0)
+ text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
+
+ if gt is not None:
+ gt_combined = torch.cat([gt, gt], dim=0)
+ gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
+ else:
+ gt_combined = None
+ gt_mask_combined = None
+
+ output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
+ cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
+ output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
+
+ output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
+ if guidance_rescale > 0.0:
+ output_pred = rescale_noise_cfg(output_pred, output_text,
+ guidance_rescale=guidance_rescale)
+ else:
+ output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
+ cls_token=None, gt=gt, mae_mask_infer=gt_mask)
+
+ latents = noise_scheduler.step(model_output=output_pred, timestep=t,
+ sample=latents,
+ eta=eta, generator=generator).prev_sample
+
+ pred = scale_shift_re(latents, params['autoencoder']['scale'],
+ params['autoencoder']['shift'])
+ if gt is not None:
+ pred[~gt_mask] = gt[~gt_mask]
+ pred_wav = autoencoder(embedding=pred)
+ return pred_wav
+
+
+@torch.no_grad()
+def eval_udit(autoencoder, unet,
+ tokenizer, text_encoder,
+ params, noise_scheduler,
+ val_df, subset,
+ audio_frames, mae=False,
+ guidance_scale=3, guidance_rescale=0.0,
+ ddim_steps=50, eta=1, random_seed=2023,
+ device='cuda',
+ epoch=0, save_path='logs/eval/', val_num=5):
+ val_df = pd.read_csv(val_df)
+ val_df = val_df[val_df['split'] == subset]
+ if mae:
+ val_df = val_df[val_df['audio_length'] != 0]
+
+ save_path = save_path + str(epoch) + '/'
+ os.makedirs(save_path, exist_ok=True)
+
+ for i in tqdm(range(len(val_df))):
+ row = val_df.iloc[i]
+ text = [row['caption']]
+ if mae:
+ audio_path = params['data']['val_dir'] + str(row['audio_path'])
+ gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
+ gt = gt / (np.max(np.abs(gt)) + 1e-9)
+ sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
+ num_samples = 10 * sr
+ if len(gt) < num_samples:
+ padding = num_samples - len(gt)
+ gt = np.pad(gt, (0, padding), 'constant')
+ else:
+ gt = gt[:num_samples]
+ gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
+ gt = autoencoder(audio=gt)
+ B, D, L = gt.shape
+ mask_len = int(L * 0.2)
+ gt_mask = torch.zeros(B, D, L).to(device)
+ for _ in range(2):
+ start = random.randint(0, L - mask_len)
+ gt_mask[:, :, start:start + mask_len] = 1
+ gt_mask = gt_mask.bool()
+ else:
+ gt = None
+ gt_mask = None
+
+ pred = inference(autoencoder, unet, gt, gt_mask,
+ tokenizer, text_encoder,
+ params, noise_scheduler,
+ text, neg_text=None,
+ audio_frames=audio_frames,
+ guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
+ ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
+ device=device)
+
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)
+
+ sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
+
+ if i + 1 >= val_num:
+ break
diff --git a/src/models/blocks.py b/src/models/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..07fb79082e066098cd465c0317fc0fe6c7285266
--- /dev/null
+++ b/src/models/blocks.py
@@ -0,0 +1,325 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from .utils.attention import Attention, JointAttention
+from .utils.modules import unpatchify, FeedForward
+from .utils.modules import film_modulate
+
+
+class AdaLN(nn.Module):
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
+ super().__init__()
+ self.ada_mode = ada_mode
+ self.scale_shift_table = None
+ if ada_mode == 'ada':
+ # move nn.silu outside
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
+ elif ada_mode == 'ada_single':
+ # adaln used in pixel-art alpha
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
+ elif ada_mode in ['ada_lora', 'ada_lora_bias']:
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
+ self.scaling = alpha / r
+ if ada_mode == 'ada_lora_bias':
+ # take bias out for consistency
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
+ else:
+ raise NotImplementedError
+
+ def forward(self, time_token=None, time_ada=None):
+ if self.ada_mode == 'ada':
+ assert time_ada is None
+ B = time_token.shape[0]
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
+ elif self.ada_mode == 'ada_single':
+ B = time_ada.shape[0]
+ time_ada = time_ada.reshape(B, 6, -1)
+ time_ada = self.scale_shift_table[None] + time_ada
+ elif self.ada_mode in ['ada_lora', 'ada_lora_bias']:
+ B = time_ada.shape[0]
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
+ time_ada = time_ada + time_ada_lora
+ time_ada = time_ada.reshape(B, 6, -1)
+ if self.scale_shift_table is not None:
+ time_ada = self.scale_shift_table[None] + time_ada
+ else:
+ raise NotImplementedError
+ return time_ada
+
+
+class DiTBlock(nn.Module):
+ """
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
+ """
+
+ def __init__(self, dim, context_dim=None,
+ num_heads=8, mlp_ratio=4.,
+ qkv_bias=False, qk_scale=None, qk_norm=None,
+ act_layer='gelu', norm_layer=nn.LayerNorm,
+ time_fusion='none',
+ ada_lora_rank=None, ada_lora_alpha=None,
+ skip=False, skip_norm=False,
+ rope_mode='none',
+ context_norm=False,
+ use_checkpoint=False):
+
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim=dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ rope_mode=rope_mode)
+
+ if context_dim is not None:
+ self.use_context = True
+ self.cross_attn = Attention(dim=dim,
+ num_heads=num_heads,
+ context_dim=context_dim,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ rope_mode='none')
+ self.norm2 = norm_layer(dim)
+ if context_norm:
+ self.norm_context = norm_layer(context_dim)
+ else:
+ self.norm_context = nn.Identity()
+ else:
+ self.use_context = False
+
+ self.norm3 = norm_layer(dim)
+ self.mlp = FeedForward(dim=dim, mult=mlp_ratio,
+ activation_fn=act_layer, dropout=0)
+
+ self.use_adanorm = True if time_fusion != 'token' else False
+ if self.use_adanorm:
+ self.adaln = AdaLN(dim, ada_mode=time_fusion,
+ r=ada_lora_rank, alpha=ada_lora_alpha)
+ if skip:
+ self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
+ self.skip_linear = nn.Linear(2 * dim, dim)
+ else:
+ self.skip_linear = None
+
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, x, time_token=None, time_ada=None,
+ skip=None, context=None,
+ x_mask=None, context_mask=None, extras=None):
+ if self.use_checkpoint:
+ return checkpoint(self._forward, x,
+ time_token, time_ada, skip, context,
+ x_mask, context_mask, extras,
+ use_reentrant=False)
+ else:
+ return self._forward(x,
+ time_token, time_ada, skip, context,
+ x_mask, context_mask, extras)
+
+ def _forward(self, x, time_token=None, time_ada=None,
+ skip=None, context=None,
+ x_mask=None, context_mask=None, extras=None):
+ B, T, C = x.shape
+ if self.skip_linear is not None:
+ assert skip is not None
+ cat = torch.cat([x, skip], dim=-1)
+ cat = self.skip_norm(cat)
+ x = self.skip_linear(cat)
+
+ if self.use_adanorm:
+ time_ada = self.adaln(time_token, time_ada)
+ (shift_msa, scale_msa, gate_msa,
+ shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
+
+ # self attention
+ if self.use_adanorm:
+ x_norm = film_modulate(self.norm1(x), shift=shift_msa,
+ scale=scale_msa)
+ x = x + (1 - gate_msa) * self.attn(x_norm, context=None,
+ context_mask=x_mask,
+ extras=extras)
+ else:
+ x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask,
+ extras=extras)
+
+ # cross attention
+ if self.use_context:
+ assert context is not None
+ x = x + self.cross_attn(x=self.norm2(x),
+ context=self.norm_context(context),
+ context_mask=context_mask, extras=extras)
+
+ # mlp
+ if self.use_adanorm:
+ x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
+ else:
+ x = x + self.mlp(self.norm3(x))
+
+ return x
+
+
+class JointDiTBlock(nn.Module):
+ """
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
+ """
+
+ def __init__(self, dim, context_dim=None,
+ num_heads=8, mlp_ratio=4.,
+ qkv_bias=False, qk_scale=None, qk_norm=None,
+ act_layer='gelu', norm_layer=nn.LayerNorm,
+ time_fusion='none',
+ ada_lora_rank=None, ada_lora_alpha=None,
+ skip=(False, False),
+ rope_mode=False,
+ context_norm=False,
+ use_checkpoint=False,):
+
+ super().__init__()
+ # no cross attention
+ assert context_dim is None
+ self.attn_norm_x = norm_layer(dim)
+ self.attn_norm_c = norm_layer(dim)
+ self.attn = JointAttention(dim=dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ qk_norm=qk_norm,
+ rope_mode=rope_mode)
+ self.ffn_norm_x = norm_layer(dim)
+ self.ffn_norm_c = norm_layer(dim)
+ self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio,
+ activation_fn=act_layer, dropout=0)
+ self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio,
+ activation_fn=act_layer, dropout=0)
+
+ # Zero-out the shift table
+ self.use_adanorm = True if time_fusion != 'token' else False
+ if self.use_adanorm:
+ self.adaln = AdaLN(dim, ada_mode=time_fusion,
+ r=ada_lora_rank, alpha=ada_lora_alpha)
+
+ if skip is False:
+ skip_x, skip_c = False, False
+ else:
+ skip_x, skip_c = skip
+
+ self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None
+ self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None
+
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, x, time_token=None, time_ada=None,
+ skip=None, context=None,
+ x_mask=None, context_mask=None, extras=None):
+ if self.use_checkpoint:
+ return checkpoint(self._forward, x,
+ time_token, time_ada, skip,
+ context, x_mask, context_mask, extras,
+ use_reentrant=False)
+ else:
+ return self._forward(x,
+ time_token, time_ada, skip,
+ context, x_mask, context_mask, extras)
+
+ def _forward(self, x, time_token=None, time_ada=None,
+ skip=None, context=None,
+ x_mask=None, context_mask=None, extras=None):
+
+ assert context is None and context_mask is None
+
+ context, x = x[:, :extras, :], x[:, extras:, :]
+ context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:]
+
+ if skip is not None:
+ skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :]
+
+ B, T, C = x.shape
+ if self.skip_linear_x is not None:
+ x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1))
+
+ if self.skip_linear_c is not None:
+ context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1))
+
+ if self.use_adanorm:
+ time_ada = self.adaln(time_token, time_ada)
+ (shift_msa, scale_msa, gate_msa,
+ shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
+
+ # self attention
+ x_norm = self.attn_norm_x(x)
+ c_norm = self.attn_norm_c(context)
+ if self.use_adanorm:
+ x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa)
+ x_out, c_out = self.attn(x_norm, context=c_norm,
+ x_mask=x_mask, context_mask=context_mask,
+ extras=extras)
+ if self.use_adanorm:
+ x = x + (1 - gate_msa) * x_out
+ else:
+ x = x + x_out
+ context = context + c_out
+
+ # mlp
+ if self.use_adanorm:
+ x_norm = film_modulate(self.ffn_norm_x(x),
+ shift=shift_mlp, scale=scale_mlp)
+ x = x + (1 - gate_mlp) * self.mlp_x(x_norm)
+ else:
+ x = x + self.mlp_x(self.ffn_norm_x(x))
+
+ c_norm = self.ffn_norm_c(context)
+ context = context + self.mlp_c(c_norm)
+
+ return torch.cat((context, x), dim=1)
+
+
+class FinalBlock(nn.Module):
+ def __init__(self, embed_dim, patch_size, in_chans,
+ img_size,
+ input_type='2d',
+ norm_layer=nn.LayerNorm,
+ use_conv=True,
+ use_adanorm=True):
+ super().__init__()
+ self.in_chans = in_chans
+ self.img_size = img_size
+ self.input_type = input_type
+
+ self.norm = norm_layer(embed_dim)
+ if use_adanorm:
+ self.use_adanorm = True
+ else:
+ self.use_adanorm = False
+
+ if input_type == '2d':
+ self.patch_dim = patch_size ** 2 * in_chans
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
+ if use_conv:
+ self.final_layer = nn.Conv2d(self.in_chans, self.in_chans,
+ 3, padding=1)
+ else:
+ self.final_layer = nn.Identity()
+
+ elif input_type == '1d':
+ self.patch_dim = patch_size * in_chans
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
+ if use_conv:
+ self.final_layer = nn.Conv1d(self.in_chans, self.in_chans,
+ 3, padding=1)
+ else:
+ self.final_layer = nn.Identity()
+
+ def forward(self, x, time_ada=None, extras=0):
+ B, T, C = x.shape
+ x = x[:, extras:, :]
+ # only handle generation target
+ if self.use_adanorm:
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
+ x = film_modulate(self.norm(x), shift, scale)
+ else:
+ x = self.norm(x)
+ x = self.linear(x)
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
+ x = self.final_layer(x)
+ return x
\ No newline at end of file
diff --git a/src/models/conditioners.py b/src/models/conditioners.py
new file mode 100644
index 0000000000000000000000000000000000000000..7414cfe2b9dc3dc7dec25e698c2ce43db2765f17
--- /dev/null
+++ b/src/models/conditioners.py
@@ -0,0 +1,180 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import repeat
+import math
+from .udit import UDiT
+from .utils.span_mask import compute_mask_indices
+
+
+class EmbeddingCFG(nn.Module):
+ """
+ Handles label dropout for classifier-free guidance.
+ """
+ # todo: support 2D input
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.cfg_embedding = nn.Parameter(
+ torch.randn(in_channels) / in_channels ** 0.5)
+
+ def token_drop(self, condition, condition_mask, cfg_prob):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ b, t, device = condition.shape[0], condition.shape[1], condition.device
+ drop_ids = torch.rand(b, device=device) < cfg_prob
+ uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t)
+ condition = torch.where(drop_ids[:, None, None], uncond, condition)
+ if condition_mask is not None:
+ condition_mask[drop_ids] = False
+ condition_mask[drop_ids, 0] = True
+
+ return condition, condition_mask
+
+ def forward(self, condition, condition_mask, cfg_prob=0.0):
+ if condition_mask is not None:
+ condition_mask = condition_mask.clone()
+ if cfg_prob > 0:
+ condition, condition_mask = self.token_drop(condition,
+ condition_mask,
+ cfg_prob)
+ return condition, condition_mask
+
+
+class DiscreteCFG(nn.Module):
+ def __init__(self, replace_id=2):
+ super(DiscreteCFG, self).__init__()
+ self.replace_id = replace_id
+
+ def forward(self, context, context_mask, cfg_prob):
+ context = context.clone()
+ if context_mask is not None:
+ context_mask = context_mask.clone()
+ if cfg_prob > 0:
+ cfg_mask = torch.rand(len(context)) < cfg_prob
+ if torch.any(cfg_mask):
+ context[cfg_mask] = 0
+ context[cfg_mask, 0] = self.replace_id
+ if context_mask is not None:
+ context_mask[cfg_mask] = False
+ context_mask[cfg_mask, 0] = True
+ return context, context_mask
+
+
+class CFGModel(nn.Module):
+ def __init__(self, context_dim, backbone):
+ super().__init__()
+ self.model = backbone
+ self.context_cfg = EmbeddingCFG(context_dim)
+
+ def forward(self, x, timesteps,
+ context, x_mask=None, context_mask=None,
+ cfg_prob=0.0):
+ context = self.context_cfg(context, cfg_prob)
+ x = self.model(x=x, timesteps=timesteps,
+ context=context,
+ x_mask=x_mask, context_mask=context_mask)
+ return x
+
+
+class ConcatModel(nn.Module):
+ def __init__(self, backbone, in_dim, stride=[]):
+ super().__init__()
+ self.model = backbone
+
+ self.downsample_layers = nn.ModuleList()
+ for i, s in enumerate(stride):
+ downsample_layer = nn.Conv1d(
+ in_dim,
+ in_dim * 2,
+ kernel_size=2 * s,
+ stride=s,
+ padding=math.ceil(s / 2),
+ )
+ self.downsample_layers.append(downsample_layer)
+ in_dim = in_dim * 2
+
+ self.context_cfg = EmbeddingCFG(in_dim)
+
+ def forward(self, x, timesteps,
+ context, x_mask=None,
+ cfg=False, cfg_prob=0.0):
+
+ # todo: support 2D input
+ # x: B, C, L
+ # context: B, C, L
+
+ for downsample_layer in self.downsample_layers:
+ context = downsample_layer(context)
+
+ context = context.transpose(1, 2)
+ context = self.context_cfg(caption=context,
+ cfg=cfg, cfg_prob=cfg_prob)
+ context = context.transpose(1, 2)
+
+ assert context.shape[-1] == x.shape[-1]
+ x = torch.cat([context, x], dim=1)
+ x = self.model(x=x, timesteps=timesteps,
+ context=None, x_mask=x_mask, context_mask=None)
+ return x
+
+
+class MaskDiT(nn.Module):
+ def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs):
+ super().__init__()
+ self.model = UDiT(**kwargs)
+ self.mae = mae
+ if self.mae:
+ out_channel = kwargs.pop('out_chans', None)
+ self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
+ self.mae_prob = mae_prob
+ self.mask_ratio = mask_ratio
+ self.mask_span = mask_span
+
+ def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
+ B, D, L = gt.shape
+ if mae_mask_infer is None:
+ # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
+ mask_ratios = mask_ratios.cpu().numpy()
+ mask = compute_mask_indices(shape=[B, L],
+ padding_mask=None,
+ mask_prob=mask_ratios,
+ mask_length=self.mask_span,
+ mask_type="static",
+ mask_other=0.0,
+ min_masks=1,
+ no_overlap=False,
+ min_space=0,)
+ mask = mask.unsqueeze(1).expand_as(gt)
+ else:
+ mask = mae_mask_infer
+ mask = mask.expand_as(gt)
+ gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
+ return gt, mask.type_as(gt)
+
+ def forward(self, x, timesteps, context,
+ x_mask=None, context_mask=None, cls_token=None,
+ gt=None, mae_mask_infer=None):
+ mae_mask = torch.ones_like(x)
+ if self.mae:
+ if gt is not None:
+ B, D, L = gt.shape
+ mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device)
+ gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer)
+ # apply mae only to the selected batches
+ if mae_mask_infer is None:
+ # determine mae batch
+ mae_batch = torch.rand(B) < self.mae_prob
+ gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch]
+ mae_mask[~mae_batch] = 1.0
+ else:
+ B, D, L = x.shape
+ gt = self.mask_embed.view(1, D, 1).expand_as(x)
+ x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
+
+ x = self.model(x=x, timesteps=timesteps, context=context,
+ x_mask=x_mask, context_mask=context_mask,
+ cls_token=cls_token)
+ # print(mae_mask[:, 0, :].sum(dim=-1))
+ return x, mae_mask
diff --git a/src/models/udit.py b/src/models/udit.py
new file mode 100644
index 0000000000000000000000000000000000000000..db86b3e01acaea2de42ae85e3ebbad83f7c19fe9
--- /dev/null
+++ b/src/models/udit.py
@@ -0,0 +1,356 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+import math
+from .utils.modules import PatchEmbed, TimestepEmbedder
+from .utils.modules import PE_wrapper, RMSNorm
+from .blocks import DiTBlock, JointDiTBlock, FinalBlock
+
+
+class UDiT(nn.Module):
+ def __init__(self,
+ img_size=224, patch_size=16, in_chans=3,
+ input_type='2d', out_chans=None,
+ embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
+ qkv_bias=False, qk_scale=None, qk_norm=None,
+ act_layer='gelu', norm_layer='layernorm',
+ context_norm=False,
+ use_checkpoint=False,
+ # time fusion ada or token
+ time_fusion='token',
+ ada_lora_rank=None, ada_lora_alpha=None,
+ cls_dim=None,
+ # max length is only used for concat
+ context_dim=768, context_fusion='concat',
+ context_max_length=128, context_pe_method='sinu',
+ pe_method='abs', rope_mode='none',
+ use_conv=True,
+ skip=True, skip_norm=True):
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ # input
+ self.in_chans = in_chans
+ self.input_type = input_type
+ if self.input_type == '2d':
+ num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+ elif self.input_type == '1d':
+ num_patches = img_size // patch_size
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
+ embed_dim=embed_dim, input_type=input_type)
+ out_chans = in_chans if out_chans is None else out_chans
+ self.out_chans = out_chans
+
+ # position embedding
+ self.rope = rope_mode
+ self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
+ length=num_patches)
+
+ print(f'x position embedding: {pe_method}')
+ print(f'rope mode: {self.rope}')
+
+ # time embed
+ self.time_embed = TimestepEmbedder(embed_dim)
+ self.time_fusion = time_fusion
+ self.use_adanorm = False
+
+ # cls embed
+ if cls_dim is not None:
+ self.cls_embed = nn.Sequential(
+ nn.Linear(cls_dim, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True),)
+ else:
+ self.cls_embed = None
+
+ # time fusion
+ if time_fusion == 'token':
+ # put token at the beginning of sequence
+ self.extras = 2 if self.cls_embed else 1
+ self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
+ elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
+ self.use_adanorm = True
+ # aviod repetitive silu for each adaln block
+ self.time_act = nn.SiLU()
+ self.extras = 0
+ self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
+ if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
+ # shared adaln
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
+ else:
+ self.time_ada = None
+ else:
+ raise NotImplementedError
+ print(f'time fusion mode: {self.time_fusion}')
+
+ # context
+ # use a simple projection
+ self.use_context = False
+ self.context_cross = False
+ self.context_max_length = context_max_length
+ self.context_fusion = 'none'
+ if context_dim is not None:
+ self.use_context = True
+ self.context_embed = nn.Sequential(
+ nn.Linear(context_dim, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True),)
+ self.context_fusion = context_fusion
+ if context_fusion == 'concat' or context_fusion == 'joint':
+ self.extras += context_max_length
+ self.context_pe = PE_wrapper(dim=embed_dim,
+ method=context_pe_method,
+ length=context_max_length)
+ # no cross attention layers
+ context_dim = None
+ elif context_fusion == 'cross':
+ self.context_pe = PE_wrapper(dim=embed_dim,
+ method=context_pe_method,
+ length=context_max_length)
+ self.context_cross = True
+ context_dim = embed_dim
+ else:
+ raise NotImplementedError
+ print(f'context fusion mode: {context_fusion}')
+ print(f'context position embedding: {context_pe_method}')
+
+ if self.context_fusion == 'joint':
+ Block = JointDiTBlock
+ self.use_skip = skip[0]
+ else:
+ Block = DiTBlock
+ self.use_skip = skip
+
+ # norm layers
+ if norm_layer == 'layernorm':
+ norm_layer = nn.LayerNorm
+ elif norm_layer == 'rmsnorm':
+ norm_layer = RMSNorm
+ else:
+ raise NotImplementedError
+
+ print(f'use long skip connection: {skip}')
+ self.in_blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+ act_layer=act_layer, norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+ skip=False, skip_norm=False,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint)
+ for _ in range(depth // 2)])
+
+ self.mid_block = Block(
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+ act_layer=act_layer, norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+ skip=False, skip_norm=False,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint)
+
+ self.out_blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
+ act_layer=act_layer, norm_layer=norm_layer,
+ time_fusion=time_fusion,
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
+ skip=skip, skip_norm=skip_norm,
+ rope_mode=self.rope,
+ context_norm=context_norm,
+ use_checkpoint=use_checkpoint)
+ for _ in range(depth // 2)])
+
+ # FinalLayer block
+ self.use_conv = use_conv
+ self.final_block = FinalBlock(embed_dim=embed_dim,
+ patch_size=patch_size,
+ img_size=img_size,
+ in_chans=out_chans,
+ input_type=input_type,
+ norm_layer=norm_layer,
+ use_conv=use_conv,
+ use_adanorm=self.use_adanorm)
+ self.initialize_weights()
+
+ def _init_ada(self):
+ if self.time_fusion == 'ada':
+ nn.init.constant_(self.time_ada_final.weight, 0)
+ nn.init.constant_(self.time_ada_final.bias, 0)
+ for block in self.in_blocks:
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
+ nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
+ nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
+ for block in self.out_blocks:
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
+ elif self.time_fusion == 'ada_single':
+ nn.init.constant_(self.time_ada.weight, 0)
+ nn.init.constant_(self.time_ada.bias, 0)
+ nn.init.constant_(self.time_ada_final.weight, 0)
+ nn.init.constant_(self.time_ada_final.bias, 0)
+ elif self.time_fusion in ['ada_lora', 'ada_lora_bias']:
+ nn.init.constant_(self.time_ada.weight, 0)
+ nn.init.constant_(self.time_ada.bias, 0)
+ nn.init.constant_(self.time_ada_final.weight, 0)
+ nn.init.constant_(self.time_ada_final.bias, 0)
+ for block in self.in_blocks:
+ nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
+ a=math.sqrt(5))
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
+ nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight,
+ a=math.sqrt(5))
+ nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
+ for block in self.out_blocks:
+ nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
+ a=math.sqrt(5))
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
+
+ def initialize_weights(self):
+ # Basic init for all layers
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # init patch Conv like Linear
+ w = self.patch_embed.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
+
+ # Zero-out AdaLN
+ if self.use_adanorm:
+ self._init_ada()
+
+ # Zero-out Cross Attention
+ if self.context_cross:
+ for block in self.in_blocks:
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
+ nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
+ nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
+ for block in self.out_blocks:
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
+
+ # Zero-out cls embedding
+ if self.cls_embed:
+ if self.use_adanorm:
+ nn.init.constant_(self.cls_embed[-1].weight, 0)
+ nn.init.constant_(self.cls_embed[-1].bias, 0)
+
+ # Zero-out Output
+ # might not zero-out this when using v-prediction
+ # it could be good when using noise-prediction
+ # nn.init.constant_(self.final_block.linear.weight, 0)
+ # nn.init.constant_(self.final_block.linear.bias, 0)
+ # if self.use_conv:
+ # nn.init.constant_(self.final_block.final_layer.weight.data, 0)
+ # nn.init.constant_(self.final_block.final_layer.bias, 0)
+
+ # init out Conv
+ if self.use_conv:
+ nn.init.xavier_uniform_(self.final_block.final_layer.weight)
+ nn.init.constant_(self.final_block.final_layer.bias, 0)
+
+ def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
+ assert context.shape[-2] == self.context_max_length
+ # Check if either x_mask or context_mask is provided
+ B = x.shape[0]
+ # Create default masks if they are not provided
+ if x_mask is None:
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
+ if context_mask is None:
+ context_mask = torch.ones(B, context.shape[-2],
+ device=context.device).bool()
+ # Concatenate the masks along the second dimension (dim=1)
+ x_mask = torch.cat([context_mask, x_mask], dim=1)
+ # Concatenate context and x along the second dimension (dim=1)
+ x = torch.cat((context, x), dim=1)
+ return x, x_mask
+
+ def forward(self, x, timesteps, context,
+ x_mask=None, context_mask=None,
+ cls_token=None
+ ):
+ # make it compatible with int time step during inference
+ if timesteps.dim() == 0:
+ timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
+
+ x = self.patch_embed(x)
+ x = self.x_pe(x)
+
+ B, L, D = x.shape
+
+ if self.use_context:
+ context_token = self.context_embed(context)
+ context_token = self.context_pe(context_token)
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
+ x, x_mask = self._concat_x_context(x=x, context=context_token,
+ x_mask=x_mask,
+ context_mask=context_mask)
+ context_token, context_mask = None, None
+ else:
+ context_token, context_mask = None, None
+
+ time_token = self.time_embed(timesteps)
+ if self.cls_embed:
+ cls_token = self.cls_embed(cls_token)
+ time_ada = None
+ time_ada_final = None
+ if self.use_adanorm:
+ if self.cls_embed:
+ time_token = time_token + cls_token
+ time_token = self.time_act(time_token)
+ time_ada_final = self.time_ada_final(time_token)
+ if self.time_ada is not None:
+ time_ada = self.time_ada(time_token)
+ else:
+ time_token = time_token.unsqueeze(dim=1)
+ if self.cls_embed:
+ cls_token = cls_token.unsqueeze(dim=1)
+ time_token = torch.cat([time_token, cls_token], dim=1)
+ time_token = self.time_pe(time_token)
+ x = torch.cat((time_token, x), dim=1)
+ if x_mask is not None:
+ x_mask = torch.cat(
+ [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
+ x_mask], dim=1)
+ time_token = None
+
+ skips = []
+ for blk in self.in_blocks:
+ x = blk(x=x, time_token=time_token, time_ada=time_ada,
+ skip=None, context=context_token,
+ x_mask=x_mask, context_mask=context_mask,
+ extras=self.extras)
+ if self.use_skip:
+ skips.append(x)
+
+ x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada,
+ skip=None, context=context_token,
+ x_mask=x_mask, context_mask=context_mask,
+ extras=self.extras)
+
+ for blk in self.out_blocks:
+ skip = skips.pop() if self.use_skip else None
+ x = blk(x=x, time_token=time_token, time_ada=time_ada,
+ skip=skip, context=context_token,
+ x_mask=x_mask, context_mask=context_mask,
+ extras=self.extras)
+
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
+
+ return x
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py b/src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py b/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffad4542ada8f1824e93c76647da232db7f2da4e
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py
@@ -0,0 +1,290 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .rotary import RotaryEmbedding
+from .modules import RMSNorm
+
+
+if hasattr(nn.functional, 'scaled_dot_product_attention'):
+ ATTENTION_MODE = 'flash'
+else:
+ ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+def add_mask(sim, mask):
+ b, ndim = sim.shape[0], mask.ndim
+ if ndim == 3:
+ mask = rearrange(mask, "b n m -> b 1 n m")
+ if ndim == 2:
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
+ max_neg_value = -torch.finfo(sim.dtype).max
+ sim = sim.masked_fill(~mask, max_neg_value)
+ return sim
+
+
+def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
+ def default(val, d):
+ return val if val is not None else (d() if isfunction(d) else d)
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
+ return attn_mask
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, context_dim=None, num_heads=8,
+ qkv_bias=False, qk_scale=None, qk_norm=None,
+ attn_drop=0., proj_drop=0., rope_mode='none'):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ if context_dim is None:
+ self.cross_attn = False
+ else:
+ self.cross_attn = True
+
+ context_dim = dim if context_dim is None else context_dim
+
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
+
+ if qk_norm is None:
+ self.norm_q = nn.Identity()
+ self.norm_k = nn.Identity()
+ elif qk_norm == 'layernorm':
+ self.norm_q = nn.LayerNorm(head_dim)
+ self.norm_k = nn.LayerNorm(head_dim)
+ elif qk_norm == 'rmsnorm':
+ self.norm_q = RMSNorm(head_dim)
+ self.norm_k = RMSNorm(head_dim)
+ else:
+ raise NotImplementedError
+
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ if self.cross_attn:
+ assert rope_mode == 'none'
+ self.rope_mode = rope_mode
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+ self.rotary = RotaryEmbedding(dim=head_dim)
+ elif self.rope_mode == 'dual':
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+ def _rotary(self, q, k, extras):
+ if self.rope_mode == 'shared':
+ q, k = self.rotary(q=q, k=k)
+ elif self.rope_mode == 'x_only':
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'dual':
+ q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+ q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'none':
+ pass
+ else:
+ raise NotImplementedError
+ return q, k
+
+ def _attn(self, q, k, v, mask_binary):
+ if ATTENTION_MODE == 'flash':
+ x = F.scaled_dot_product_attention(q, k, v,
+ dropout_p=self.attn_drop_p,
+ attn_mask=mask_binary)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ else:
+ raise NotImplementedError
+ return x
+
+ def forward(self, x, context=None, context_mask=None, extras=0):
+ B, L, C = x.shape
+ if context is None:
+ context = x
+
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if context_mask is not None:
+ mask_binary = create_mask(x.shape, context.shape,
+ x.device, None, context_mask)
+ else:
+ mask_binary = None
+
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
+
+ q = self.norm_q(q)
+ k = self.norm_k(k)
+
+ q, k = self._rotary(q, k, extras)
+
+ x = self._attn(q, k, v, mask_binary)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class JointAttention(nn.Module):
+ def __init__(self, dim, num_heads=8,
+ qkv_bias=False, qk_scale=None, qk_norm=None,
+ attn_drop=0., proj_drop=0.,
+ rope_mode='none'):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias)
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias)
+
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
+
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ self.proj_x = nn.Linear(dim, dim)
+ self.proj_drop_x = nn.Dropout(proj_drop)
+
+ self.proj_c = nn.Linear(dim, dim)
+ self.proj_drop_c = nn.Dropout(proj_drop)
+
+ self.rope_mode = rope_mode
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+ self.rotary = RotaryEmbedding(dim=head_dim)
+ elif self.rope_mode == 'dual':
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+ def _make_qkv_layers(self, dim, qkv_bias):
+ return (nn.Linear(dim, dim, bias=qkv_bias),
+ nn.Linear(dim, dim, bias=qkv_bias),
+ nn.Linear(dim, dim, bias=qkv_bias))
+
+ def _make_norm_layers(self, qk_norm, head_dim):
+ if qk_norm is None:
+ norm_q = nn.Identity()
+ norm_k = nn.Identity()
+ elif qk_norm == 'layernorm':
+ norm_q = nn.LayerNorm(head_dim)
+ norm_k = nn.LayerNorm(head_dim)
+ elif qk_norm == 'rmsnorm':
+ norm_q = RMSNorm(head_dim)
+ norm_k = RMSNorm(head_dim)
+ else:
+ raise NotImplementedError
+ return norm_q, norm_k
+
+ def _rotary(self, q, k, extras):
+ if self.rope_mode == 'shared':
+ q, k = self.rotary(q=q, k=k)
+ elif self.rope_mode == 'x_only':
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'dual':
+ q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+ q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'none':
+ pass
+ else:
+ raise NotImplementedError
+ return q, k
+
+ def _attn(self, q, k, v, mask_binary):
+ if ATTENTION_MODE == 'flash':
+ x = F.scaled_dot_product_attention(q, k, v,
+ dropout_p=self.attn_drop_p,
+ attn_mask=mask_binary)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ else:
+ raise NotImplementedError
+ return x
+
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
+ B = x.shape[0]
+ if x_mask is None:
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
+ if context_mask is None:
+ context_mask = torch.ones(B, context.shape[-2], device=context.device).bool()
+ mask = torch.cat([context_mask, x_mask], dim=1)
+ return mask
+
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
+ B, Lx, C = x.shape
+ _, Lc, _ = context.shape
+ if x_mask is not None or context_mask is not None:
+ mask = self._cat_mask(x, context,
+ x_mask=x_mask,
+ context_mask=context_mask)
+ shape = [B, Lx+Lc, C]
+ mask_binary = create_mask(q_shape=shape, k_shape=shape,
+ device=x.device,
+ q_mask=None, k_mask=mask)
+ else:
+ mask_binary = None
+
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
+ qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context)
+
+ qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
+ H=self.num_heads), [qx, kx, vx])
+ qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
+ H=self.num_heads), [qc, kc, vc])
+
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
+
+ q, k, v = (torch.cat([qc, qx], dim=2),
+ torch.cat([kc, kx], dim=2),
+ torch.cat([vc, vx], dim=2))
+
+ q, k = self._rotary(q, k, extras)
+
+ x = self._attn(q, k, v, mask_binary)
+
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
+
+ x = self.proj_x(x)
+ x = self.proj_drop_x(x)
+
+ context = self.proj_c(context)
+ context = self.proj_drop_c(context)
+
+ return x, context
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py b/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a2a8c841b62748120d7cb33a6aa10860ecdb674
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py
@@ -0,0 +1,374 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.cuda.amp import autocast
+import math
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .timm import trunc_normal_
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def film_modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256,
+ out_size=None):
+ super().__init__()
+ if out_size is None:
+ out_size = hidden_size
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, out_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ def forward(self, t):
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
+ self.mlp[0].weight.dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+def patchify(imgs, patch_size, input_type='2d'):
+ if input_type == '2d':
+ x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
+ elif input_type == '1d':
+ x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
+ return x
+
+
+def unpatchify(x, channels=3, input_type='2d', img_size=None):
+ if input_type == '2d':
+ patch_size = int((x.shape[2] // channels) ** 0.5)
+ # h = w = int(x.shape[1] ** .5)
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
+ assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
+ x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h,
+ p1=patch_size, p2=patch_size)
+ elif input_type == '1d':
+ patch_size = int((x.shape[2] // channels))
+ h = x.shape[1]
+ assert patch_size * channels == x.shape[2]
+ x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding
+ """
+
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
+ super().__init__()
+ self.patch_size = patch_size
+ self.input_type = input_type
+ if input_type == '2d':
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
+ elif input_type == '1d':
+ self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
+
+ def forward(self, x):
+ if self.input_type == '2d':
+ B, C, H, W = x.shape
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
+ elif self.input_type == '1d':
+ B, C, H = x.shape
+ assert H % self.patch_size == 0
+
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class PositionalConvEmbedding(nn.Module):
+ """
+ Relative positional embedding used in HuBERT
+ """
+
+ def __init__(self, dim=768, kernel_size=128, groups=16):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ groups=groups,
+ bias=True
+ )
+ self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
+
+ def forward(self, x):
+ # B C T
+ x = self.conv(x)
+ x = F.gelu(x[:, :, :-1])
+ return x
+
+
+class SinusoidalPositionalEncoding(nn.Module):
+ def __init__(self, dim, length):
+ super(SinusoidalPositionalEncoding, self).__init__()
+ self.length = length
+ self.dim = dim
+ self.register_buffer('pe', self._generate_positional_encoding(length, dim))
+
+ def _generate_positional_encoding(self, length, dim):
+ pe = torch.zeros(length, dim)
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+
+ pe = pe.unsqueeze(0)
+ return pe
+
+ def forward(self, x):
+ x = x + self.pe[:, :x.size(1)]
+ return x
+
+
+class PE_wrapper(nn.Module):
+ def __init__(self, dim=768, method='abs', length=None, **kwargs):
+ super().__init__()
+ self.method = method
+ if method == 'abs':
+ # init absolute pe like UViT
+ self.length = length
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
+ trunc_normal_(self.abs_pe, std=.02)
+ elif method == 'conv':
+ self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
+ elif method == 'sinu':
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
+ elif method == 'none':
+ # skip pe
+ self.id = nn.Identity()
+ else:
+ raise NotImplementedError
+
+ def forward(self, x):
+ if self.method == 'abs':
+ _, L, _ = x.shape
+ assert L <= self.length
+ x = x + self.abs_pe[:, :L, :]
+ elif self.method == 'conv':
+ x = x + self.conv_pe(x)
+ elif self.method == 'sinu':
+ x = self.sinu_pe(x)
+ elif self.method == 'none':
+ x = self.id(x)
+ else:
+ raise NotImplementedError
+ return x
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+class GELU(nn.Module):
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none",
+ bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32),
+ approximate=self.approximate).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def snake_beta(x, alpha, beta):
+ return x + beta * torch.sin(x * alpha).pow(2)
+
+
+class Snake(nn.Module):
+ def __init__(self, dim_in, dim_out, bias,
+ alpha_trainable=True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = snake_beta(x, self.alpha, self.beta)
+ return x
+
+
+class GESnake(nn.Module):
+ def __init__(self, dim_in, dim_out, bias,
+ alpha_trainable=True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ def forward(self, x):
+ x = self.proj(x)
+ x, gate = x.chunk(2, dim=-1)
+ return x * snake_beta(gate, self.alpha, self.beta)
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out=None,
+ mult=4,
+ dropout=0.0,
+ activation_fn="geglu",
+ final_dropout=False,
+ inner_dim=None,
+ bias=True,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim, bias=bias)
+ elif activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+ elif activation_fn == "snake":
+ act_fn = Snake(dim, inner_dim, bias=bias)
+ elif activation_fn == "gesnake":
+ act_fn = GESnake(dim, inner_dim, bias=bias)
+ else:
+ raise NotImplementedError
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py b/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..04f8b199ced89d0ed0365b8d74c1088749e7c441
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py
@@ -0,0 +1,91 @@
+import torch
+
+"this rope is faster than llama rope with jit script"
+
+
+def rotate_half(x):
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def apply_rotary_pos_emb(x, cos, sin):
+ # NOTE: This could probably be moved to Triton
+ # Handle a possible sequence length mismatch in between q and k
+ cos = cos[:, :, : x.shape[-2], :]
+ sin = sin[:, :, : x.shape[-2], :]
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+class RotaryEmbedding(torch.nn.Module):
+ """
+ The rotary position embeddings from RoFormer_ (Su et. al).
+ A crucial insight from the method is that the query and keys are
+ transformed by rotation matrices which depend on the relative positions.
+
+ Other implementations are available in the Rotary Transformer repo_ and in
+ GPT-NeoX_, GPT-NeoX was an inspiration
+
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
+
+
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
+ """
+
+ def __init__(self, dim: int):
+ super().__init__()
+ # Generate and save the inverse frequency buffer (non trainable)
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self._seq_len_cached = None
+ self._cos_cached = None
+ self._sin_cached = None
+
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
+ # expect input: B, H, L, D
+ seq_len = x.shape[seq_dimension]
+
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ # also make sure dtype wont change
+ if (
+ seq_len != self._seq_len_cached
+ or self._cos_cached.device != x.device
+ or self._cos_cached.dtype != x.dtype
+ ):
+ self._seq_len_cached = seq_len
+ t = torch.arange(
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
+
+ return self._cos_cached, self._sin_cached
+
+ def forward(self, q, k):
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
+ q.float(), seq_dimension=-2
+ )
+ if k is not None:
+ return (
+ apply_rotary_pos_emb(q.float(),
+ self._cos_cached,
+ self._sin_cached).type_as(q),
+ apply_rotary_pos_emb(k.float(),
+ self._cos_cached,
+ self._sin_cached).type_as(k),
+ )
+ else:
+ return (
+ apply_rotary_pos_emb(q.float(),
+ self._cos_cached,
+ self._sin_cached).type_as(q),
+ None
+ )
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py b/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..23f8557e9907c4f9ec17efa36ebd035d8667ff00
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py
@@ -0,0 +1,146 @@
+import numpy as np
+import torch
+from typing import Optional, Tuple
+
+
+def compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ """
+
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ # Convert mask_prob to a NumPy array
+ mask_prob = np.array(mask_prob)
+
+ # Calculate all_num_mask for each element in the batch
+ all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int)
+
+ # Apply the max operation with min_masks for each element
+ all_num_mask = np.maximum(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = all_sz
+ num_mask = all_num_mask[i]
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = np.random.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = np.random.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - keep_length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = np.random.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ]
+ )
+
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+ # min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ # if len(mask_idc) > min_len:
+ # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+ mask[i, mask_idc] = True
+
+ return torch.tensor(mask)
+
+
+if __name__ == '__main__':
+ mask = compute_mask_indices(
+ shape=[4, 500],
+ padding_mask=None,
+ mask_prob=[0.65, 0.5, 0.65, 0.65],
+ mask_length=10,
+ mask_type="static",
+ mask_other=0.0,
+ min_masks=1,
+ no_overlap=False,
+ min_space=0,
+ )
+ print(mask)
+ print(mask.sum(dim=1))
\ No newline at end of file
diff --git a/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py b/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae72f4fca681617778028b64da5daed26b3ed3d4
--- /dev/null
+++ b/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py
@@ -0,0 +1,114 @@
+# code from timm 0.3.2
+import torch
+import torch.nn as nn
+import math
+import warnings
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None,
+ act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
\ No newline at end of file
diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/utils/__pycache__/__init__.cpython-310.pyc b/src/models/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc31b0a19bf83eef2a703df69b4272a12bfbe577
Binary files /dev/null and b/src/models/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/__init__.cpython-311.pyc b/src/models/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6136f879b34c073f2bb3a13f6ef55d72b7ecf028
Binary files /dev/null and b/src/models/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/attention.cpython-310.pyc b/src/models/utils/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ebbb2fc665b2c2ff2115bc98ec054430333e6ee
Binary files /dev/null and b/src/models/utils/__pycache__/attention.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/attention.cpython-311.pyc b/src/models/utils/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f457ce350f5ec3c6dfd75532635bfaf27cd72ba
Binary files /dev/null and b/src/models/utils/__pycache__/attention.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/modules.cpython-310.pyc b/src/models/utils/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d507df3df9cbf9fa29fbabb4591df36aedf6bdd4
Binary files /dev/null and b/src/models/utils/__pycache__/modules.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/modules.cpython-311.pyc b/src/models/utils/__pycache__/modules.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..917e6f3d92cc78d984aeeb0e5c97b61fa7a97a84
Binary files /dev/null and b/src/models/utils/__pycache__/modules.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/rotary.cpython-310.pyc b/src/models/utils/__pycache__/rotary.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..716bfd4abc834b3e912c4f4574ddc3d3597183a5
Binary files /dev/null and b/src/models/utils/__pycache__/rotary.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/rotary.cpython-311.pyc b/src/models/utils/__pycache__/rotary.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9eff5238b9477aad457d21e785db9b986ff456e
Binary files /dev/null and b/src/models/utils/__pycache__/rotary.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/span_mask.cpython-310.pyc b/src/models/utils/__pycache__/span_mask.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8dc66ee9fc445b41e939f29ba32cfdeb6169bcc2
Binary files /dev/null and b/src/models/utils/__pycache__/span_mask.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/span_mask.cpython-311.pyc b/src/models/utils/__pycache__/span_mask.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55645308b0731bc848b667d789370364fe94426e
Binary files /dev/null and b/src/models/utils/__pycache__/span_mask.cpython-311.pyc differ
diff --git a/src/models/utils/__pycache__/timm.cpython-310.pyc b/src/models/utils/__pycache__/timm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fa38c2a3330015ecde142ac1e187afd6afd3aa5
Binary files /dev/null and b/src/models/utils/__pycache__/timm.cpython-310.pyc differ
diff --git a/src/models/utils/__pycache__/timm.cpython-311.pyc b/src/models/utils/__pycache__/timm.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8ca43e9c8f24c9e6634babca5adadc78e686949
Binary files /dev/null and b/src/models/utils/__pycache__/timm.cpython-311.pyc differ
diff --git a/src/models/utils/attention.py b/src/models/utils/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffad4542ada8f1824e93c76647da232db7f2da4e
--- /dev/null
+++ b/src/models/utils/attention.py
@@ -0,0 +1,290 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .rotary import RotaryEmbedding
+from .modules import RMSNorm
+
+
+if hasattr(nn.functional, 'scaled_dot_product_attention'):
+ ATTENTION_MODE = 'flash'
+else:
+ ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+def add_mask(sim, mask):
+ b, ndim = sim.shape[0], mask.ndim
+ if ndim == 3:
+ mask = rearrange(mask, "b n m -> b 1 n m")
+ if ndim == 2:
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
+ max_neg_value = -torch.finfo(sim.dtype).max
+ sim = sim.masked_fill(~mask, max_neg_value)
+ return sim
+
+
+def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
+ def default(val, d):
+ return val if val is not None else (d() if isfunction(d) else d)
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
+ return attn_mask
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, context_dim=None, num_heads=8,
+ qkv_bias=False, qk_scale=None, qk_norm=None,
+ attn_drop=0., proj_drop=0., rope_mode='none'):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ if context_dim is None:
+ self.cross_attn = False
+ else:
+ self.cross_attn = True
+
+ context_dim = dim if context_dim is None else context_dim
+
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
+
+ if qk_norm is None:
+ self.norm_q = nn.Identity()
+ self.norm_k = nn.Identity()
+ elif qk_norm == 'layernorm':
+ self.norm_q = nn.LayerNorm(head_dim)
+ self.norm_k = nn.LayerNorm(head_dim)
+ elif qk_norm == 'rmsnorm':
+ self.norm_q = RMSNorm(head_dim)
+ self.norm_k = RMSNorm(head_dim)
+ else:
+ raise NotImplementedError
+
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ if self.cross_attn:
+ assert rope_mode == 'none'
+ self.rope_mode = rope_mode
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+ self.rotary = RotaryEmbedding(dim=head_dim)
+ elif self.rope_mode == 'dual':
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+ def _rotary(self, q, k, extras):
+ if self.rope_mode == 'shared':
+ q, k = self.rotary(q=q, k=k)
+ elif self.rope_mode == 'x_only':
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'dual':
+ q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+ q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'none':
+ pass
+ else:
+ raise NotImplementedError
+ return q, k
+
+ def _attn(self, q, k, v, mask_binary):
+ if ATTENTION_MODE == 'flash':
+ x = F.scaled_dot_product_attention(q, k, v,
+ dropout_p=self.attn_drop_p,
+ attn_mask=mask_binary)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ else:
+ raise NotImplementedError
+ return x
+
+ def forward(self, x, context=None, context_mask=None, extras=0):
+ B, L, C = x.shape
+ if context is None:
+ context = x
+
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if context_mask is not None:
+ mask_binary = create_mask(x.shape, context.shape,
+ x.device, None, context_mask)
+ else:
+ mask_binary = None
+
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
+
+ q = self.norm_q(q)
+ k = self.norm_k(k)
+
+ q, k = self._rotary(q, k, extras)
+
+ x = self._attn(q, k, v, mask_binary)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class JointAttention(nn.Module):
+ def __init__(self, dim, num_heads=8,
+ qkv_bias=False, qk_scale=None, qk_norm=None,
+ attn_drop=0., proj_drop=0.,
+ rope_mode='none'):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias)
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias)
+
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
+
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ self.proj_x = nn.Linear(dim, dim)
+ self.proj_drop_x = nn.Dropout(proj_drop)
+
+ self.proj_c = nn.Linear(dim, dim)
+ self.proj_drop_c = nn.Dropout(proj_drop)
+
+ self.rope_mode = rope_mode
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
+ self.rotary = RotaryEmbedding(dim=head_dim)
+ elif self.rope_mode == 'dual':
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
+
+ def _make_qkv_layers(self, dim, qkv_bias):
+ return (nn.Linear(dim, dim, bias=qkv_bias),
+ nn.Linear(dim, dim, bias=qkv_bias),
+ nn.Linear(dim, dim, bias=qkv_bias))
+
+ def _make_norm_layers(self, qk_norm, head_dim):
+ if qk_norm is None:
+ norm_q = nn.Identity()
+ norm_k = nn.Identity()
+ elif qk_norm == 'layernorm':
+ norm_q = nn.LayerNorm(head_dim)
+ norm_k = nn.LayerNorm(head_dim)
+ elif qk_norm == 'rmsnorm':
+ norm_q = RMSNorm(head_dim)
+ norm_k = RMSNorm(head_dim)
+ else:
+ raise NotImplementedError
+ return norm_q, norm_k
+
+ def _rotary(self, q, k, extras):
+ if self.rope_mode == 'shared':
+ q, k = self.rotary(q=q, k=k)
+ elif self.rope_mode == 'x_only':
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'dual':
+ q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
+ q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
+ q = torch.cat((q_c, q_x), dim=2)
+ k = torch.cat((k_c, k_x), dim=2)
+ elif self.rope_mode == 'none':
+ pass
+ else:
+ raise NotImplementedError
+ return q, k
+
+ def _attn(self, q, k, v, mask_binary):
+ if ATTENTION_MODE == 'flash':
+ x = F.scaled_dot_product_attention(q, k, v,
+ dropout_p=self.attn_drop_p,
+ attn_mask=mask_binary)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ else:
+ raise NotImplementedError
+ return x
+
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
+ B = x.shape[0]
+ if x_mask is None:
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
+ if context_mask is None:
+ context_mask = torch.ones(B, context.shape[-2], device=context.device).bool()
+ mask = torch.cat([context_mask, x_mask], dim=1)
+ return mask
+
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
+ B, Lx, C = x.shape
+ _, Lc, _ = context.shape
+ if x_mask is not None or context_mask is not None:
+ mask = self._cat_mask(x, context,
+ x_mask=x_mask,
+ context_mask=context_mask)
+ shape = [B, Lx+Lc, C]
+ mask_binary = create_mask(q_shape=shape, k_shape=shape,
+ device=x.device,
+ q_mask=None, k_mask=mask)
+ else:
+ mask_binary = None
+
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
+ qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context)
+
+ qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
+ H=self.num_heads), [qx, kx, vx])
+ qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D',
+ H=self.num_heads), [qc, kc, vc])
+
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
+
+ q, k, v = (torch.cat([qc, qx], dim=2),
+ torch.cat([kc, kx], dim=2),
+ torch.cat([vc, vx], dim=2))
+
+ q, k = self._rotary(q, k, extras)
+
+ x = self._attn(q, k, v, mask_binary)
+
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
+
+ x = self.proj_x(x)
+ x = self.proj_drop_x(x)
+
+ context = self.proj_c(context)
+ context = self.proj_drop_c(context)
+
+ return x, context
\ No newline at end of file
diff --git a/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py b/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd201997bfe10a721d3473692a820ccde7189797
--- /dev/null
+++ b/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py
@@ -0,0 +1,99 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .rotary import RotaryEmbedding
+
+if hasattr(nn.functional, 'scaled_dot_product_attention'):
+ ATTENTION_MODE = 'flash'
+else:
+ ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+def add_mask(sim, mask):
+ b, ndim = sim.shape[0], mask.ndim
+ if ndim == 3:
+ mask = rearrange(mask, "b n m -> b 1 n m")
+ if ndim == 2:
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
+ max_neg_value = -torch.finfo(sim.dtype).max
+ sim = sim.masked_fill(~mask, max_neg_value)
+ return sim
+
+
+def create_mask(q, k, q_mask=None, k_mask=None):
+ def default(val, d):
+ return val if val is not None else (d() if isfunction(d) else d)
+
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
+ return attn_mask
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None,
+ attn_drop=0., proj_drop=0., use_rope=False):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ context_dim = dim if context_dim is None else context_dim
+
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_rope = use_rope
+ if self.use_rope:
+ self.rotary = RotaryEmbedding(dim=head_dim)
+
+ def forward(self, x, context=None, context_mask=None):
+ B, L, C = x.shape
+ q = self.to_q(x)
+ if context is None:
+ context = x
+ else:
+ assert self.use_rope is False
+
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if context_mask is not None:
+ mask_binary = create_mask(x, context, None, context_mask)
+ else:
+ mask_binary = None
+
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float()
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float()
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float()
+
+ if self.use_rope:
+ q, k = self.rotary(q=q, k=k)
+
+ if ATTENTION_MODE == 'flash':
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v,
+ dropout_p=self.attn_drop_p,
+ attn_mask=mask_binary)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+ else:
+ raise NotImplementedError
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
\ No newline at end of file
diff --git a/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py b/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..52d66b7a08773f2aadab74ac1ffd6d35409ff52b
--- /dev/null
+++ b/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py
@@ -0,0 +1,74 @@
+import torch
+from typing import Tuple
+from rotary import RotaryEmbedding
+import time
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
+ return freqs_cis
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor,
+ x: torch.Tensor,):
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def compute_rope(q, freqs_cis):
+ return q * freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq1, xq2 = xq.chunk(2, dim=-1)
+ xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float())
+
+ xk1, xk2 = xk.chunk(2, dim=-1)
+ xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float())
+
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3)
+ xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+if __name__ == '__main__':
+ # Move data to CUDA
+ freq_cis = precompute_freqs_cis(4, 5).cuda()
+ x = torch.rand(1, 5, 1, 4).cuda()
+ y = torch.rand(1, 5, 1, 4).cuda()
+
+ # First method
+ start_time = time.time()
+ for _ in range(20000):
+ x1, y1 = apply_rotary_emb(x, y, freq_cis)
+ end_time = time.time()
+ print(f"Method 1 time cost: {end_time - start_time} seconds")
+
+ # Prepare data for the second method
+ x = x.permute(0, 2, 1, 3)
+ y = y.permute(0, 2, 1, 3)
+ rope = RotaryEmbedding(4).cuda()
+
+ # Second method
+ start_time = time.time()
+ for _ in range(20000):
+ x2, y2 = rope(x, y)
+ end_time = time.time()
+ print(f"Method 2 time cost: {end_time - start_time} seconds")
+
+ # Print the results
+ print(x1)
+ print(x2)
\ No newline at end of file
diff --git a/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc b/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de78be925584642b52de19239fd67bdcf6173d95
Binary files /dev/null and b/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc differ
diff --git a/src/models/utils/bk/attention.py b/src/models/utils/bk/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd201997bfe10a721d3473692a820ccde7189797
--- /dev/null
+++ b/src/models/utils/bk/attention.py
@@ -0,0 +1,99 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .rotary import RotaryEmbedding
+
+if hasattr(nn.functional, 'scaled_dot_product_attention'):
+ ATTENTION_MODE = 'flash'
+else:
+ ATTENTION_MODE = 'math'
+print(f'attention mode is {ATTENTION_MODE}')
+
+
+def add_mask(sim, mask):
+ b, ndim = sim.shape[0], mask.ndim
+ if ndim == 3:
+ mask = rearrange(mask, "b n m -> b 1 n m")
+ if ndim == 2:
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
+ max_neg_value = -torch.finfo(sim.dtype).max
+ sim = sim.masked_fill(~mask, max_neg_value)
+ return sim
+
+
+def create_mask(q, k, q_mask=None, k_mask=None):
+ def default(val, d):
+ return val if val is not None else (d() if isfunction(d) else d)
+
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
+ return attn_mask
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None,
+ attn_drop=0., proj_drop=0., use_rope=False):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ context_dim = dim if context_dim is None else context_dim
+
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
+ self.attn_drop_p = attn_drop
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.use_rope = use_rope
+ if self.use_rope:
+ self.rotary = RotaryEmbedding(dim=head_dim)
+
+ def forward(self, x, context=None, context_mask=None):
+ B, L, C = x.shape
+ q = self.to_q(x)
+ if context is None:
+ context = x
+ else:
+ assert self.use_rope is False
+
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if context_mask is not None:
+ mask_binary = create_mask(x, context, None, context_mask)
+ else:
+ mask_binary = None
+
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float()
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float()
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float()
+
+ if self.use_rope:
+ q, k = self.rotary(q=q, k=k)
+
+ if ATTENTION_MODE == 'flash':
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v,
+ dropout_p=self.attn_drop_p,
+ attn_mask=mask_binary)
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
+ elif ATTENTION_MODE == 'math':
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+ else:
+ raise NotImplementedError
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
\ No newline at end of file
diff --git a/src/models/utils/bk/llama_rotary.py b/src/models/utils/bk/llama_rotary.py
new file mode 100644
index 0000000000000000000000000000000000000000..52d66b7a08773f2aadab74ac1ffd6d35409ff52b
--- /dev/null
+++ b/src/models/utils/bk/llama_rotary.py
@@ -0,0 +1,74 @@
+import torch
+from typing import Tuple
+from rotary import RotaryEmbedding
+import time
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
+ return freqs_cis
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor,
+ x: torch.Tensor,):
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def compute_rope(q, freqs_cis):
+ return q * freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq1, xq2 = xq.chunk(2, dim=-1)
+ xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float())
+
+ xk1, xk2 = xk.chunk(2, dim=-1)
+ xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float())
+
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3)
+ xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+if __name__ == '__main__':
+ # Move data to CUDA
+ freq_cis = precompute_freqs_cis(4, 5).cuda()
+ x = torch.rand(1, 5, 1, 4).cuda()
+ y = torch.rand(1, 5, 1, 4).cuda()
+
+ # First method
+ start_time = time.time()
+ for _ in range(20000):
+ x1, y1 = apply_rotary_emb(x, y, freq_cis)
+ end_time = time.time()
+ print(f"Method 1 time cost: {end_time - start_time} seconds")
+
+ # Prepare data for the second method
+ x = x.permute(0, 2, 1, 3)
+ y = y.permute(0, 2, 1, 3)
+ rope = RotaryEmbedding(4).cuda()
+
+ # Second method
+ start_time = time.time()
+ for _ in range(20000):
+ x2, y2 = rope(x, y)
+ end_time = time.time()
+ print(f"Method 2 time cost: {end_time - start_time} seconds")
+
+ # Print the results
+ print(x1)
+ print(x2)
\ No newline at end of file
diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a2a8c841b62748120d7cb33a6aa10860ecdb674
--- /dev/null
+++ b/src/models/utils/modules.py
@@ -0,0 +1,374 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.cuda.amp import autocast
+import math
+import einops
+from einops import rearrange, repeat
+from inspect import isfunction
+from .timm import trunc_normal_
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def film_modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256,
+ out_size=None):
+ super().__init__()
+ if out_size is None:
+ out_size = hidden_size
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, out_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ def forward(self, t):
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
+ self.mlp[0].weight.dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+def patchify(imgs, patch_size, input_type='2d'):
+ if input_type == '2d':
+ x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
+ elif input_type == '1d':
+ x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
+ return x
+
+
+def unpatchify(x, channels=3, input_type='2d', img_size=None):
+ if input_type == '2d':
+ patch_size = int((x.shape[2] // channels) ** 0.5)
+ # h = w = int(x.shape[1] ** .5)
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
+ assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]
+ x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h,
+ p1=patch_size, p2=patch_size)
+ elif input_type == '1d':
+ patch_size = int((x.shape[2] // channels))
+ h = x.shape[1]
+ assert patch_size * channels == x.shape[2]
+ x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding
+ """
+
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
+ super().__init__()
+ self.patch_size = patch_size
+ self.input_type = input_type
+ if input_type == '2d':
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
+ elif input_type == '1d':
+ self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
+
+ def forward(self, x):
+ if self.input_type == '2d':
+ B, C, H, W = x.shape
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
+ elif self.input_type == '1d':
+ B, C, H = x.shape
+ assert H % self.patch_size == 0
+
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class PositionalConvEmbedding(nn.Module):
+ """
+ Relative positional embedding used in HuBERT
+ """
+
+ def __init__(self, dim=768, kernel_size=128, groups=16):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ groups=groups,
+ bias=True
+ )
+ self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
+
+ def forward(self, x):
+ # B C T
+ x = self.conv(x)
+ x = F.gelu(x[:, :, :-1])
+ return x
+
+
+class SinusoidalPositionalEncoding(nn.Module):
+ def __init__(self, dim, length):
+ super(SinusoidalPositionalEncoding, self).__init__()
+ self.length = length
+ self.dim = dim
+ self.register_buffer('pe', self._generate_positional_encoding(length, dim))
+
+ def _generate_positional_encoding(self, length, dim):
+ pe = torch.zeros(length, dim)
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+
+ pe = pe.unsqueeze(0)
+ return pe
+
+ def forward(self, x):
+ x = x + self.pe[:, :x.size(1)]
+ return x
+
+
+class PE_wrapper(nn.Module):
+ def __init__(self, dim=768, method='abs', length=None, **kwargs):
+ super().__init__()
+ self.method = method
+ if method == 'abs':
+ # init absolute pe like UViT
+ self.length = length
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
+ trunc_normal_(self.abs_pe, std=.02)
+ elif method == 'conv':
+ self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
+ elif method == 'sinu':
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
+ elif method == 'none':
+ # skip pe
+ self.id = nn.Identity()
+ else:
+ raise NotImplementedError
+
+ def forward(self, x):
+ if self.method == 'abs':
+ _, L, _ = x.shape
+ assert L <= self.length
+ x = x + self.abs_pe[:, :L, :]
+ elif self.method == 'conv':
+ x = x + self.conv_pe(x)
+ elif self.method == 'sinu':
+ x = self.sinu_pe(x)
+ elif self.method == 'none':
+ x = self.id(x)
+ else:
+ raise NotImplementedError
+ return x
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+class GELU(nn.Module):
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none",
+ bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32),
+ approximate=self.approximate).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def snake_beta(x, alpha, beta):
+ return x + beta * torch.sin(x * alpha).pow(2)
+
+
+class Snake(nn.Module):
+ def __init__(self, dim_in, dim_out, bias,
+ alpha_trainable=True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = snake_beta(x, self.alpha, self.beta)
+ return x
+
+
+class GESnake(nn.Module):
+ def __init__(self, dim_in, dim_out, bias,
+ alpha_trainable=True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ def forward(self, x):
+ x = self.proj(x)
+ x, gate = x.chunk(2, dim=-1)
+ return x * snake_beta(gate, self.alpha, self.beta)
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out=None,
+ mult=4,
+ dropout=0.0,
+ activation_fn="geglu",
+ final_dropout=False,
+ inner_dim=None,
+ bias=True,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim, bias=bias)
+ elif activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+ elif activation_fn == "snake":
+ act_fn = Snake(dim, inner_dim, bias=bias)
+ elif activation_fn == "gesnake":
+ act_fn = GESnake(dim, inner_dim, bias=bias)
+ else:
+ raise NotImplementedError
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
\ No newline at end of file
diff --git a/src/models/utils/rotary.py b/src/models/utils/rotary.py
new file mode 100644
index 0000000000000000000000000000000000000000..04f8b199ced89d0ed0365b8d74c1088749e7c441
--- /dev/null
+++ b/src/models/utils/rotary.py
@@ -0,0 +1,91 @@
+import torch
+
+"this rope is faster than llama rope with jit script"
+
+
+def rotate_half(x):
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# disable in checkpoint mode
+# @torch.jit.script
+def apply_rotary_pos_emb(x, cos, sin):
+ # NOTE: This could probably be moved to Triton
+ # Handle a possible sequence length mismatch in between q and k
+ cos = cos[:, :, : x.shape[-2], :]
+ sin = sin[:, :, : x.shape[-2], :]
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+class RotaryEmbedding(torch.nn.Module):
+ """
+ The rotary position embeddings from RoFormer_ (Su et. al).
+ A crucial insight from the method is that the query and keys are
+ transformed by rotation matrices which depend on the relative positions.
+
+ Other implementations are available in the Rotary Transformer repo_ and in
+ GPT-NeoX_, GPT-NeoX was an inspiration
+
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
+
+
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
+ """
+
+ def __init__(self, dim: int):
+ super().__init__()
+ # Generate and save the inverse frequency buffer (non trainable)
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self._seq_len_cached = None
+ self._cos_cached = None
+ self._sin_cached = None
+
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
+ # expect input: B, H, L, D
+ seq_len = x.shape[seq_dimension]
+
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ # also make sure dtype wont change
+ if (
+ seq_len != self._seq_len_cached
+ or self._cos_cached.device != x.device
+ or self._cos_cached.dtype != x.dtype
+ ):
+ self._seq_len_cached = seq_len
+ t = torch.arange(
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
+ )
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
+
+ return self._cos_cached, self._sin_cached
+
+ def forward(self, q, k):
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
+ q.float(), seq_dimension=-2
+ )
+ if k is not None:
+ return (
+ apply_rotary_pos_emb(q.float(),
+ self._cos_cached,
+ self._sin_cached).type_as(q),
+ apply_rotary_pos_emb(k.float(),
+ self._cos_cached,
+ self._sin_cached).type_as(k),
+ )
+ else:
+ return (
+ apply_rotary_pos_emb(q.float(),
+ self._cos_cached,
+ self._sin_cached).type_as(q),
+ None
+ )
\ No newline at end of file
diff --git a/src/models/utils/span_mask.py b/src/models/utils/span_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..23f8557e9907c4f9ec17efa36ebd035d8667ff00
--- /dev/null
+++ b/src/models/utils/span_mask.py
@@ -0,0 +1,146 @@
+import numpy as np
+import torch
+from typing import Optional, Tuple
+
+
+def compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ """
+
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ # Convert mask_prob to a NumPy array
+ mask_prob = np.array(mask_prob)
+
+ # Calculate all_num_mask for each element in the batch
+ all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int)
+
+ # Apply the max operation with min_masks for each element
+ all_num_mask = np.maximum(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = all_sz
+ num_mask = all_num_mask[i]
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = np.random.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = np.random.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - keep_length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = np.random.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ]
+ )
+
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+ # min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ # if len(mask_idc) > min_len:
+ # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+ mask[i, mask_idc] = True
+
+ return torch.tensor(mask)
+
+
+if __name__ == '__main__':
+ mask = compute_mask_indices(
+ shape=[4, 500],
+ padding_mask=None,
+ mask_prob=[0.65, 0.5, 0.65, 0.65],
+ mask_length=10,
+ mask_type="static",
+ mask_other=0.0,
+ min_masks=1,
+ no_overlap=False,
+ min_space=0,
+ )
+ print(mask)
+ print(mask.sum(dim=1))
\ No newline at end of file
diff --git a/src/models/utils/timm.py b/src/models/utils/timm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae72f4fca681617778028b64da5daed26b3ed3d4
--- /dev/null
+++ b/src/models/utils/timm.py
@@ -0,0 +1,114 @@
+# code from timm 0.3.2
+import torch
+import torch.nn as nn
+import math
+import warnings
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None,
+ act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
\ No newline at end of file
diff --git a/src/modules/autoencoder_wrapper.py b/src/modules/autoencoder_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ca144325af0e2642514da91c277b25fe6f52af8
--- /dev/null
+++ b/src/modules/autoencoder_wrapper.py
@@ -0,0 +1,83 @@
+import torch
+import torch.nn as nn
+from .dac import DAC
+from .stable_vae import load_vae
+
+
+class Autoencoder(nn.Module):
+ def __init__(self, ckpt_path, model_type='dac', quantization_first=False):
+ super(Autoencoder, self).__init__()
+ self.model_type = model_type
+ if self.model_type == 'dac':
+ model = DAC.load(ckpt_path)
+ elif self.model_type == 'stable_vae':
+ model = load_vae(ckpt_path)
+ else:
+ raise NotImplementedError(f"Model type not implemented: {self.model_type}")
+ self.ae = model.eval()
+ self.quantization_first = quantization_first
+ print(f'Autoencoder quantization first mode: {quantization_first}')
+
+ @torch.no_grad()
+ def forward(self, audio=None, embedding=None):
+ if self.model_type == 'dac':
+ return self.process_dac(audio, embedding)
+ elif self.model_type == 'encodec':
+ return self.process_encodec(audio, embedding)
+ elif self.model_type == 'stable_vae':
+ return self.process_stable_vae(audio, embedding)
+ else:
+ raise NotImplementedError(f"Model type not implemented: {self.model_type}")
+
+ def process_dac(self, audio=None, embedding=None):
+ if audio is not None:
+ z = self.ae.encoder(audio)
+ if self.quantization_first:
+ z, *_ = self.ae.quantizer(z, None)
+ return z
+ elif embedding is not None:
+ z = embedding
+ if self.quantization_first:
+ audio = self.ae.decoder(z)
+ else:
+ z, *_ = self.ae.quantizer(z, None)
+ audio = self.ae.decoder(z)
+ return audio
+ else:
+ raise ValueError("Either audio or embedding must be provided.")
+
+ def process_encodec(self, audio=None, embedding=None):
+ if audio is not None:
+ z = self.ae.encoder(audio)
+ if self.quantization_first:
+ code = self.ae.quantizer.encode(z)
+ z = self.ae.quantizer.decode(code)
+ return z
+ elif embedding is not None:
+ z = embedding
+ if self.quantization_first:
+ audio = self.ae.decoder(z)
+ else:
+ code = self.ae.quantizer.encode(z)
+ z = self.ae.quantizer.decode(code)
+ audio = self.ae.decoder(z)
+ return audio
+ else:
+ raise ValueError("Either audio or embedding must be provided.")
+
+ def process_stable_vae(self, audio=None, embedding=None):
+ if audio is not None:
+ z = self.ae.encoder(audio)
+ if self.quantization_first:
+ z = self.ae.bottleneck.encode(z)
+ return z
+ if embedding is not None:
+ z = embedding
+ if self.quantization_first:
+ audio = self.ae.decoder(z)
+ else:
+ z = self.ae.bottleneck.encode(z)
+ audio = self.ae.decoder(z)
+ return audio
+ else:
+ raise ValueError("Either audio or embedding must be provided.")
diff --git a/src/modules/clap_wrapper.py b/src/modules/clap_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/modules/dac/__init__.py b/src/modules/dac/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a5d03388fa63486960c783ebe7f1bd411b95b1d
--- /dev/null
+++ b/src/modules/dac/__init__.py
@@ -0,0 +1,16 @@
+__version__ = "1.0.0"
+
+# preserved here for legacy reasons
+__model_version__ = "latest"
+
+import audiotools
+
+audiotools.ml.BaseModel.INTERN += ["dac.**"]
+audiotools.ml.BaseModel.EXTERN += ["einops"]
+
+
+from . import nn
+from . import model
+from . import utils
+from .model import DAC
+from .model import DACFile
diff --git a/src/modules/dac/__main__.py b/src/modules/dac/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..393698e7da671bce1478b4d78b2f685d2640636b
--- /dev/null
+++ b/src/modules/dac/__main__.py
@@ -0,0 +1,36 @@
+import sys
+
+import argbind
+
+from dac.utils import download
+from dac.utils.decode import decode
+from dac.utils.encode import encode
+
+STAGES = ["encode", "decode", "download"]
+
+
+def run(stage: str):
+ """Run stages.
+
+ Parameters
+ ----------
+ stage : str
+ Stage to run
+ """
+ if stage not in STAGES:
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
+ stage_fn = globals()[stage]
+
+ if stage == "download":
+ stage_fn()
+ return
+
+ stage_fn()
+
+
+if __name__ == "__main__":
+ group = sys.argv.pop(1)
+ args = argbind.parse_args(group=group)
+
+ with argbind.scope(args):
+ run(group)
diff --git a/src/modules/dac/compare/__init__.py b/src/modules/dac/compare/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/modules/dac/compare/encodec.py b/src/modules/dac/compare/encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..f74e4990eacdafaa8a2e10fb62c38bde10816db0
--- /dev/null
+++ b/src/modules/dac/compare/encodec.py
@@ -0,0 +1,54 @@
+import torch
+from audiotools import AudioSignal
+from audiotools.ml import BaseModel
+from encodec import EncodecModel
+
+
+class Encodec(BaseModel):
+ def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
+ super().__init__()
+
+ if sample_rate == 24000:
+ self.model = EncodecModel.encodec_model_24khz()
+ else:
+ self.model = EncodecModel.encodec_model_48khz()
+ self.model.set_target_bandwidth(bandwidth)
+ self.sample_rate = 44100
+
+ def forward(
+ self,
+ audio_data: torch.Tensor,
+ sample_rate: int = 44100,
+ n_quantizers: int = None,
+ ):
+ signal = AudioSignal(audio_data, sample_rate)
+ signal.resample(self.model.sample_rate)
+ recons = self.model(signal.audio_data)
+ recons = AudioSignal(recons, self.model.sample_rate)
+ recons.resample(sample_rate)
+ return {"audio": recons.audio_data}
+
+
+if __name__ == "__main__":
+ import numpy as np
+ from functools import partial
+
+ model = Encodec()
+
+ for n, m in model.named_modules():
+ o = m.extra_repr()
+ p = sum([np.prod(p.size()) for p in m.parameters()])
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
+ print(model)
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
+
+ length = 88200 * 2
+ x = torch.randn(1, 1, length).to(model.device)
+ x.requires_grad_(True)
+ x.retain_grad()
+
+ # Make a forward pass
+ out = model(x)["audio"]
+
+ print(x.shape, out.shape)
diff --git a/src/modules/dac/model/__init__.py b/src/modules/dac/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..58b47475d39e4249d2cd45577503bb68ffdacd00
--- /dev/null
+++ b/src/modules/dac/model/__init__.py
@@ -0,0 +1,4 @@
+from .base import CodecMixin
+from .base import DACFile
+from .dac import DAC
+from .discriminator import Discriminator
diff --git a/src/modules/dac/model/base.py b/src/modules/dac/model/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ef5a44074ca2bd5d726fe90fa7c8c87da1b3b7a
--- /dev/null
+++ b/src/modules/dac/model/base.py
@@ -0,0 +1,294 @@
+import math
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Union
+
+import numpy as np
+import torch
+import tqdm
+from audiotools import AudioSignal
+from torch import nn
+
+SUPPORTED_VERSIONS = ["1.0.0"]
+
+
+@dataclass
+class DACFile:
+ codes: torch.Tensor
+
+ # Metadata
+ chunk_length: int
+ original_length: int
+ input_db: float
+ channels: int
+ sample_rate: int
+ padding: bool
+ dac_version: str
+
+ def save(self, path):
+ artifacts = {
+ "codes": self.codes.numpy().astype(np.uint16),
+ "metadata": {
+ "input_db": self.input_db.numpy().astype(np.float32),
+ "original_length": self.original_length,
+ "sample_rate": self.sample_rate,
+ "chunk_length": self.chunk_length,
+ "channels": self.channels,
+ "padding": self.padding,
+ "dac_version": SUPPORTED_VERSIONS[-1],
+ },
+ }
+ path = Path(path).with_suffix(".dac")
+ with open(path, "wb") as f:
+ np.save(f, artifacts)
+ return path
+
+ @classmethod
+ def load(cls, path):
+ artifacts = np.load(path, allow_pickle=True)[()]
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
+ raise RuntimeError(
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
+ )
+ return cls(codes=codes, **artifacts["metadata"])
+
+
+class CodecMixin:
+ @property
+ def padding(self):
+ if not hasattr(self, "_padding"):
+ self._padding = True
+ return self._padding
+
+ @padding.setter
+ def padding(self, value):
+ assert isinstance(value, bool)
+
+ layers = [
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
+ ]
+
+ for layer in layers:
+ if value:
+ if hasattr(layer, "original_padding"):
+ layer.padding = layer.original_padding
+ else:
+ layer.original_padding = layer.padding
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
+
+ self._padding = value
+
+ def get_delay(self):
+ # Any number works here, delay is invariant to input length
+ l_out = self.get_output_length(0)
+ L = l_out
+
+ layers = []
+ for layer in self.modules():
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
+ layers.append(layer)
+
+ for layer in reversed(layers):
+ d = layer.dilation[0]
+ k = layer.kernel_size[0]
+ s = layer.stride[0]
+
+ if isinstance(layer, nn.ConvTranspose1d):
+ L = ((L - d * (k - 1) - 1) / s) + 1
+ elif isinstance(layer, nn.Conv1d):
+ L = (L - 1) * s + d * (k - 1) + 1
+
+ L = math.ceil(L)
+
+ l_in = L
+
+ return (l_in - l_out) // 2
+
+ def get_output_length(self, input_length):
+ L = input_length
+ # Calculate output length
+ for layer in self.modules():
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
+ d = layer.dilation[0]
+ k = layer.kernel_size[0]
+ s = layer.stride[0]
+
+ if isinstance(layer, nn.Conv1d):
+ L = ((L - d * (k - 1) - 1) / s) + 1
+ elif isinstance(layer, nn.ConvTranspose1d):
+ L = (L - 1) * s + d * (k - 1) + 1
+
+ L = math.floor(L)
+ return L
+
+ @torch.no_grad()
+ def compress(
+ self,
+ audio_path_or_signal: Union[str, Path, AudioSignal],
+ win_duration: float = 1.0,
+ verbose: bool = False,
+ normalize_db: float = -16,
+ n_quantizers: int = None,
+ ) -> DACFile:
+ """Processes an audio signal from a file or AudioSignal object into
+ discrete codes. This function processes the signal in short windows,
+ using constant GPU memory.
+
+ Parameters
+ ----------
+ audio_path_or_signal : Union[str, Path, AudioSignal]
+ audio signal to reconstruct
+ win_duration : float, optional
+ window duration in seconds, by default 5.0
+ verbose : bool, optional
+ by default False
+ normalize_db : float, optional
+ normalize db, by default -16
+
+ Returns
+ -------
+ DACFile
+ Object containing compressed codes and metadata
+ required for decompression
+ """
+ audio_signal = audio_path_or_signal
+ if isinstance(audio_signal, (str, Path)):
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
+
+ self.eval()
+ original_padding = self.padding
+ original_device = audio_signal.device
+
+ audio_signal = audio_signal.clone()
+ original_sr = audio_signal.sample_rate
+
+ resample_fn = audio_signal.resample
+ loudness_fn = audio_signal.loudness
+
+ # If audio is > 10 minutes long, use the ffmpeg versions
+ if audio_signal.signal_duration >= 10 * 60 * 60:
+ resample_fn = audio_signal.ffmpeg_resample
+ loudness_fn = audio_signal.ffmpeg_loudness
+
+ original_length = audio_signal.signal_length
+ resample_fn(self.sample_rate)
+ input_db = loudness_fn()
+
+ if normalize_db is not None:
+ audio_signal.normalize(normalize_db)
+ audio_signal.ensure_max_of_audio()
+
+ nb, nac, nt = audio_signal.audio_data.shape
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
+ win_duration = (
+ audio_signal.signal_duration if win_duration is None else win_duration
+ )
+
+ if audio_signal.signal_duration <= win_duration:
+ # Unchunked compression (used if signal length < win duration)
+ self.padding = True
+ n_samples = nt
+ hop = nt
+ else:
+ # Chunked inference
+ self.padding = False
+ # Zero-pad signal on either side by the delay
+ audio_signal.zero_pad(self.delay, self.delay)
+ n_samples = int(win_duration * self.sample_rate)
+ # Round n_samples to nearest hop length multiple
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
+ hop = self.get_output_length(n_samples)
+
+ codes = []
+ range_fn = range if not verbose else tqdm.trange
+
+ for i in range_fn(0, nt, hop):
+ x = audio_signal[..., i : i + n_samples]
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
+
+ audio_data = x.audio_data.to(self.device)
+ audio_data = self.preprocess(audio_data, self.sample_rate)
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
+ codes.append(c.to(original_device))
+ chunk_length = c.shape[-1]
+
+ codes = torch.cat(codes, dim=-1)
+
+ dac_file = DACFile(
+ codes=codes,
+ chunk_length=chunk_length,
+ original_length=original_length,
+ input_db=input_db,
+ channels=nac,
+ sample_rate=original_sr,
+ padding=self.padding,
+ dac_version=SUPPORTED_VERSIONS[-1],
+ )
+
+ if n_quantizers is not None:
+ codes = codes[:, :n_quantizers, :]
+
+ self.padding = original_padding
+ return dac_file
+
+ @torch.no_grad()
+ def decompress(
+ self,
+ obj: Union[str, Path, DACFile],
+ verbose: bool = False,
+ ) -> AudioSignal:
+ """Reconstruct audio from a given .dac file
+
+ Parameters
+ ----------
+ obj : Union[str, Path, DACFile]
+ .dac file location or corresponding DACFile object.
+ verbose : bool, optional
+ Prints progress if True, by default False
+
+ Returns
+ -------
+ AudioSignal
+ Object with the reconstructed audio
+ """
+ self.eval()
+ if isinstance(obj, (str, Path)):
+ obj = DACFile.load(obj)
+
+ original_padding = self.padding
+ self.padding = obj.padding
+
+ range_fn = range if not verbose else tqdm.trange
+ codes = obj.codes
+ original_device = codes.device
+ chunk_length = obj.chunk_length
+ recons = []
+
+ for i in range_fn(0, codes.shape[-1], chunk_length):
+ c = codes[..., i : i + chunk_length].to(self.device)
+ z = self.quantizer.from_codes(c)[0]
+ r = self.decode(z)
+ recons.append(r.to(original_device))
+
+ recons = torch.cat(recons, dim=-1)
+ recons = AudioSignal(recons, self.sample_rate)
+
+ resample_fn = recons.resample
+ loudness_fn = recons.loudness
+
+ # If audio is > 10 minutes long, use the ffmpeg versions
+ if recons.signal_duration >= 10 * 60 * 60:
+ resample_fn = recons.ffmpeg_resample
+ loudness_fn = recons.ffmpeg_loudness
+
+ recons.normalize(obj.input_db)
+ resample_fn(obj.sample_rate)
+ recons = recons[..., : obj.original_length]
+ loudness_fn()
+ recons.audio_data = recons.audio_data.reshape(
+ -1, obj.channels, obj.original_length
+ )
+
+ self.padding = original_padding
+ return recons
diff --git a/src/modules/dac/model/dac.py b/src/modules/dac/model/dac.py
new file mode 100644
index 0000000000000000000000000000000000000000..98c5983644dd4ffdf87730bc92ac3e9bb13e4374
--- /dev/null
+++ b/src/modules/dac/model/dac.py
@@ -0,0 +1,364 @@
+import math
+from typing import List
+from typing import Union
+
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from audiotools.ml import BaseModel
+from torch import nn
+
+from .base import CodecMixin
+from ..nn.layers import Snake1d
+from ..nn.layers import WNConv1d
+from ..nn.layers import WNConvTranspose1d
+from ..nn.quantize import ResidualVectorQuantize
+
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+
+class ResidualUnit(nn.Module):
+ def __init__(self, dim: int = 16, dilation: int = 1):
+ super().__init__()
+ pad = ((7 - 1) * dilation) // 2
+ self.block = nn.Sequential(
+ Snake1d(dim),
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+ Snake1d(dim),
+ WNConv1d(dim, dim, kernel_size=1),
+ )
+
+ def forward(self, x):
+ y = self.block(x)
+ pad = (x.shape[-1] - y.shape[-1]) // 2
+ if pad > 0:
+ x = x[..., pad:-pad]
+ return x + y
+
+
+class EncoderBlock(nn.Module):
+ def __init__(self, dim: int = 16, stride: int = 1):
+ super().__init__()
+ self.block = nn.Sequential(
+ ResidualUnit(dim // 2, dilation=1),
+ ResidualUnit(dim // 2, dilation=3),
+ ResidualUnit(dim // 2, dilation=9),
+ Snake1d(dim // 2),
+ WNConv1d(
+ dim // 2,
+ dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ ),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ d_model: int = 64,
+ strides: list = [2, 4, 8, 8],
+ d_latent: int = 64,
+ ):
+ super().__init__()
+ # Create first convolution
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+ # Create EncoderBlocks that double channels as they downsample by `stride`
+ for stride in strides:
+ d_model *= 2
+ self.block += [EncoderBlock(d_model, stride=stride)]
+
+ # Create last convolution
+ self.block += [
+ Snake1d(d_model),
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
+ ]
+
+ # Wrap black into nn.Sequential
+ self.block = nn.Sequential(*self.block)
+ self.enc_dim = d_model
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
+ super().__init__()
+ self.block = nn.Sequential(
+ Snake1d(input_dim),
+ WNConvTranspose1d(
+ input_dim,
+ output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ ),
+ ResidualUnit(output_dim, dilation=1),
+ ResidualUnit(output_dim, dilation=3),
+ ResidualUnit(output_dim, dilation=9),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ input_channel,
+ channels,
+ rates,
+ d_out: int = 1,
+ ):
+ super().__init__()
+
+ # Add first conv layer
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
+
+ # Add upsampling + MRF blocks
+ for i, stride in enumerate(rates):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
+
+ # Add final conv layer
+ layers += [
+ Snake1d(output_dim),
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class DAC(BaseModel, CodecMixin):
+ def __init__(
+ self,
+ encoder_dim: int = 64,
+ encoder_rates: List[int] = [2, 4, 8, 8],
+ latent_dim: int = None,
+ decoder_dim: int = 1536,
+ decoder_rates: List[int] = [8, 8, 4, 2],
+ n_codebooks: int = 9,
+ codebook_size: int = 1024,
+ codebook_dim: Union[int, list] = 8,
+ quantizer_dropout: bool = False,
+ sample_rate: int = 44100,
+ ):
+ super().__init__()
+
+ self.encoder_dim = encoder_dim
+ self.encoder_rates = encoder_rates
+ self.decoder_dim = decoder_dim
+ self.decoder_rates = decoder_rates
+ self.sample_rate = sample_rate
+
+ if latent_dim is None:
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
+
+ self.latent_dim = latent_dim
+
+ self.hop_length = np.prod(encoder_rates)
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
+
+ self.n_codebooks = n_codebooks
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.quantizer = ResidualVectorQuantize(
+ input_dim=latent_dim,
+ n_codebooks=n_codebooks,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_dropout=quantizer_dropout,
+ )
+
+ self.decoder = Decoder(
+ latent_dim,
+ decoder_dim,
+ decoder_rates,
+ )
+ self.sample_rate = sample_rate
+ self.apply(init_weights)
+
+ self.delay = self.get_delay()
+
+ def preprocess(self, audio_data, sample_rate):
+ if sample_rate is None:
+ sample_rate = self.sample_rate
+ assert sample_rate == self.sample_rate
+
+ length = audio_data.shape[-1]
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
+
+ return audio_data
+
+ def encode(
+ self,
+ audio_data: torch.Tensor,
+ n_quantizers: int = None,
+ ):
+ """Encode given audio data and return quantized latent codes
+
+ Parameters
+ ----------
+ audio_data : Tensor[B x 1 x T]
+ Audio data to encode
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None
+ If None, all quantizers are used.
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ "length" : int
+ Number of samples in input audio
+ """
+ z = self.encoder(audio_data)
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
+ z, n_quantizers
+ )
+ return z, codes, latents, commitment_loss, codebook_loss
+
+ def decode(self, z: torch.Tensor):
+ """Decode given latent codes and return audio data
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ Quantized continuous representation of input
+ length : int, optional
+ Number of samples in output audio, by default None
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "audio" : Tensor[B x 1 x length]
+ Decoded audio data.
+ """
+ return self.decoder(z)
+
+ def forward(
+ self,
+ audio_data: torch.Tensor,
+ sample_rate: int = None,
+ n_quantizers: int = None,
+ ):
+ """Model forward pass
+
+ Parameters
+ ----------
+ audio_data : Tensor[B x 1 x T]
+ Audio data to encode
+ sample_rate : int, optional
+ Sample rate of audio data in Hz, by default None
+ If None, defaults to `self.sample_rate`
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None.
+ If None, all quantizers are used.
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ "length" : int
+ Number of samples in input audio
+ "audio" : Tensor[B x 1 x length]
+ Decoded audio data.
+ """
+ length = audio_data.shape[-1]
+ audio_data = self.preprocess(audio_data, sample_rate)
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(
+ audio_data, n_quantizers
+ )
+
+ x = self.decode(z)
+ return {
+ "audio": x[..., :length],
+ "z": z,
+ "codes": codes,
+ "latents": latents,
+ "vq/commitment_loss": commitment_loss,
+ "vq/codebook_loss": codebook_loss,
+ }
+
+
+if __name__ == "__main__":
+ import numpy as np
+ from functools import partial
+
+ model = DAC().to("cpu")
+
+ for n, m in model.named_modules():
+ o = m.extra_repr()
+ p = sum([np.prod(p.size()) for p in m.parameters()])
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
+ print(model)
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
+
+ length = 88200 * 2
+ x = torch.randn(1, 1, length).to(model.device)
+ x.requires_grad_(True)
+ x.retain_grad()
+
+ # Make a forward pass
+ out = model(x)["audio"]
+ print("Input shape:", x.shape)
+ print("Output shape:", out.shape)
+
+ # Create gradient variable
+ grad = torch.zeros_like(out)
+ grad[:, :, grad.shape[-1] // 2] = 1
+
+ # Make a backward pass
+ out.backward(grad)
+
+ # Check non-zero values
+ gradmap = x.grad.squeeze(0)
+ gradmap = (gradmap != 0).sum(0) # sum across features
+ rf = (gradmap != 0).sum()
+
+ print(f"Receptive field: {rf.item()}")
+
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
+ model.decompress(model.compress(x, verbose=True), verbose=True)
diff --git a/src/modules/dac/model/discriminator.py b/src/modules/dac/model/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ade8e6ddc3741e6fb443ba4c88cd179c3f0196eb
--- /dev/null
+++ b/src/modules/dac/model/discriminator.py
@@ -0,0 +1,228 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import ml
+from audiotools import STFTParams
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ act = kwargs.pop("act", True)
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
+ if not act:
+ return conv
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+def WNConv2d(*args, **kwargs):
+ act = kwargs.pop("act", True)
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
+ if not act:
+ return conv
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+class MPD(nn.Module):
+ def __init__(self, period):
+ super().__init__()
+ self.period = period
+ self.convs = nn.ModuleList(
+ [
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
+ ]
+ )
+ self.conv_post = WNConv2d(
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
+ )
+
+ def pad_to_period(self, x):
+ t = x.shape[-1]
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
+ return x
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.pad_to_period(x)
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
+
+ for layer in self.convs:
+ x = layer(x)
+ fmap.append(x)
+
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+class MSD(nn.Module):
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
+ super().__init__()
+ self.convs = nn.ModuleList(
+ [
+ WNConv1d(1, 16, 15, 1, padding=7),
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
+ WNConv1d(1024, 1024, 5, 1, padding=2),
+ ]
+ )
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
+ self.sample_rate = sample_rate
+ self.rate = rate
+
+ def forward(self, x):
+ x = AudioSignal(x, self.sample_rate)
+ x.resample(self.sample_rate // self.rate)
+ x = x.audio_data
+
+ fmap = []
+
+ for l in self.convs:
+ x = l(x)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
+
+
+class MRD(nn.Module):
+ def __init__(
+ self,
+ window_length: int,
+ hop_factor: float = 0.25,
+ sample_rate: int = 44100,
+ bands: list = BANDS,
+ ):
+ """Complex multi-band spectrogram discriminator.
+ Parameters
+ ----------
+ window_length : int
+ Window length of STFT.
+ hop_factor : float, optional
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 44100
+ bands : list, optional
+ Bands to run discriminator over.
+ """
+ super().__init__()
+
+ self.window_length = window_length
+ self.hop_factor = hop_factor
+ self.sample_rate = sample_rate
+ self.stft_params = STFTParams(
+ window_length=window_length,
+ hop_length=int(window_length * hop_factor),
+ match_stride=True,
+ )
+
+ n_fft = window_length // 2 + 1
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+ self.bands = bands
+
+ ch = 32
+ convs = lambda: nn.ModuleList(
+ [
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
+ ]
+ )
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
+
+ def spectrogram(self, x):
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
+ x = torch.view_as_real(x.stft())
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
+ # Split into bands
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
+ return x_bands
+
+ def forward(self, x):
+ x_bands = self.spectrogram(x)
+ fmap = []
+
+ x = []
+ for band, stack in zip(x_bands, self.band_convs):
+ for layer in stack:
+ band = layer(band)
+ fmap.append(band)
+ x.append(band)
+
+ x = torch.cat(x, dim=-1)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+class Discriminator(ml.BaseModel):
+ def __init__(
+ self,
+ rates: list = [],
+ periods: list = [2, 3, 5, 7, 11],
+ fft_sizes: list = [2048, 1024, 512],
+ sample_rate: int = 44100,
+ bands: list = BANDS,
+ ):
+ """Discriminator that combines multiple discriminators.
+
+ Parameters
+ ----------
+ rates : list, optional
+ sampling rates (in Hz) to run MSD at, by default []
+ If empty, MSD is not used.
+ periods : list, optional
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
+ fft_sizes : list, optional
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 44100
+ bands : list, optional
+ Bands to run MRD at, by default `BANDS`
+ """
+ super().__init__()
+ discs = []
+ discs += [MPD(p) for p in periods]
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
+ self.discriminators = nn.ModuleList(discs)
+
+ def preprocess(self, y):
+ # Remove DC offset
+ y = y - y.mean(dim=-1, keepdims=True)
+ # Peak normalize the volume of input audio
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+ return y
+
+ def forward(self, x):
+ x = self.preprocess(x)
+ fmaps = [d(x) for d in self.discriminators]
+ return fmaps
+
+
+if __name__ == "__main__":
+ disc = Discriminator()
+ x = torch.zeros(1, 1, 44100)
+ results = disc(x)
+ for i, result in enumerate(results):
+ print(f"disc{i}")
+ for i, r in enumerate(result):
+ print(r.shape, r.mean(), r.min(), r.max())
+ print()
diff --git a/src/modules/dac/nn/__init__.py b/src/modules/dac/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec2f6b972aeb8a21764ff73dae2095eb94bb8ba4
--- /dev/null
+++ b/src/modules/dac/nn/__init__.py
@@ -0,0 +1,3 @@
+from . import layers
+from . import loss
+from . import quantize
diff --git a/src/modules/dac/nn/layers.py b/src/modules/dac/nn/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c53a88b881deab94fb06b6a395951cf9cc995a
--- /dev/null
+++ b/src/modules/dac/nn/layers.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+ shape = x.shape
+ x = x.reshape(shape[0], shape[1], -1)
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+ x = x.reshape(shape)
+ return x
+
+
+class Snake1d(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+ def forward(self, x):
+ return snake(x, self.alpha)
diff --git a/src/modules/dac/nn/loss.py b/src/modules/dac/nn/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a26e62f885f0bb7c6774e95660a7848e10a8f87
--- /dev/null
+++ b/src/modules/dac/nn/loss.py
@@ -0,0 +1,368 @@
+import typing
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import STFTParams
+from torch import nn
+
+
+class L1Loss(nn.L1Loss):
+ """L1 Loss between AudioSignals. Defaults
+ to comparing ``audio_data``, but any
+ attribute of an AudioSignal can be used.
+
+ Parameters
+ ----------
+ attribute : str, optional
+ Attribute of signal to compare, defaults to ``audio_data``.
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+ self.attribute = attribute
+ self.weight = weight
+ super().__init__(**kwargs)
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate AudioSignal
+ y : AudioSignal
+ Reference AudioSignal
+
+ Returns
+ -------
+ torch.Tensor
+ L1 loss between AudioSignal attributes.
+ """
+ if isinstance(x, AudioSignal):
+ x = getattr(x, self.attribute)
+ y = getattr(y, self.attribute)
+ return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+ """
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+ of estimated and reference audio signals or aligned features.
+
+ Parameters
+ ----------
+ scaling : int, optional
+ Whether to use scale-invariant (True) or
+ signal-to-noise ratio (False), by default True
+ reduction : str, optional
+ How to reduce across the batch (either 'mean',
+ 'sum', or none).], by default ' mean'
+ zero_mean : int, optional
+ Zero mean the references and estimates before
+ computing the loss, by default True
+ clip_min : int, optional
+ The minimum possible loss value. Helps network
+ to not focus on making already good examples better, by default None
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(
+ self,
+ scaling: int = True,
+ reduction: str = "mean",
+ zero_mean: int = True,
+ clip_min: int = None,
+ weight: float = 1.0,
+ ):
+ self.scaling = scaling
+ self.reduction = reduction
+ self.zero_mean = zero_mean
+ self.clip_min = clip_min
+ self.weight = weight
+ super().__init__()
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ eps = 1e-8
+ # nb, nc, nt
+ if isinstance(x, AudioSignal):
+ references = x.audio_data
+ estimates = y.audio_data
+ else:
+ references = x
+ estimates = y
+
+ nb = references.shape[0]
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+ # samples now on axis 1
+ if self.zero_mean:
+ mean_reference = references.mean(dim=1, keepdim=True)
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
+ else:
+ mean_reference = 0
+ mean_estimate = 0
+
+ _references = references - mean_reference
+ _estimates = estimates - mean_estimate
+
+ references_projection = (_references**2).sum(dim=-2) + eps
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+ scale = (
+ (references_on_estimates / references_projection).unsqueeze(1)
+ if self.scaling
+ else 1
+ )
+
+ e_true = scale * _references
+ e_res = _estimates - e_true
+
+ signal = (e_true**2).sum(dim=1)
+ noise = (e_res**2).sum(dim=1)
+ sdr = -10 * torch.log10(signal / noise + eps)
+
+ if self.clip_min is not None:
+ sdr = torch.clamp(sdr, min=self.clip_min)
+
+ if self.reduction == "mean":
+ sdr = sdr.mean()
+ elif self.reduction == "sum":
+ sdr = sdr.sum()
+ return sdr
+
+
+class MultiScaleSTFTLoss(nn.Module):
+ """Computes the multi-scale STFT loss from [1].
+
+ Parameters
+ ----------
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ References
+ ----------
+
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
+ "DDSP: Differentiable Digital Signal Processing."
+ International Conference on Learning Representations. 2019.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.loss_fn = loss_fn
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.clamp_eps = clamp_eps
+ self.weight = weight
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes multi-scale STFT between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Multi-scale STFT loss.
+ """
+ loss = 0.0
+ for s in self.stft_params:
+ x.stft(s.window_length, s.hop_length, s.window_type)
+ y.stft(s.window_length, s.hop_length, s.window_type)
+ loss += self.log_weight * self.loss_fn(
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+ return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+ """Compute distance between mel spectrograms. Can be used
+ in a multi-scale way.
+
+ Parameters
+ ----------
+ n_mels : List[int]
+ Number of mels per STFT, by default [150, 80],
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ n_mels: List[int] = [150, 80],
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ mel_fmin: List[float] = [0.0, 0.0],
+ mel_fmax: List[float] = [None, None],
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.n_mels = n_mels
+ self.loss_fn = loss_fn
+ self.clamp_eps = clamp_eps
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.weight = weight
+ self.mel_fmin = mel_fmin
+ self.mel_fmax = mel_fmax
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes mel loss between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Mel loss.
+ """
+ loss = 0.0
+ for n_mels, fmin, fmax, s in zip(
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+ ):
+ kwargs = {
+ "window_length": s.window_length,
+ "hop_length": s.hop_length,
+ "window_type": s.window_type,
+ }
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+ loss += self.log_weight * self.loss_fn(
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+ return loss
+
+
+class GANLoss(nn.Module):
+ """
+ Computes a discriminator loss, given a discriminator on
+ generated waveforms/spectrograms compared to ground truth
+ waveforms/spectrograms. Computes the loss for both the
+ discriminator and the generator in separate functions.
+ """
+
+ def __init__(self, discriminator):
+ super().__init__()
+ self.discriminator = discriminator
+
+ def forward(self, fake, real):
+ d_fake = self.discriminator(fake.audio_data)
+ d_real = self.discriminator(real.audio_data)
+ return d_fake, d_real
+
+ def discriminator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+ loss_d = 0
+ for x_fake, x_real in zip(d_fake, d_real):
+ loss_d += torch.mean(x_fake[-1] ** 2)
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
+ return loss_d
+
+ def generator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake, real)
+
+ loss_g = 0
+ for x_fake in d_fake:
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+ loss_feature = 0
+
+ for i in range(len(d_fake)):
+ for j in range(len(d_fake[i]) - 1):
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+ return loss_g, loss_feature
diff --git a/src/modules/dac/nn/quantize.py b/src/modules/dac/nn/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..499088fbde266dc89a704e45944b8900f2c72952
--- /dev/null
+++ b/src/modules/dac/nn/quantize.py
@@ -0,0 +1,262 @@
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from .layers import WNConv1d
+
+
+class VectorQuantize(nn.Module):
+ """
+ Implementation of VQ similar to Karpathy's repo:
+ https://github.com/karpathy/deep-vector-quantization
+ Additionally uses following tricks from Improved VQGAN
+ (https://arxiv.org/pdf/2110.04627.pdf):
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+ for improved codebook usage
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+ improves training stability
+ """
+
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+ def forward(self, z):
+ """Quantized the input tensor using a fixed codebook and returns
+ the corresponding codebook vectors
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ Tensor[1]
+ Codebook loss to update the codebook
+ Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+ z_e = self.in_proj(z) # z_e : (B x D x T)
+ z_q, indices = self.decode_latents(z_e)
+
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+ z_q = (
+ z_e + (z_q - z_e).detach()
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
+
+ z_q = self.out_proj(z_q)
+
+ return z_q, commitment_loss, codebook_loss, indices, z_e
+
+ def embed_code(self, embed_id):
+ return F.embedding(embed_id, self.codebook.weight)
+
+ def decode_code(self, embed_id):
+ return self.embed_code(embed_id).transpose(1, 2)
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight # codebook: (N x D)
+
+ # L2 normalize encodings and codebook (ViT-VQGAN)
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance with codebook
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ )
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ z_q = self.decode_code(indices)
+ return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+ """
+ Introduced in SoundStream: An end2end neural audio codec
+ https://arxiv.org/abs/2107.03312
+ """
+
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ codebook_size: int = 1024,
+ codebook_dim: Union[int, list] = 8,
+ quantizer_dropout: float = 0.0,
+ ):
+ super().__init__()
+ if isinstance(codebook_dim, int):
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+ self.n_codebooks = n_codebooks
+ self.codebook_dim = codebook_dim
+ self.codebook_size = codebook_size
+
+ self.quantizers = nn.ModuleList(
+ [
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+ for i in range(n_codebooks)
+ ]
+ )
+ self.quantizer_dropout = quantizer_dropout
+
+ def forward(self, z, n_quantizers: int = None):
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
+ the corresponding codebook vectors
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ n_quantizers : int, optional
+ No. of quantizers to use
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
+ when in training mode, and a random number of quantizers is used.
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ """
+ z_q = 0
+ residual = z
+ commitment_loss = 0
+ codebook_loss = 0
+
+ codebook_indices = []
+ latents = []
+
+ if n_quantizers is None:
+ n_quantizers = self.n_codebooks
+ if self.training:
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
+ n_quantizers = n_quantizers.to(z.device)
+
+ for i, quantizer in enumerate(self.quantizers):
+ if self.training is False and i >= n_quantizers:
+ break
+
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+ residual
+ )
+
+ # Create mask to apply quantizer dropout
+ mask = (
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+ )
+ z_q = z_q + z_q_i * mask[:, None, None]
+ residual = residual - z_q_i
+
+ # Sum losses
+ commitment_loss += (commitment_loss_i * mask).mean()
+ codebook_loss += (codebook_loss_i * mask).mean()
+
+ codebook_indices.append(indices_i)
+ latents.append(z_e_i)
+
+ codes = torch.stack(codebook_indices, dim=1)
+ latents = torch.cat(latents, dim=1)
+
+ return z_q, codes, latents, commitment_loss, codebook_loss
+
+ def from_codes(self, codes: torch.Tensor):
+ """Given the quantized codes, reconstruct the continuous representation
+ Parameters
+ ----------
+ codes : Tensor[B x N x T]
+ Quantized discrete representation of input
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ """
+ z_q = 0.0
+ z_p = []
+ n_codebooks = codes.shape[1]
+ for i in range(n_codebooks):
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+ z_p.append(z_p_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+ return z_q, torch.cat(z_p, dim=1), codes
+
+ def from_latents(self, latents: torch.Tensor):
+ """Given the unquantized latents, reconstruct the
+ continuous representation after quantization.
+
+ Parameters
+ ----------
+ latents : Tensor[B x N x T]
+ Continuous representation of input after projection
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized representation of full-projected space
+ Tensor[B x D x T]
+ Quantized representation of latent space
+ """
+ z_q = 0
+ z_p = []
+ codes = []
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+ 0
+ ]
+ for i in range(n_codebooks):
+ j, k = dims[i], dims[i + 1]
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+ z_p.append(z_p_i)
+ codes.append(codes_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+if __name__ == "__main__":
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
+ x = torch.randn(16, 512, 80)
+ y = rvq(x)
+ print(y["latents"].shape)
diff --git a/src/modules/dac/utils/__init__.py b/src/modules/dac/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4522f8b226ac1c5cf8abeefc6ba32328c2375d5
--- /dev/null
+++ b/src/modules/dac/utils/__init__.py
@@ -0,0 +1,122 @@
+from pathlib import Path
+
+import argbind
+from audiotools import ml
+
+from ..model import DAC
+
+Accelerator = ml.Accelerator
+
+__MODEL_LATEST_TAGS__ = {
+ ("44khz", "8kbps"): "0.0.1",
+ ("24khz", "8kbps"): "0.0.4",
+ ("16khz", "8kbps"): "0.0.5",
+ ("44khz", "16kbps"): "1.0.0",
+}
+
+__MODEL_URLS__ = {
+ (
+ "44khz",
+ "0.0.1",
+ "8kbps",
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
+ (
+ "24khz",
+ "0.0.4",
+ "8kbps",
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
+ (
+ "16khz",
+ "0.0.5",
+ "8kbps",
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
+ (
+ "44khz",
+ "1.0.0",
+ "16kbps",
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
+}
+
+
+@argbind.bind(group="download", positional=True, without_prefix=True)
+def download(
+ model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
+):
+ """
+ Function that downloads the weights file from URL if a local cache is not found.
+
+ Parameters
+ ----------
+ model_type : str
+ The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
+ model_bitrate: str
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+ Only 44khz model supports 16kbps.
+ tag : str
+ The tag of the model to download. Defaults to "latest".
+
+ Returns
+ -------
+ Path
+ Directory path required to load model via audiotools.
+ """
+ model_type = model_type.lower()
+ tag = tag.lower()
+
+ assert model_type in [
+ "44khz",
+ "24khz",
+ "16khz",
+ ], "model_type must be one of '44khz', '24khz', or '16khz'"
+
+ assert model_bitrate in [
+ "8kbps",
+ "16kbps",
+ ], "model_bitrate must be one of '8kbps', or '16kbps'"
+
+ if tag == "latest":
+ tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
+
+ download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
+
+ if download_link is None:
+ raise ValueError(
+ f"Could not find model with tag {tag} and model type {model_type}"
+ )
+
+ local_path = (
+ Path.home()
+ / ".cache"
+ / "descript"
+ / "dac"
+ / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
+ )
+ if not local_path.exists():
+ local_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Download the model
+ import requests
+
+ response = requests.get(download_link)
+
+ if response.status_code != 200:
+ raise ValueError(
+ f"Could not download model. Received response code {response.status_code}"
+ )
+ local_path.write_bytes(response.content)
+
+ return local_path
+
+
+def load_model(
+ model_type: str = "44khz",
+ model_bitrate: str = "8kbps",
+ tag: str = "latest",
+ load_path: str = None,
+):
+ if not load_path:
+ load_path = download(
+ model_type=model_type, model_bitrate=model_bitrate, tag=tag
+ )
+ generator = DAC.load(load_path)
+ return generator
diff --git a/src/modules/dac/utils/decode.py b/src/modules/dac/utils/decode.py
new file mode 100644
index 0000000000000000000000000000000000000000..48c25298fd0f9a53f19c1d6d260f1f0563c7b7e4
--- /dev/null
+++ b/src/modules/dac/utils/decode.py
@@ -0,0 +1,95 @@
+import warnings
+from pathlib import Path
+
+import argbind
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from tqdm import tqdm
+
+from dac import DACFile
+from dac.utils import load_model
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+@argbind.bind(group="decode", positional=True, without_prefix=True)
+@torch.inference_mode()
+@torch.no_grad()
+def decode(
+ input: str,
+ output: str = "",
+ weights_path: str = "",
+ model_tag: str = "latest",
+ model_bitrate: str = "8kbps",
+ device: str = "cuda",
+ model_type: str = "44khz",
+ verbose: bool = False,
+):
+ """Decode audio from codes.
+
+ Parameters
+ ----------
+ input : str
+ Path to input directory or file
+ output : str, optional
+ Path to output directory, by default "".
+ If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
+ weights_path : str, optional
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
+ model_tag and model_type.
+ model_tag : str, optional
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
+ model_bitrate: str
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+ device : str, optional
+ Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
+ model_type : str, optional
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
+ """
+ generator = load_model(
+ model_type=model_type,
+ model_bitrate=model_bitrate,
+ tag=model_tag,
+ load_path=weights_path,
+ )
+ generator.to(device)
+ generator.eval()
+
+ # Find all .dac files in input directory
+ _input = Path(input)
+ input_files = list(_input.glob("**/*.dac"))
+
+ # If input is a .dac file, add it to the list
+ if _input.suffix == ".dac":
+ input_files.append(_input)
+
+ # Create output directory
+ output = Path(output)
+ output.mkdir(parents=True, exist_ok=True)
+
+ for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
+ # Load file
+ artifact = DACFile.load(input_files[i])
+
+ # Reconstruct audio from codes
+ recons = generator.decompress(artifact, verbose=verbose)
+
+ # Compute output path
+ relative_path = input_files[i].relative_to(input)
+ output_dir = output / relative_path.parent
+ if not relative_path.name:
+ output_dir = output
+ relative_path = input_files[i]
+ output_name = relative_path.with_suffix(".wav").name
+ output_path = output_dir / output_name
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Write to file
+ recons.write(output_path)
+
+
+if __name__ == "__main__":
+ args = argbind.parse_args()
+ with argbind.scope(args):
+ decode()
diff --git a/src/modules/dac/utils/encode.py b/src/modules/dac/utils/encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c9a8b582b24c194ba82c2d280c4df4fa2cfff87
--- /dev/null
+++ b/src/modules/dac/utils/encode.py
@@ -0,0 +1,94 @@
+import math
+import warnings
+from pathlib import Path
+
+import argbind
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from audiotools.core import util
+from tqdm import tqdm
+
+from dac.utils import load_model
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+@argbind.bind(group="encode", positional=True, without_prefix=True)
+@torch.inference_mode()
+@torch.no_grad()
+def encode(
+ input: str,
+ output: str = "",
+ weights_path: str = "",
+ model_tag: str = "latest",
+ model_bitrate: str = "8kbps",
+ n_quantizers: int = None,
+ device: str = "cuda",
+ model_type: str = "44khz",
+ win_duration: float = 5.0,
+ verbose: bool = False,
+):
+ """Encode audio files in input path to .dac format.
+
+ Parameters
+ ----------
+ input : str
+ Path to input audio file or directory
+ output : str, optional
+ Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
+ weights_path : str, optional
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
+ model_tag and model_type.
+ model_tag : str, optional
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
+ model_bitrate: str
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
+ device : str, optional
+ Device to use, by default "cuda"
+ model_type : str, optional
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
+ """
+ generator = load_model(
+ model_type=model_type,
+ model_bitrate=model_bitrate,
+ tag=model_tag,
+ load_path=weights_path,
+ )
+ generator.to(device)
+ generator.eval()
+ kwargs = {"n_quantizers": n_quantizers}
+
+ # Find all audio files in input path
+ input = Path(input)
+ audio_files = util.find_audio(input)
+
+ output = Path(output)
+ output.mkdir(parents=True, exist_ok=True)
+
+ for i in tqdm(range(len(audio_files)), desc="Encoding files"):
+ # Load file
+ signal = AudioSignal(audio_files[i])
+
+ # Encode audio to .dac format
+ artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
+
+ # Compute output path
+ relative_path = audio_files[i].relative_to(input)
+ output_dir = output / relative_path.parent
+ if not relative_path.name:
+ output_dir = output
+ relative_path = audio_files[i]
+ output_name = relative_path.with_suffix(".dac").name
+ output_path = output_dir / output_name
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ artifact.save(output_path)
+
+
+if __name__ == "__main__":
+ args = argbind.parse_args()
+ with argbind.scope(args):
+ encode()
diff --git a/src/modules/stable_vae/__init__.py b/src/modules/stable_vae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dd2a867025e88625177ccad8d73ce46b807c945
--- /dev/null
+++ b/src/modules/stable_vae/__init__.py
@@ -0,0 +1,40 @@
+from .models.autoencoders import create_autoencoder_from_config
+import os
+import json
+import torch
+from torch.nn.utils import remove_weight_norm
+
+
+def remove_all_weight_norm(model):
+ for name, module in model.named_modules():
+ if hasattr(module, 'weight_g'):
+ remove_weight_norm(module)
+
+
+def load_vae(ckpt_path, remove_weight_norm=False):
+ config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json')
+
+ # Load the model configuration
+ with open(config_file) as f:
+ model_config = json.load(f)
+
+ # Create the model from the configuration
+ model = create_autoencoder_from_config(model_config)
+
+ # Load the state dictionary from the checkpoint
+ model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
+
+ # Strip the "autoencoder." prefix from the keys
+ model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")}
+
+ # Load the state dictionary into the model
+ model.load_state_dict(model_dict)
+
+ # Remove weight normalization
+ if remove_weight_norm:
+ remove_all_weight_norm(model)
+
+ # Set the model to evaluation mode
+ model.eval()
+
+ return model
diff --git a/src/modules/stable_vae/models/autoencoders.py b/src/modules/stable_vae/models/autoencoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..23e1b428ec40767b0ab3838a37b16401fa55ccbd
--- /dev/null
+++ b/src/modules/stable_vae/models/autoencoders.py
@@ -0,0 +1,683 @@
+import torch
+import math
+import numpy as np
+
+from torch import nn
+from torch.nn import functional as F
+from torchaudio import transforms as T
+from alias_free_torch import Activation1d
+from .nn.layers import WNConv1d, WNConvTranspose1d
+from typing import Literal, Dict, Any
+
+# from .inference.sampling import sample
+from .utils import prepare_audio
+from .blocks import SnakeBeta
+from .bottleneck import Bottleneck, DiscreteBottleneck
+from .factory import create_pretransform_from_config, create_bottleneck_from_config
+from .pretransforms import Pretransform
+
+def checkpoint(function, *args, **kwargs):
+ kwargs.setdefault("use_reentrant", False)
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
+
+def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
+ if activation == "elu":
+ act = nn.ELU()
+ elif activation == "snake":
+ act = SnakeBeta(channels)
+ elif activation == "none":
+ act = nn.Identity()
+ else:
+ raise ValueError(f"Unknown activation {activation}")
+
+ if antialias:
+ act = Activation1d(act)
+
+ return act
+
+class ResidualUnit(nn.Module):
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
+ super().__init__()
+
+ self.dilation = dilation
+
+ padding = (dilation * (7-1)) // 2
+
+ self.layers = nn.Sequential(
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=7, dilation=dilation, padding=padding),
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
+ kernel_size=1)
+ )
+
+ def forward(self, x):
+ res = x
+
+ #x = checkpoint(self.layers, x)
+ x = self.layers(x)
+
+ return x + res
+
+class EncoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
+ super().__init__()
+
+ self.layers = nn.Sequential(
+ ResidualUnit(in_channels=in_channels,
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
+ ResidualUnit(in_channels=in_channels,
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
+ ResidualUnit(in_channels=in_channels,
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+class DecoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
+ super().__init__()
+
+ if use_nearest_upsample:
+ upsample_layer = nn.Sequential(
+ nn.Upsample(scale_factor=stride, mode="nearest"),
+ WNConv1d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2*stride,
+ stride=1,
+ bias=False,
+ padding='same')
+ )
+ else:
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
+
+ self.layers = nn.Sequential(
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
+ upsample_layer,
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+ dilation=1, use_snake=use_snake),
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+ dilation=3, use_snake=use_snake),
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+ dilation=9, use_snake=use_snake),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+class OobleckEncoder(nn.Module):
+ def __init__(self,
+ in_channels=2,
+ channels=128,
+ latent_dim=32,
+ c_mults = [1, 2, 4, 8],
+ strides = [2, 4, 8, 8],
+ use_snake=False,
+ antialias_activation=False
+ ):
+ super().__init__()
+
+ c_mults = [1] + c_mults
+
+ self.depth = len(c_mults)
+
+ layers = [
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
+ ]
+
+ for i in range(self.depth-1):
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
+
+ layers += [
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
+ ]
+
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class OobleckDecoder(nn.Module):
+ def __init__(self,
+ out_channels=2,
+ channels=128,
+ latent_dim=32,
+ c_mults = [1, 2, 4, 8],
+ strides = [2, 4, 8, 8],
+ use_snake=False,
+ antialias_activation=False,
+ use_nearest_upsample=False,
+ final_tanh=True):
+ super().__init__()
+
+ c_mults = [1] + c_mults
+
+ self.depth = len(c_mults)
+
+ layers = [
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
+ ]
+
+ for i in range(self.depth-1, 0, -1):
+ layers += [DecoderBlock(
+ in_channels=c_mults[i]*channels,
+ out_channels=c_mults[i-1]*channels,
+ stride=strides[i-1],
+ use_snake=use_snake,
+ antialias_activation=antialias_activation,
+ use_nearest_upsample=use_nearest_upsample
+ )
+ ]
+
+ layers += [
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
+ nn.Tanh() if final_tanh else nn.Identity()
+ ]
+
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class DACEncoderWrapper(nn.Module):
+ def __init__(self, in_channels=1, **kwargs):
+ super().__init__()
+
+ from dac.model.dac import Encoder as DACEncoder
+
+ latent_dim = kwargs.pop("latent_dim", None)
+
+ encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
+ self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
+ self.latent_dim = latent_dim
+
+ # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
+ self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
+
+ if in_channels != 1:
+ self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.proj_out(x)
+ return x
+
+class DACDecoderWrapper(nn.Module):
+ def __init__(self, latent_dim, out_channels=1, **kwargs):
+ super().__init__()
+
+ from dac.model.dac import Decoder as DACDecoder
+
+ self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
+
+ self.latent_dim = latent_dim
+
+ def forward(self, x):
+ return self.decoder(x)
+
+class AudioAutoencoder(nn.Module):
+ def __init__(
+ self,
+ encoder,
+ decoder,
+ latent_dim,
+ downsampling_ratio,
+ sample_rate,
+ io_channels=2,
+ bottleneck: Bottleneck = None,
+ pretransform: Pretransform = None,
+ in_channels = None,
+ out_channels = None,
+ soft_clip = False
+ ):
+ super().__init__()
+
+ self.downsampling_ratio = downsampling_ratio
+ self.sample_rate = sample_rate
+
+ self.latent_dim = latent_dim
+ self.io_channels = io_channels
+ self.in_channels = io_channels
+ self.out_channels = io_channels
+
+ self.min_length = self.downsampling_ratio
+
+ if in_channels is not None:
+ self.in_channels = in_channels
+
+ if out_channels is not None:
+ self.out_channels = out_channels
+
+ self.bottleneck = bottleneck
+
+ self.encoder = encoder
+
+ self.decoder = decoder
+
+ self.pretransform = pretransform
+
+ self.soft_clip = soft_clip
+
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
+
+ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
+
+ info = {}
+
+ if self.pretransform is not None and not skip_pretransform:
+ if self.pretransform.enable_grad:
+ if iterate_batch:
+ audios = []
+ for i in range(audio.shape[0]):
+ audios.append(self.pretransform.encode(audio[i:i+1]))
+ audio = torch.cat(audios, dim=0)
+ else:
+ audio = self.pretransform.encode(audio)
+ else:
+ with torch.no_grad():
+ if iterate_batch:
+ audios = []
+ for i in range(audio.shape[0]):
+ audios.append(self.pretransform.encode(audio[i:i+1]))
+ audio = torch.cat(audios, dim=0)
+ else:
+ audio = self.pretransform.encode(audio)
+
+ if self.encoder is not None:
+ if iterate_batch:
+ latents = []
+ for i in range(audio.shape[0]):
+ latents.append(self.encoder(audio[i:i+1]))
+ latents = torch.cat(latents, dim=0)
+ else:
+ latents = self.encoder(audio)
+ else:
+ latents = audio
+
+ if self.bottleneck is not None:
+ # TODO: Add iterate batch logic, needs to merge the info dicts
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
+
+ info.update(bottleneck_info)
+
+ if return_info:
+ return latents, info
+
+ return latents
+
+ def decode(self, latents, iterate_batch=False, **kwargs):
+
+ if self.bottleneck is not None:
+ if iterate_batch:
+ decoded = []
+ for i in range(latents.shape[0]):
+ decoded.append(self.bottleneck.decode(latents[i:i+1]))
+ decoded = torch.cat(decoded, dim=0)
+ else:
+ latents = self.bottleneck.decode(latents)
+
+ if iterate_batch:
+ decoded = []
+ for i in range(latents.shape[0]):
+ decoded.append(self.decoder(latents[i:i+1]))
+ decoded = torch.cat(decoded, dim=0)
+ else:
+ decoded = self.decoder(latents, **kwargs)
+
+ if self.pretransform is not None:
+ if self.pretransform.enable_grad:
+ if iterate_batch:
+ decodeds = []
+ for i in range(decoded.shape[0]):
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
+ decoded = torch.cat(decodeds, dim=0)
+ else:
+ decoded = self.pretransform.decode(decoded)
+ else:
+ with torch.no_grad():
+ if iterate_batch:
+ decodeds = []
+ for i in range(latents.shape[0]):
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
+ decoded = torch.cat(decodeds, dim=0)
+ else:
+ decoded = self.pretransform.decode(decoded)
+
+ if self.soft_clip:
+ decoded = torch.tanh(decoded)
+
+ return decoded
+
+ def decode_tokens(self, tokens, **kwargs):
+ '''
+ Decode discrete tokens to audio
+ Only works with discrete autoencoders
+ '''
+
+ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
+
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
+
+ return self.decode(latents, **kwargs)
+
+
+ def preprocess_audio_for_encoder(self, audio, in_sr):
+ '''
+ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
+ If the model is mono, stereo audio will be converted to mono.
+ Audio will be silence-padded to be a multiple of the model's downsampling ratio.
+ Audio will be resampled to the model's sample rate.
+ The output will have batch size 1 and be shape (1 x Channels x Length)
+ '''
+ return self.preprocess_audio_list_for_encoder([audio], [in_sr])
+
+ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
+ '''
+ Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
+ The audio in that list can be of different lengths and channels.
+ in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
+ All audio will be resampled to the model's sample rate.
+ Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
+ If the model is mono, all audio will be converted to mono.
+ The output will be a tensor of shape (Batch x Channels x Length)
+ '''
+ batch_size = len(audio_list)
+ if isinstance(in_sr_list, int):
+ in_sr_list = [in_sr_list]*batch_size
+ assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
+ new_audio = []
+ max_length = 0
+ # resample & find the max length
+ for i in range(batch_size):
+ audio = audio_list[i]
+ in_sr = in_sr_list[i]
+ if len(audio.shape) == 3 and audio.shape[0] == 1:
+ # batchsize 1 was given by accident. Just squeeze it.
+ audio = audio.squeeze(0)
+ elif len(audio.shape) == 1:
+ # Mono signal, channel dimension is missing, unsqueeze it in
+ audio = audio.unsqueeze(0)
+ assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
+ # Resample audio
+ if in_sr != self.sample_rate:
+ resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
+ audio = resample_tf(audio)
+ new_audio.append(audio)
+ if audio.shape[-1] > max_length:
+ max_length = audio.shape[-1]
+ # Pad every audio to the same length, multiple of model's downsampling ratio
+ padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
+ for i in range(batch_size):
+ # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
+ new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
+ target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
+ # convert to tensor
+ return torch.stack(new_audio)
+
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
+ '''
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
+ # and therefore you likely could use the same values with decode_audio.
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
+ Smaller chunk_size uses less memory, but more compute.
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
+ '''
+ if not chunked:
+ # default behavior. Encode the entire audio in parallel
+ return self.encode(audio, **kwargs)
+ else:
+ # CHUNKED ENCODING
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
+ samples_per_latent = self.downsampling_ratio
+ total_size = audio.shape[2] # in samples
+ batch_size = audio.shape[0]
+ chunk_size *= samples_per_latent # converting metric in latents to samples
+ overlap *= samples_per_latent # converting metric in latents to samples
+ hop_size = chunk_size - overlap
+ chunks = []
+ for i in range(0, total_size - chunk_size + 1, hop_size):
+ chunk = audio[:,:,i:i+chunk_size]
+ chunks.append(chunk)
+ if i+chunk_size != total_size:
+ # Final chunk
+ chunk = audio[:,:,-chunk_size:]
+ chunks.append(chunk)
+ chunks = torch.stack(chunks)
+ num_chunks = chunks.shape[0]
+ # Note: y_size might be a different value from the latent length used in diffusion training
+ # because we can encode audio of varying lengths
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
+ y_size = total_size // samples_per_latent
+ # Create an empty latent, we will populate it with chunks as we encode them
+ y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
+ for i in range(num_chunks):
+ x_chunk = chunks[i,:]
+ # encode the chunk
+ y_chunk = self.encode(x_chunk)
+ # figure out where to put the audio along the time domain
+ if i == num_chunks-1:
+ # final chunk always goes at the end
+ t_end = y_size
+ t_start = t_end - y_chunk.shape[2]
+ else:
+ t_start = i * hop_size // samples_per_latent
+ t_end = t_start + chunk_size // samples_per_latent
+ # remove the edges of the overlaps
+ ol = overlap//samples_per_latent//2
+ chunk_start = 0
+ chunk_end = y_chunk.shape[2]
+ if i > 0:
+ # no overlap for the start of the first chunk
+ t_start += ol
+ chunk_start += ol
+ if i < num_chunks-1:
+ # no overlap for the end of the last chunk
+ t_end -= ol
+ chunk_end -= ol
+ # paste the chunked audio into our y_final output audio
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
+ return y_final
+
+ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
+ '''
+ Decode latents to audio.
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
+ Smaller chunk_size uses less memory, but more compute.
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
+ '''
+ if not chunked:
+ # default behavior. Decode the entire latent in parallel
+ return self.decode(latents, **kwargs)
+ else:
+ # chunked decoding
+ hop_size = chunk_size - overlap
+ total_size = latents.shape[2]
+ batch_size = latents.shape[0]
+ chunks = []
+ for i in range(0, total_size - chunk_size + 1, hop_size):
+ chunk = latents[:,:,i:i+chunk_size]
+ chunks.append(chunk)
+ if i+chunk_size != total_size:
+ # Final chunk
+ chunk = latents[:,:,-chunk_size:]
+ chunks.append(chunk)
+ chunks = torch.stack(chunks)
+ num_chunks = chunks.shape[0]
+ # samples_per_latent is just the downsampling ratio
+ samples_per_latent = self.downsampling_ratio
+ # Create an empty waveform, we will populate it with chunks as decode them
+ y_size = total_size * samples_per_latent
+ y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
+ for i in range(num_chunks):
+ x_chunk = chunks[i,:]
+ # decode the chunk
+ y_chunk = self.decode(x_chunk)
+ # figure out where to put the audio along the time domain
+ if i == num_chunks-1:
+ # final chunk always goes at the end
+ t_end = y_size
+ t_start = t_end - y_chunk.shape[2]
+ else:
+ t_start = i * hop_size * samples_per_latent
+ t_end = t_start + chunk_size * samples_per_latent
+ # remove the edges of the overlaps
+ ol = (overlap//2) * samples_per_latent
+ chunk_start = 0
+ chunk_end = y_chunk.shape[2]
+ if i > 0:
+ # no overlap for the start of the first chunk
+ t_start += ol
+ chunk_start += ol
+ if i < num_chunks-1:
+ # no overlap for the end of the last chunk
+ t_end -= ol
+ chunk_end -= ol
+ # paste the chunked audio into our y_final output audio
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
+ return y_final
+
+
+# AE factories
+
+def create_encoder_from_config(encoder_config: Dict[str, Any]):
+ encoder_type = encoder_config.get("type", None)
+ assert encoder_type is not None, "Encoder type must be specified"
+
+ if encoder_type == "oobleck":
+ encoder = OobleckEncoder(
+ **encoder_config["config"]
+ )
+
+ elif encoder_type == "seanet":
+ from encodec.modules import SEANetEncoder
+ seanet_encoder_config = encoder_config["config"]
+
+ #SEANet encoder expects strides in reverse order
+ seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
+ encoder = SEANetEncoder(
+ **seanet_encoder_config
+ )
+ elif encoder_type == "dac":
+ dac_config = encoder_config["config"]
+
+ encoder = DACEncoderWrapper(**dac_config)
+ elif encoder_type == "local_attn":
+ from .local_attention import TransformerEncoder1D
+
+ local_attn_config = encoder_config["config"]
+
+ encoder = TransformerEncoder1D(
+ **local_attn_config
+ )
+ else:
+ raise ValueError(f"Unknown encoder type {encoder_type}")
+
+ requires_grad = encoder_config.get("requires_grad", True)
+ if not requires_grad:
+ for param in encoder.parameters():
+ param.requires_grad = False
+
+ return encoder
+
+def create_decoder_from_config(decoder_config: Dict[str, Any]):
+ decoder_type = decoder_config.get("type", None)
+ assert decoder_type is not None, "Decoder type must be specified"
+
+ if decoder_type == "oobleck":
+ decoder = OobleckDecoder(
+ **decoder_config["config"]
+ )
+ elif decoder_type == "seanet":
+ from encodec.modules import SEANetDecoder
+
+ decoder = SEANetDecoder(
+ **decoder_config["config"]
+ )
+ elif decoder_type == "dac":
+ dac_config = decoder_config["config"]
+
+ decoder = DACDecoderWrapper(**dac_config)
+ elif decoder_type == "local_attn":
+ from .local_attention import TransformerDecoder1D
+
+ local_attn_config = decoder_config["config"]
+
+ decoder = TransformerDecoder1D(
+ **local_attn_config
+ )
+ else:
+ raise ValueError(f"Unknown decoder type {decoder_type}")
+
+ requires_grad = decoder_config.get("requires_grad", True)
+ if not requires_grad:
+ for param in decoder.parameters():
+ param.requires_grad = False
+
+ return decoder
+
+def create_autoencoder_from_config(config: Dict[str, Any]):
+
+ ae_config = config["model"]
+
+ encoder = create_encoder_from_config(ae_config["encoder"])
+ decoder = create_decoder_from_config(ae_config["decoder"])
+
+ bottleneck = ae_config.get("bottleneck", None)
+
+ latent_dim = ae_config.get("latent_dim", None)
+ assert latent_dim is not None, "latent_dim must be specified in model config"
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
+ io_channels = ae_config.get("io_channels", None)
+ assert io_channels is not None, "io_channels must be specified in model config"
+ sample_rate = config.get("sample_rate", None)
+ assert sample_rate is not None, "sample_rate must be specified in model config"
+
+ in_channels = ae_config.get("in_channels", None)
+ out_channels = ae_config.get("out_channels", None)
+
+ pretransform = ae_config.get("pretransform", None)
+
+ if pretransform is not None:
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
+
+ if bottleneck is not None:
+ bottleneck = create_bottleneck_from_config(bottleneck)
+
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
+
+ return AudioAutoencoder(
+ encoder,
+ decoder,
+ io_channels=io_channels,
+ latent_dim=latent_dim,
+ downsampling_ratio=downsampling_ratio,
+ sample_rate=sample_rate,
+ bottleneck=bottleneck,
+ pretransform=pretransform,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ soft_clip=soft_clip
+ )
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/blocks.py b/src/modules/stable_vae/models/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..a832f5c561a5b8e0a1c58a5c5fb4c05325534561
--- /dev/null
+++ b/src/modules/stable_vae/models/blocks.py
@@ -0,0 +1,359 @@
+from functools import reduce
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from torch.backends.cuda import sdp_kernel
+from packaging import version
+
+from .nn.layers import Snake1d
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, main, skip=None):
+ super().__init__()
+ self.main = nn.Sequential(*main)
+ self.skip = skip if skip else nn.Identity()
+
+ def forward(self, input):
+ return self.main(input) + self.skip(input)
+
+
+class ResConvBlock(ResidualBlock):
+ def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
+ super().__init__([
+ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
+ nn.GroupNorm(1, c_mid),
+ Snake1d(c_mid) if use_snake else nn.GELU(),
+ nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
+ (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
+ ], skip)
+
+
+class SelfAttention1d(nn.Module):
+ def __init__(self, c_in, n_head=1, dropout_rate=0.):
+ super().__init__()
+ assert c_in % n_head == 0
+ self.norm = nn.GroupNorm(1, c_in)
+ self.n_head = n_head
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
+
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
+
+ if not self.use_flash:
+ return
+
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
+
+ if device_properties.major == 8 and device_properties.minor == 0:
+ # Use flash attention for A100 GPUs
+ self.sdp_kernel_config = (True, False, False)
+ else:
+ # Don't use flash attention for other GPUs
+ self.sdp_kernel_config = (False, True, True)
+
+ def forward(self, input):
+ n, c, s = input.shape
+ qkv = self.qkv_proj(self.norm(input))
+ qkv = qkv.view(
+ [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = k.shape[3]**-0.25
+
+ if self.use_flash:
+ with sdp_kernel(*self.sdp_kernel_config):
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
+ else:
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
+
+
+ return input + self.dropout(self.out_proj(y))
+
+
+class SkipBlock(nn.Module):
+ def __init__(self, *main):
+ super().__init__()
+ self.main = nn.Sequential(*main)
+
+ def forward(self, input):
+ return torch.cat([self.main(input), input], dim=1)
+
+
+class FourierFeatures(nn.Module):
+ def __init__(self, in_features, out_features, std=1.):
+ super().__init__()
+ assert out_features % 2 == 0
+ self.weight = nn.Parameter(torch.randn(
+ [out_features // 2, in_features]) * std)
+
+ def forward(self, input):
+ f = 2 * math.pi * input @ self.weight.T
+ return torch.cat([f.cos(), f.sin()], dim=-1)
+
+
+def expand_to_planes(input, shape):
+ return input[..., None].repeat([1, 1, shape[2]])
+
+_kernels = {
+ 'linear':
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
+ 'cubic':
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
+ 'lanczos3':
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
+}
+
+
+class Downsample1d(nn.Module):
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor(_kernels[kernel])
+ self.pad = kernel_1d.shape[0] // 2 - 1
+ self.register_buffer('kernel', kernel_1d)
+ self.channels_last = channels_last
+
+ def forward(self, x):
+ if self.channels_last:
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
+ indices = torch.arange(x.shape[1], device=x.device)
+ weight[indices, indices] = self.kernel.to(weight)
+ x = F.conv1d(x, weight, stride=2)
+ if self.channels_last:
+ x = x.permute(0, 2, 1)
+ return x
+
+
+class Upsample1d(nn.Module):
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
+ self.pad = kernel_1d.shape[0] // 2 - 1
+ self.register_buffer('kernel', kernel_1d)
+ self.channels_last = channels_last
+
+ def forward(self, x):
+ if self.channels_last:
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
+ indices = torch.arange(x.shape[1], device=x.device)
+ weight[indices, indices] = self.kernel.to(weight)
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
+ if self.channels_last:
+ x = x.permute(0, 2, 1)
+ return x
+
+
+def Downsample1d_2(
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
+) -> nn.Module:
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
+
+ return nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=factor * kernel_multiplier + 1,
+ stride=factor,
+ padding=factor * (kernel_multiplier // 2),
+ )
+
+
+def Upsample1d_2(
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
+) -> nn.Module:
+
+ if factor == 1:
+ return nn.Conv1d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
+ )
+
+ if use_nearest:
+ return nn.Sequential(
+ nn.Upsample(scale_factor=factor, mode="nearest"),
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ padding=1,
+ ),
+ )
+ else:
+ return nn.ConvTranspose1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=factor * 2,
+ stride=factor,
+ padding=factor // 2 + factor % 2,
+ output_padding=factor % 2,
+ )
+
+
+def zero_init(layer):
+ nn.init.zeros_(layer.weight)
+ if layer.bias is not None:
+ nn.init.zeros_(layer.bias)
+ return layer
+
+
+def rms_norm(x, scale, eps):
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
+ return x * scale.to(x.dtype)
+
+#rms_norm = torch.compile(rms_norm)
+
+class AdaRMSNorm(nn.Module):
+ def __init__(self, features, cond_features, eps=1e-6):
+ super().__init__()
+ self.eps = eps
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
+
+ def extra_repr(self):
+ return f"eps={self.eps},"
+
+ def forward(self, x, cond):
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
+
+
+def normalize(x, eps=1e-4):
+ dim = list(range(1, x.ndim))
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
+ alpha = np.sqrt(n.numel() / x.numel())
+ return x / torch.add(eps, n, alpha=alpha)
+
+
+class ForcedWNConv1d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
+
+ def forward(self, x):
+ if self.training:
+ with torch.no_grad():
+ self.weight.copy_(normalize(self.weight))
+
+ fan_in = self.weight[0].numel()
+
+ w = normalize(self.weight) / math.sqrt(fan_in)
+
+ return F.conv1d(x, w, padding='same')
+
+# Kernels
+
+use_compile = True
+
+def compile(function, *args, **kwargs):
+ if not use_compile:
+ return function
+ try:
+ return torch.compile(function, *args, **kwargs)
+ except RuntimeError:
+ return function
+
+
+@compile
+def linear_geglu(x, weight, bias=None):
+ x = x @ weight.mT
+ if bias is not None:
+ x = x + bias
+ x, gate = x.chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+@compile
+def rms_norm(x, scale, eps):
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
+ return x * scale.to(x.dtype)
+
+# Layers
+
+
+class LinearGEGLU(nn.Linear):
+ def __init__(self, in_features, out_features, bias=True):
+ super().__init__(in_features, out_features * 2, bias=bias)
+ self.out_features = out_features
+
+ def forward(self, x):
+ return linear_geglu(x, self.weight, self.bias)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, shape, fix_scale = False, eps=1e-6):
+ super().__init__()
+ self.eps = eps
+
+ if fix_scale:
+ self.register_buffer("scale", torch.ones(shape))
+ else:
+ self.scale = nn.Parameter(torch.ones(shape))
+
+ def extra_repr(self):
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
+
+ def forward(self, x):
+ return rms_norm(x, self.scale, self.eps)
+
+
+# jit script make it 1.4x faster and save GPU memory
+@torch.jit.script
+def snake_beta(x, alpha, beta):
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
+
+# try:
+# snake_beta = torch.compile(snake_beta)
+# except RuntimeError:
+# pass
+
+
+# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
+# License available in LICENSES/LICENSE_NVIDIA.txt
+class SnakeBeta(nn.Module):
+
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale:
+ # log scale alphas initialized to zeros
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
+ else:
+ # linear scale alphas initialized to ones
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ # self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
+ # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = snake_beta(x, alpha, beta)
+
+ return x
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/bottleneck.py b/src/modules/stable_vae/models/bottleneck.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a2254c97fbb5d1f4bbf7150106b5f3a154e3705
--- /dev/null
+++ b/src/modules/stable_vae/models/bottleneck.py
@@ -0,0 +1,346 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from einops import rearrange
+from vector_quantize_pytorch import ResidualVQ, FSQ
+from .nn.quantize import ResidualVectorQuantize as DACResidualVQ
+
+
+class Bottleneck(nn.Module):
+ def __init__(self, is_discrete: bool = False):
+ super().__init__()
+
+ self.is_discrete = is_discrete
+
+ def encode(self, x, return_info=False, **kwargs):
+ raise NotImplementedError
+
+ def decode(self, x):
+ raise NotImplementedError
+
+
+class DiscreteBottleneck(Bottleneck):
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
+ super().__init__(is_discrete=True)
+
+ self.num_quantizers = num_quantizers
+ self.codebook_size = codebook_size
+ self.tokens_id = tokens_id
+
+ def decode_tokens(self, codes, **kwargs):
+ raise NotImplementedError
+
+
+class TanhBottleneck(Bottleneck):
+ def __init__(self):
+ super().__init__(is_discrete=False)
+ self.tanh = nn.Tanh()
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ x = torch.tanh(x)
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return x
+
+
+@torch.jit.script
+def vae_sample_kl(mean, scale):
+ stdev = nn.functional.softplus(scale) + 1e-4
+ var = stdev * stdev
+ logvar = torch.log(var)
+ latents = torch.randn_like(mean) * stdev + mean
+
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
+
+ return latents, kl
+
+
+@torch.jit.script
+def vae_sample(mean, scale):
+ stdev = nn.functional.softplus(scale) + 1e-4
+ latents = torch.randn_like(mean) * stdev + mean
+ return latents
+
+
+class VAEBottleneck(Bottleneck):
+ def __init__(self):
+ super().__init__(is_discrete=False)
+
+ def encode(self, x, return_info=False, **kwargs):
+ mean, scale = x.chunk(2, dim=1)
+
+ if return_info:
+ info = {}
+ x, kl = vae_sample_kl(mean, scale)
+ info["kl"] = kl
+ return x, info
+ else:
+ x = vae_sample(mean, scale)
+ return x
+
+ def decode(self, x):
+ return x
+
+
+def compute_mean_kernel(x, y):
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
+ return torch.exp(-kernel_input).mean()
+
+
+def compute_mmd(latents):
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
+ noise = torch.randn_like(latents_reshaped)
+
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
+ noise_kernel = compute_mean_kernel(noise, noise)
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
+
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
+ return mmd.mean()
+
+
+class WassersteinBottleneck(Bottleneck):
+ def __init__(self, noise_augment_dim: int = 0):
+ super().__init__(is_discrete=False)
+
+ self.noise_augment_dim = noise_augment_dim
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ if self.training and return_info:
+ mmd = compute_mmd(x)
+ info["mmd"] = mmd
+
+ if return_info:
+ return x, info
+
+ return x
+
+ def decode(self, x):
+
+ if self.noise_augment_dim > 0:
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
+ x.shape[-1]).type_as(x)
+ x = torch.cat([x, noise], dim=1)
+
+ return x
+
+
+class L2Bottleneck(Bottleneck):
+ def __init__(self):
+ super().__init__(is_discrete=False)
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ x = F.normalize(x, dim=1)
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return F.normalize(x, dim=1)
+
+
+class RVQBottleneck(DiscreteBottleneck):
+ def __init__(self, **quantizer_kwargs):
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
+
+ def encode(self, x, return_info=False, **kwargs):
+ info = {}
+
+ x = rearrange(x, "b c n -> b n c")
+ x, indices, loss = self.quantizer(x)
+ x = rearrange(x, "b n c -> b c n")
+
+ info["quantizer_indices"] = indices
+ info["quantizer_loss"] = loss.mean()
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return x
+
+ def decode_tokens(self, codes, **kwargs):
+ latents = self.quantizer.get_outputs_from_indices(codes)
+
+ return self.decode(latents, **kwargs)
+
+
+class RVQVAEBottleneck(DiscreteBottleneck):
+ def __init__(self, **quantizer_kwargs):
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ x, kl = vae_sample(*x.chunk(2, dim=1))
+
+ info["kl"] = kl
+
+ x = rearrange(x, "b c n -> b n c")
+ x, indices, loss = self.quantizer(x)
+ x = rearrange(x, "b n c -> b c n")
+
+ info["quantizer_indices"] = indices
+ info["quantizer_loss"] = loss.mean()
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return x
+
+ def decode_tokens(self, codes, **kwargs):
+ latents = self.quantizer.get_outputs_from_indices(codes)
+
+ return self.decode(latents, **kwargs)
+
+
+class DACRVQBottleneck(DiscreteBottleneck):
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
+ self.quantize_on_decode = quantize_on_decode
+
+ def encode(self, x, return_info=False, **kwargs):
+ info = {}
+
+ info["pre_quantizer"] = x
+
+ if self.quantize_on_decode:
+ return x, info if return_info else x
+
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
+
+ output = {
+ "z": z,
+ "codes": codes,
+ "latents": latents,
+ "vq/commitment_loss": commitment_loss,
+ "vq/codebook_loss": codebook_loss,
+ }
+
+ output["vq/commitment_loss"] /= self.num_quantizers
+ output["vq/codebook_loss"] /= self.num_quantizers
+
+ info.update(output)
+
+ if return_info:
+ return output["z"], info
+
+ return output["z"]
+
+ def decode(self, x):
+
+ if self.quantize_on_decode:
+ x = self.quantizer(x)[0]
+
+ return x
+
+ def decode_tokens(self, codes, **kwargs):
+ latents, _, _ = self.quantizer.from_codes(codes)
+
+ return self.decode(latents, **kwargs)
+
+
+class DACRVQVAEBottleneck(DiscreteBottleneck):
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
+ self.quantize_on_decode = quantize_on_decode
+
+ def encode(self, x, return_info=False, n_quantizers: int = None):
+ info = {}
+
+ mean, scale = x.chunk(2, dim=1)
+
+ x, kl = vae_sample(mean, scale)
+
+ info["pre_quantizer"] = x
+ info["kl"] = kl
+
+ if self.quantize_on_decode:
+ return x, info if return_info else x
+
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
+
+ output = {
+ "z": z,
+ "codes": codes,
+ "latents": latents,
+ "vq/commitment_loss": commitment_loss,
+ "vq/codebook_loss": codebook_loss,
+ }
+
+ output["vq/commitment_loss"] /= self.num_quantizers
+ output["vq/codebook_loss"] /= self.num_quantizers
+
+ info.update(output)
+
+ if return_info:
+ return output["z"], info
+
+ return output["z"]
+
+ def decode(self, x):
+
+ if self.quantize_on_decode:
+ x = self.quantizer(x)[0]
+
+ return x
+
+ def decode_tokens(self, codes, **kwargs):
+ latents, _, _ = self.quantizer.from_codes(codes)
+
+ return self.decode(latents, **kwargs)
+
+
+class FSQBottleneck(DiscreteBottleneck):
+ def __init__(self, dim, levels):
+ super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices")
+ self.quantizer = FSQ(levels=[levels] * dim)
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ x = rearrange(x, "b c n -> b n c")
+ x, indices = self.quantizer(x)
+ x = rearrange(x, "b n c -> b c n")
+
+ info["quantizer_indices"] = indices
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return x
+
+ def decode_tokens(self, tokens, **kwargs):
+ latents = self.quantizer.indices_to_codes(tokens)
+
+ return self.decode(latents, **kwargs)
\ No newline at end of file
diff --git a/src/modules/stable_vae/models/factory.py b/src/modules/stable_vae/models/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..46292eca5f5367bef4ac7133beada867c1e43f9a
--- /dev/null
+++ b/src/modules/stable_vae/models/factory.py
@@ -0,0 +1,153 @@
+import json
+
+def create_model_from_config(model_config):
+ model_type = model_config.get('model_type', None)
+
+ assert model_type is not None, 'model_type must be specified in model config'
+
+ if model_type == 'autoencoder':
+ from .autoencoders import create_autoencoder_from_config
+ return create_autoencoder_from_config(model_config)
+ elif model_type == 'diffusion_uncond':
+ from .diffusion import create_diffusion_uncond_from_config
+ return create_diffusion_uncond_from_config(model_config)
+ elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
+ from .diffusion import create_diffusion_cond_from_config
+ return create_diffusion_cond_from_config(model_config)
+ elif model_type == 'diffusion_autoencoder':
+ from .autoencoders import create_diffAE_from_config
+ return create_diffAE_from_config(model_config)
+ elif model_type == 'lm':
+ from .lm import create_audio_lm_from_config
+ return create_audio_lm_from_config(model_config)
+ else:
+ raise NotImplementedError(f'Unknown model type: {model_type}')
+
+def create_model_from_config_path(model_config_path):
+ with open(model_config_path) as f:
+ model_config = json.load(f)
+
+ return create_model_from_config(model_config)
+
+def create_pretransform_from_config(pretransform_config, sample_rate):
+ pretransform_type = pretransform_config.get('type', None)
+
+ assert pretransform_type is not None, 'type must be specified in pretransform config'
+
+ if pretransform_type == 'autoencoder':
+ from .autoencoders import create_autoencoder_from_config
+ from .pretransforms import AutoencoderPretransform
+
+ # Create fake top-level config to pass sample rate to autoencoder constructor
+ # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
+ autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
+ autoencoder = create_autoencoder_from_config(autoencoder_config)
+
+ scale = pretransform_config.get("scale", 1.0)
+ model_half = pretransform_config.get("model_half", False)
+ iterate_batch = pretransform_config.get("iterate_batch", False)
+ chunked = pretransform_config.get("chunked", False)
+
+ pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
+ elif pretransform_type == 'wavelet':
+ from .pretransforms import WaveletPretransform
+
+ wavelet_config = pretransform_config["config"]
+ channels = wavelet_config["channels"]
+ levels = wavelet_config["levels"]
+ wavelet = wavelet_config["wavelet"]
+
+ pretransform = WaveletPretransform(channels, levels, wavelet)
+ elif pretransform_type == 'pqmf':
+ from .pretransforms import PQMFPretransform
+ pqmf_config = pretransform_config["config"]
+ pretransform = PQMFPretransform(**pqmf_config)
+ elif pretransform_type == 'dac_pretrained':
+ from .pretransforms import PretrainedDACPretransform
+ pretrained_dac_config = pretransform_config["config"]
+ pretransform = PretrainedDACPretransform(**pretrained_dac_config)
+ elif pretransform_type == "audiocraft_pretrained":
+ from .pretransforms import AudiocraftCompressionPretransform
+
+ audiocraft_config = pretransform_config["config"]
+ pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
+ else:
+ raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
+
+ enable_grad = pretransform_config.get('enable_grad', False)
+ pretransform.enable_grad = enable_grad
+
+ pretransform.eval().requires_grad_(pretransform.enable_grad)
+
+ return pretransform
+
+def create_bottleneck_from_config(bottleneck_config):
+ bottleneck_type = bottleneck_config.get('type', None)
+
+ assert bottleneck_type is not None, 'type must be specified in bottleneck config'
+
+ if bottleneck_type == 'tanh':
+ from .bottleneck import TanhBottleneck
+ bottleneck = TanhBottleneck()
+ elif bottleneck_type == 'vae':
+ from .bottleneck import VAEBottleneck
+ bottleneck = VAEBottleneck()
+ elif bottleneck_type == 'rvq':
+ from .bottleneck import RVQBottleneck
+
+ quantizer_params = {
+ "dim": 128,
+ "codebook_size": 1024,
+ "num_quantizers": 8,
+ "decay": 0.99,
+ "kmeans_init": True,
+ "kmeans_iters": 50,
+ "threshold_ema_dead_code": 2,
+ }
+
+ quantizer_params.update(bottleneck_config["config"])
+
+ bottleneck = RVQBottleneck(**quantizer_params)
+ elif bottleneck_type == "dac_rvq":
+ from .bottleneck import DACRVQBottleneck
+
+ bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
+
+ elif bottleneck_type == 'rvq_vae':
+ from .bottleneck import RVQVAEBottleneck
+
+ quantizer_params = {
+ "dim": 128,
+ "codebook_size": 1024,
+ "num_quantizers": 8,
+ "decay": 0.99,
+ "kmeans_init": True,
+ "kmeans_iters": 50,
+ "threshold_ema_dead_code": 2,
+ }
+
+ quantizer_params.update(bottleneck_config["config"])
+
+ bottleneck = RVQVAEBottleneck(**quantizer_params)
+
+ elif bottleneck_type == 'dac_rvq_vae':
+ from .bottleneck import DACRVQVAEBottleneck
+ bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
+ elif bottleneck_type == 'l2_norm':
+ from .bottleneck import L2Bottleneck
+ bottleneck = L2Bottleneck()
+ elif bottleneck_type == "wasserstein":
+ from .bottleneck import WassersteinBottleneck
+ bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
+ elif bottleneck_type == "fsq":
+ from .bottleneck import FSQBottleneck
+ bottleneck = FSQBottleneck(**bottleneck_config["config"])
+ else:
+ raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
+
+ requires_grad = bottleneck_config.get('requires_grad', True)
+ if not requires_grad:
+ for param in bottleneck.parameters():
+ param.requires_grad = False
+
+ return bottleneck
diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec2f6b972aeb8a21764ff73dae2095eb94bb8ba4
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py
@@ -0,0 +1,3 @@
+from . import layers
+from . import loss
+from . import quantize
diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c53a88b881deab94fb06b6a395951cf9cc995a
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+ shape = x.shape
+ x = x.reshape(shape[0], shape[1], -1)
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+ x = x.reshape(shape)
+ return x
+
+
+class Snake1d(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+ def forward(self, x):
+ return snake(x, self.alpha)
diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a26e62f885f0bb7c6774e95660a7848e10a8f87
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py
@@ -0,0 +1,368 @@
+import typing
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import STFTParams
+from torch import nn
+
+
+class L1Loss(nn.L1Loss):
+ """L1 Loss between AudioSignals. Defaults
+ to comparing ``audio_data``, but any
+ attribute of an AudioSignal can be used.
+
+ Parameters
+ ----------
+ attribute : str, optional
+ Attribute of signal to compare, defaults to ``audio_data``.
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+ self.attribute = attribute
+ self.weight = weight
+ super().__init__(**kwargs)
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate AudioSignal
+ y : AudioSignal
+ Reference AudioSignal
+
+ Returns
+ -------
+ torch.Tensor
+ L1 loss between AudioSignal attributes.
+ """
+ if isinstance(x, AudioSignal):
+ x = getattr(x, self.attribute)
+ y = getattr(y, self.attribute)
+ return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+ """
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+ of estimated and reference audio signals or aligned features.
+
+ Parameters
+ ----------
+ scaling : int, optional
+ Whether to use scale-invariant (True) or
+ signal-to-noise ratio (False), by default True
+ reduction : str, optional
+ How to reduce across the batch (either 'mean',
+ 'sum', or none).], by default ' mean'
+ zero_mean : int, optional
+ Zero mean the references and estimates before
+ computing the loss, by default True
+ clip_min : int, optional
+ The minimum possible loss value. Helps network
+ to not focus on making already good examples better, by default None
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(
+ self,
+ scaling: int = True,
+ reduction: str = "mean",
+ zero_mean: int = True,
+ clip_min: int = None,
+ weight: float = 1.0,
+ ):
+ self.scaling = scaling
+ self.reduction = reduction
+ self.zero_mean = zero_mean
+ self.clip_min = clip_min
+ self.weight = weight
+ super().__init__()
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ eps = 1e-8
+ # nb, nc, nt
+ if isinstance(x, AudioSignal):
+ references = x.audio_data
+ estimates = y.audio_data
+ else:
+ references = x
+ estimates = y
+
+ nb = references.shape[0]
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+ # samples now on axis 1
+ if self.zero_mean:
+ mean_reference = references.mean(dim=1, keepdim=True)
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
+ else:
+ mean_reference = 0
+ mean_estimate = 0
+
+ _references = references - mean_reference
+ _estimates = estimates - mean_estimate
+
+ references_projection = (_references**2).sum(dim=-2) + eps
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+ scale = (
+ (references_on_estimates / references_projection).unsqueeze(1)
+ if self.scaling
+ else 1
+ )
+
+ e_true = scale * _references
+ e_res = _estimates - e_true
+
+ signal = (e_true**2).sum(dim=1)
+ noise = (e_res**2).sum(dim=1)
+ sdr = -10 * torch.log10(signal / noise + eps)
+
+ if self.clip_min is not None:
+ sdr = torch.clamp(sdr, min=self.clip_min)
+
+ if self.reduction == "mean":
+ sdr = sdr.mean()
+ elif self.reduction == "sum":
+ sdr = sdr.sum()
+ return sdr
+
+
+class MultiScaleSTFTLoss(nn.Module):
+ """Computes the multi-scale STFT loss from [1].
+
+ Parameters
+ ----------
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ References
+ ----------
+
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
+ "DDSP: Differentiable Digital Signal Processing."
+ International Conference on Learning Representations. 2019.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.loss_fn = loss_fn
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.clamp_eps = clamp_eps
+ self.weight = weight
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes multi-scale STFT between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Multi-scale STFT loss.
+ """
+ loss = 0.0
+ for s in self.stft_params:
+ x.stft(s.window_length, s.hop_length, s.window_type)
+ y.stft(s.window_length, s.hop_length, s.window_type)
+ loss += self.log_weight * self.loss_fn(
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+ return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+ """Compute distance between mel spectrograms. Can be used
+ in a multi-scale way.
+
+ Parameters
+ ----------
+ n_mels : List[int]
+ Number of mels per STFT, by default [150, 80],
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ n_mels: List[int] = [150, 80],
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ mel_fmin: List[float] = [0.0, 0.0],
+ mel_fmax: List[float] = [None, None],
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.n_mels = n_mels
+ self.loss_fn = loss_fn
+ self.clamp_eps = clamp_eps
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.weight = weight
+ self.mel_fmin = mel_fmin
+ self.mel_fmax = mel_fmax
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes mel loss between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Mel loss.
+ """
+ loss = 0.0
+ for n_mels, fmin, fmax, s in zip(
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+ ):
+ kwargs = {
+ "window_length": s.window_length,
+ "hop_length": s.hop_length,
+ "window_type": s.window_type,
+ }
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+ loss += self.log_weight * self.loss_fn(
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+ return loss
+
+
+class GANLoss(nn.Module):
+ """
+ Computes a discriminator loss, given a discriminator on
+ generated waveforms/spectrograms compared to ground truth
+ waveforms/spectrograms. Computes the loss for both the
+ discriminator and the generator in separate functions.
+ """
+
+ def __init__(self, discriminator):
+ super().__init__()
+ self.discriminator = discriminator
+
+ def forward(self, fake, real):
+ d_fake = self.discriminator(fake.audio_data)
+ d_real = self.discriminator(real.audio_data)
+ return d_fake, d_real
+
+ def discriminator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+ loss_d = 0
+ for x_fake, x_real in zip(d_fake, d_real):
+ loss_d += torch.mean(x_fake[-1] ** 2)
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
+ return loss_d
+
+ def generator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake, real)
+
+ loss_g = 0
+ for x_fake in d_fake:
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+ loss_feature = 0
+
+ for i in range(len(d_fake)):
+ for j in range(len(d_fake[i]) - 1):
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+ return loss_g, loss_feature
diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..499088fbde266dc89a704e45944b8900f2c72952
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py
@@ -0,0 +1,262 @@
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from .layers import WNConv1d
+
+
+class VectorQuantize(nn.Module):
+ """
+ Implementation of VQ similar to Karpathy's repo:
+ https://github.com/karpathy/deep-vector-quantization
+ Additionally uses following tricks from Improved VQGAN
+ (https://arxiv.org/pdf/2110.04627.pdf):
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+ for improved codebook usage
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+ improves training stability
+ """
+
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+ def forward(self, z):
+ """Quantized the input tensor using a fixed codebook and returns
+ the corresponding codebook vectors
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ Tensor[1]
+ Codebook loss to update the codebook
+ Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+ z_e = self.in_proj(z) # z_e : (B x D x T)
+ z_q, indices = self.decode_latents(z_e)
+
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+ z_q = (
+ z_e + (z_q - z_e).detach()
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
+
+ z_q = self.out_proj(z_q)
+
+ return z_q, commitment_loss, codebook_loss, indices, z_e
+
+ def embed_code(self, embed_id):
+ return F.embedding(embed_id, self.codebook.weight)
+
+ def decode_code(self, embed_id):
+ return self.embed_code(embed_id).transpose(1, 2)
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight # codebook: (N x D)
+
+ # L2 normalize encodings and codebook (ViT-VQGAN)
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance with codebook
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ )
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ z_q = self.decode_code(indices)
+ return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+ """
+ Introduced in SoundStream: An end2end neural audio codec
+ https://arxiv.org/abs/2107.03312
+ """
+
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ codebook_size: int = 1024,
+ codebook_dim: Union[int, list] = 8,
+ quantizer_dropout: float = 0.0,
+ ):
+ super().__init__()
+ if isinstance(codebook_dim, int):
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+ self.n_codebooks = n_codebooks
+ self.codebook_dim = codebook_dim
+ self.codebook_size = codebook_size
+
+ self.quantizers = nn.ModuleList(
+ [
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+ for i in range(n_codebooks)
+ ]
+ )
+ self.quantizer_dropout = quantizer_dropout
+
+ def forward(self, z, n_quantizers: int = None):
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
+ the corresponding codebook vectors
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ n_quantizers : int, optional
+ No. of quantizers to use
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
+ when in training mode, and a random number of quantizers is used.
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ """
+ z_q = 0
+ residual = z
+ commitment_loss = 0
+ codebook_loss = 0
+
+ codebook_indices = []
+ latents = []
+
+ if n_quantizers is None:
+ n_quantizers = self.n_codebooks
+ if self.training:
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
+ n_quantizers = n_quantizers.to(z.device)
+
+ for i, quantizer in enumerate(self.quantizers):
+ if self.training is False and i >= n_quantizers:
+ break
+
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+ residual
+ )
+
+ # Create mask to apply quantizer dropout
+ mask = (
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+ )
+ z_q = z_q + z_q_i * mask[:, None, None]
+ residual = residual - z_q_i
+
+ # Sum losses
+ commitment_loss += (commitment_loss_i * mask).mean()
+ codebook_loss += (codebook_loss_i * mask).mean()
+
+ codebook_indices.append(indices_i)
+ latents.append(z_e_i)
+
+ codes = torch.stack(codebook_indices, dim=1)
+ latents = torch.cat(latents, dim=1)
+
+ return z_q, codes, latents, commitment_loss, codebook_loss
+
+ def from_codes(self, codes: torch.Tensor):
+ """Given the quantized codes, reconstruct the continuous representation
+ Parameters
+ ----------
+ codes : Tensor[B x N x T]
+ Quantized discrete representation of input
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ """
+ z_q = 0.0
+ z_p = []
+ n_codebooks = codes.shape[1]
+ for i in range(n_codebooks):
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+ z_p.append(z_p_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+ return z_q, torch.cat(z_p, dim=1), codes
+
+ def from_latents(self, latents: torch.Tensor):
+ """Given the unquantized latents, reconstruct the
+ continuous representation after quantization.
+
+ Parameters
+ ----------
+ latents : Tensor[B x N x T]
+ Continuous representation of input after projection
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized representation of full-projected space
+ Tensor[B x D x T]
+ Quantized representation of latent space
+ """
+ z_q = 0
+ z_p = []
+ codes = []
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+ 0
+ ]
+ for i in range(n_codebooks):
+ j, k = dims[i], dims[i + 1]
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+ z_p.append(z_p_i)
+ codes.append(codes_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+if __name__ == "__main__":
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
+ x = torch.randn(16, 512, 80)
+ y = rvq(x)
+ print(y["latents"].shape)
diff --git a/src/modules/stable_vae/models/nn/__init__.py b/src/modules/stable_vae/models/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec2f6b972aeb8a21764ff73dae2095eb94bb8ba4
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/__init__.py
@@ -0,0 +1,3 @@
+from . import layers
+from . import loss
+from . import quantize
diff --git a/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd83e18bd22222ca6b9ce0f0ab056cc026747bb3
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e1c8ae0e314e5ef76519da0314434a8c8ffe772
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57a85d0e28029662f1ecc2790e44f71caa09cd0c
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecebbf1fbbca206d55b66b3476169545e09b115d
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47d3c0daa35a755146c3ebf4f49d430469fe0c6a
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3bc1a6e78fa00b5b451d12827d0f9d504e3db22
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d960f1c6e8f02f6b1c3b72b9d60543ceadb619cf
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc differ
diff --git a/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1f38a6916548fcdf0e9c2d4caafa7e5b93d9c10
Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc differ
diff --git a/src/modules/stable_vae/models/nn/layers.py b/src/modules/stable_vae/models/nn/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c53a88b881deab94fb06b6a395951cf9cc995a
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/layers.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+ shape = x.shape
+ x = x.reshape(shape[0], shape[1], -1)
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+ x = x.reshape(shape)
+ return x
+
+
+class Snake1d(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+ def forward(self, x):
+ return snake(x, self.alpha)
diff --git a/src/modules/stable_vae/models/nn/loss.py b/src/modules/stable_vae/models/nn/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a26e62f885f0bb7c6774e95660a7848e10a8f87
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/loss.py
@@ -0,0 +1,368 @@
+import typing
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import STFTParams
+from torch import nn
+
+
+class L1Loss(nn.L1Loss):
+ """L1 Loss between AudioSignals. Defaults
+ to comparing ``audio_data``, but any
+ attribute of an AudioSignal can be used.
+
+ Parameters
+ ----------
+ attribute : str, optional
+ Attribute of signal to compare, defaults to ``audio_data``.
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+ self.attribute = attribute
+ self.weight = weight
+ super().__init__(**kwargs)
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate AudioSignal
+ y : AudioSignal
+ Reference AudioSignal
+
+ Returns
+ -------
+ torch.Tensor
+ L1 loss between AudioSignal attributes.
+ """
+ if isinstance(x, AudioSignal):
+ x = getattr(x, self.attribute)
+ y = getattr(y, self.attribute)
+ return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+ """
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+ of estimated and reference audio signals or aligned features.
+
+ Parameters
+ ----------
+ scaling : int, optional
+ Whether to use scale-invariant (True) or
+ signal-to-noise ratio (False), by default True
+ reduction : str, optional
+ How to reduce across the batch (either 'mean',
+ 'sum', or none).], by default ' mean'
+ zero_mean : int, optional
+ Zero mean the references and estimates before
+ computing the loss, by default True
+ clip_min : int, optional
+ The minimum possible loss value. Helps network
+ to not focus on making already good examples better, by default None
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(
+ self,
+ scaling: int = True,
+ reduction: str = "mean",
+ zero_mean: int = True,
+ clip_min: int = None,
+ weight: float = 1.0,
+ ):
+ self.scaling = scaling
+ self.reduction = reduction
+ self.zero_mean = zero_mean
+ self.clip_min = clip_min
+ self.weight = weight
+ super().__init__()
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ eps = 1e-8
+ # nb, nc, nt
+ if isinstance(x, AudioSignal):
+ references = x.audio_data
+ estimates = y.audio_data
+ else:
+ references = x
+ estimates = y
+
+ nb = references.shape[0]
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+ # samples now on axis 1
+ if self.zero_mean:
+ mean_reference = references.mean(dim=1, keepdim=True)
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
+ else:
+ mean_reference = 0
+ mean_estimate = 0
+
+ _references = references - mean_reference
+ _estimates = estimates - mean_estimate
+
+ references_projection = (_references**2).sum(dim=-2) + eps
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+ scale = (
+ (references_on_estimates / references_projection).unsqueeze(1)
+ if self.scaling
+ else 1
+ )
+
+ e_true = scale * _references
+ e_res = _estimates - e_true
+
+ signal = (e_true**2).sum(dim=1)
+ noise = (e_res**2).sum(dim=1)
+ sdr = -10 * torch.log10(signal / noise + eps)
+
+ if self.clip_min is not None:
+ sdr = torch.clamp(sdr, min=self.clip_min)
+
+ if self.reduction == "mean":
+ sdr = sdr.mean()
+ elif self.reduction == "sum":
+ sdr = sdr.sum()
+ return sdr
+
+
+class MultiScaleSTFTLoss(nn.Module):
+ """Computes the multi-scale STFT loss from [1].
+
+ Parameters
+ ----------
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ References
+ ----------
+
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
+ "DDSP: Differentiable Digital Signal Processing."
+ International Conference on Learning Representations. 2019.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.loss_fn = loss_fn
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.clamp_eps = clamp_eps
+ self.weight = weight
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes multi-scale STFT between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Multi-scale STFT loss.
+ """
+ loss = 0.0
+ for s in self.stft_params:
+ x.stft(s.window_length, s.hop_length, s.window_type)
+ y.stft(s.window_length, s.hop_length, s.window_type)
+ loss += self.log_weight * self.loss_fn(
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+ return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+ """Compute distance between mel spectrograms. Can be used
+ in a multi-scale way.
+
+ Parameters
+ ----------
+ n_mels : List[int]
+ Number of mels per STFT, by default [150, 80],
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ n_mels: List[int] = [150, 80],
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ mel_fmin: List[float] = [0.0, 0.0],
+ mel_fmax: List[float] = [None, None],
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.n_mels = n_mels
+ self.loss_fn = loss_fn
+ self.clamp_eps = clamp_eps
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.weight = weight
+ self.mel_fmin = mel_fmin
+ self.mel_fmax = mel_fmax
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes mel loss between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Mel loss.
+ """
+ loss = 0.0
+ for n_mels, fmin, fmax, s in zip(
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+ ):
+ kwargs = {
+ "window_length": s.window_length,
+ "hop_length": s.hop_length,
+ "window_type": s.window_type,
+ }
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+ loss += self.log_weight * self.loss_fn(
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+ return loss
+
+
+class GANLoss(nn.Module):
+ """
+ Computes a discriminator loss, given a discriminator on
+ generated waveforms/spectrograms compared to ground truth
+ waveforms/spectrograms. Computes the loss for both the
+ discriminator and the generator in separate functions.
+ """
+
+ def __init__(self, discriminator):
+ super().__init__()
+ self.discriminator = discriminator
+
+ def forward(self, fake, real):
+ d_fake = self.discriminator(fake.audio_data)
+ d_real = self.discriminator(real.audio_data)
+ return d_fake, d_real
+
+ def discriminator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+ loss_d = 0
+ for x_fake, x_real in zip(d_fake, d_real):
+ loss_d += torch.mean(x_fake[-1] ** 2)
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
+ return loss_d
+
+ def generator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake, real)
+
+ loss_g = 0
+ for x_fake in d_fake:
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+ loss_feature = 0
+
+ for i in range(len(d_fake)):
+ for j in range(len(d_fake[i]) - 1):
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+ return loss_g, loss_feature
diff --git a/src/modules/stable_vae/models/nn/quantize.py b/src/modules/stable_vae/models/nn/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..499088fbde266dc89a704e45944b8900f2c72952
--- /dev/null
+++ b/src/modules/stable_vae/models/nn/quantize.py
@@ -0,0 +1,262 @@
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from .layers import WNConv1d
+
+
+class VectorQuantize(nn.Module):
+ """
+ Implementation of VQ similar to Karpathy's repo:
+ https://github.com/karpathy/deep-vector-quantization
+ Additionally uses following tricks from Improved VQGAN
+ (https://arxiv.org/pdf/2110.04627.pdf):
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+ for improved codebook usage
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+ improves training stability
+ """
+
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+ def forward(self, z):
+ """Quantized the input tensor using a fixed codebook and returns
+ the corresponding codebook vectors
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ Tensor[1]
+ Codebook loss to update the codebook
+ Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+ z_e = self.in_proj(z) # z_e : (B x D x T)
+ z_q, indices = self.decode_latents(z_e)
+
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+ z_q = (
+ z_e + (z_q - z_e).detach()
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
+
+ z_q = self.out_proj(z_q)
+
+ return z_q, commitment_loss, codebook_loss, indices, z_e
+
+ def embed_code(self, embed_id):
+ return F.embedding(embed_id, self.codebook.weight)
+
+ def decode_code(self, embed_id):
+ return self.embed_code(embed_id).transpose(1, 2)
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight # codebook: (N x D)
+
+ # L2 normalize encodings and codebook (ViT-VQGAN)
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance with codebook
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ )
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ z_q = self.decode_code(indices)
+ return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+ """
+ Introduced in SoundStream: An end2end neural audio codec
+ https://arxiv.org/abs/2107.03312
+ """
+
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ codebook_size: int = 1024,
+ codebook_dim: Union[int, list] = 8,
+ quantizer_dropout: float = 0.0,
+ ):
+ super().__init__()
+ if isinstance(codebook_dim, int):
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+ self.n_codebooks = n_codebooks
+ self.codebook_dim = codebook_dim
+ self.codebook_size = codebook_size
+
+ self.quantizers = nn.ModuleList(
+ [
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+ for i in range(n_codebooks)
+ ]
+ )
+ self.quantizer_dropout = quantizer_dropout
+
+ def forward(self, z, n_quantizers: int = None):
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
+ the corresponding codebook vectors
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ n_quantizers : int, optional
+ No. of quantizers to use
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
+ when in training mode, and a random number of quantizers is used.
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ """
+ z_q = 0
+ residual = z
+ commitment_loss = 0
+ codebook_loss = 0
+
+ codebook_indices = []
+ latents = []
+
+ if n_quantizers is None:
+ n_quantizers = self.n_codebooks
+ if self.training:
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
+ n_quantizers = n_quantizers.to(z.device)
+
+ for i, quantizer in enumerate(self.quantizers):
+ if self.training is False and i >= n_quantizers:
+ break
+
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+ residual
+ )
+
+ # Create mask to apply quantizer dropout
+ mask = (
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+ )
+ z_q = z_q + z_q_i * mask[:, None, None]
+ residual = residual - z_q_i
+
+ # Sum losses
+ commitment_loss += (commitment_loss_i * mask).mean()
+ codebook_loss += (codebook_loss_i * mask).mean()
+
+ codebook_indices.append(indices_i)
+ latents.append(z_e_i)
+
+ codes = torch.stack(codebook_indices, dim=1)
+ latents = torch.cat(latents, dim=1)
+
+ return z_q, codes, latents, commitment_loss, codebook_loss
+
+ def from_codes(self, codes: torch.Tensor):
+ """Given the quantized codes, reconstruct the continuous representation
+ Parameters
+ ----------
+ codes : Tensor[B x N x T]
+ Quantized discrete representation of input
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ """
+ z_q = 0.0
+ z_p = []
+ n_codebooks = codes.shape[1]
+ for i in range(n_codebooks):
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+ z_p.append(z_p_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+ return z_q, torch.cat(z_p, dim=1), codes
+
+ def from_latents(self, latents: torch.Tensor):
+ """Given the unquantized latents, reconstruct the
+ continuous representation after quantization.
+
+ Parameters
+ ----------
+ latents : Tensor[B x N x T]
+ Continuous representation of input after projection
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized representation of full-projected space
+ Tensor[B x D x T]
+ Quantized representation of latent space
+ """
+ z_q = 0
+ z_p = []
+ codes = []
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+ 0
+ ]
+ for i in range(n_codebooks):
+ j, k = dims[i], dims[i + 1]
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+ z_p.append(z_p_i)
+ codes.append(codes_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+if __name__ == "__main__":
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
+ x = torch.randn(16, 512, 80)
+ y = rvq(x)
+ print(y["latents"].shape)
diff --git a/src/modules/stable_vae/models/pretransforms.py b/src/modules/stable_vae/models/pretransforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c4ffbbebd379695d123e5155f1387afe4d824f1
--- /dev/null
+++ b/src/modules/stable_vae/models/pretransforms.py
@@ -0,0 +1,258 @@
+import torch
+from einops import rearrange
+from torch import nn
+
+class Pretransform(nn.Module):
+ def __init__(self, enable_grad, io_channels, is_discrete):
+ super().__init__()
+
+ self.is_discrete = is_discrete
+ self.io_channels = io_channels
+ self.encoded_channels = None
+ self.downsampling_ratio = None
+
+ self.enable_grad = enable_grad
+
+ def encode(self, x):
+ raise NotImplementedError
+
+ def decode(self, z):
+ raise NotImplementedError
+
+ def tokenize(self, x):
+ raise NotImplementedError
+
+ def decode_tokens(self, tokens):
+ raise NotImplementedError
+
+class AutoencoderPretransform(Pretransform):
+ def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
+ super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
+ self.model = model
+ self.model.requires_grad_(False).eval()
+ self.scale=scale
+ self.downsampling_ratio = model.downsampling_ratio
+ self.io_channels = model.io_channels
+ self.sample_rate = model.sample_rate
+
+ self.model_half = model_half
+ self.iterate_batch = iterate_batch
+
+ self.encoded_channels = model.latent_dim
+
+ self.chunked = chunked
+ self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
+ self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
+
+ if self.model_half:
+ self.model.half()
+
+ def encode(self, x, **kwargs):
+
+ if self.model_half:
+ x = x.half()
+ self.model.to(torch.float16)
+
+ encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
+
+ if self.model_half:
+ encoded = encoded.float()
+
+ return encoded / self.scale
+
+ def decode(self, z, **kwargs):
+ z = z * self.scale
+
+ if self.model_half:
+ z = z.half()
+ self.model.to(torch.float16)
+
+ decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
+
+ if self.model_half:
+ decoded = decoded.float()
+
+ return decoded
+
+ def tokenize(self, x, **kwargs):
+ assert self.model.is_discrete, "Cannot tokenize with a continuous model"
+
+ _, info = self.model.encode(x, return_info = True, **kwargs)
+
+ return info[self.model.bottleneck.tokens_id]
+
+ def decode_tokens(self, tokens, **kwargs):
+ assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
+
+ return self.model.decode_tokens(tokens, **kwargs)
+
+ def load_state_dict(self, state_dict, strict=True):
+ self.model.load_state_dict(state_dict, strict=strict)
+
+class WaveletPretransform(Pretransform):
+ def __init__(self, channels, levels, wavelet):
+ super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
+
+ from .wavelets import WaveletEncode1d, WaveletDecode1d
+
+ self.encoder = WaveletEncode1d(channels, levels, wavelet)
+ self.decoder = WaveletDecode1d(channels, levels, wavelet)
+
+ self.downsampling_ratio = 2 ** levels
+ self.io_channels = channels
+ self.encoded_channels = channels * self.downsampling_ratio
+
+ def encode(self, x):
+ return self.encoder(x)
+
+ def decode(self, z):
+ return self.decoder(z)
+
+class PQMFPretransform(Pretransform):
+ def __init__(self, attenuation=100, num_bands=16):
+ # TODO: Fix PQMF to take in in-channels
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
+ from .pqmf import PQMF
+ self.pqmf = PQMF(attenuation, num_bands)
+
+
+ def encode(self, x):
+ # x is (Batch x Channels x Time)
+ x = self.pqmf.forward(x)
+ # pqmf.forward returns (Batch x Channels x Bands x Time)
+ # but Pretransform needs Batch x Channels x Time
+ # so concatenate channels and bands into one axis
+ return rearrange(x, "b c n t -> b (c n) t")
+
+ def decode(self, x):
+ # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
+ x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
+ # returns (Batch x Channels x Time)
+ return self.pqmf.inverse(x)
+
+class PretrainedDACPretransform(Pretransform):
+ def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
+
+ import dac
+
+ model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
+
+ self.model = dac.DAC.load(model_path)
+
+ self.quantize_on_decode = quantize_on_decode
+
+ if model_type == "44khz":
+ self.downsampling_ratio = 512
+ else:
+ self.downsampling_ratio = 320
+
+ self.io_channels = 1
+
+ self.scale = scale
+
+ self.chunked = chunked
+
+ self.encoded_channels = self.model.latent_dim
+
+ self.num_quantizers = self.model.n_codebooks
+
+ self.codebook_size = self.model.codebook_size
+
+ def encode(self, x):
+
+ latents = self.model.encoder(x)
+
+ if self.quantize_on_decode:
+ output = latents
+ else:
+ z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
+ output = z
+
+ if self.scale != 1.0:
+ output = output / self.scale
+
+ return output
+
+ def decode(self, z):
+
+ if self.scale != 1.0:
+ z = z * self.scale
+
+ if self.quantize_on_decode:
+ z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
+
+ return self.model.decode(z)
+
+ def tokenize(self, x):
+ return self.model.encode(x)[1]
+
+ def decode_tokens(self, tokens):
+ latents = self.model.quantizer.from_codes(tokens)
+ return self.model.decode(latents)
+
+class AudiocraftCompressionPretransform(Pretransform):
+ def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
+
+ try:
+ from audiocraft.models import CompressionModel
+ except ImportError:
+ raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
+
+ self.model = CompressionModel.get_pretrained(model_type)
+
+ self.quantize_on_decode = quantize_on_decode
+
+ self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
+
+ self.sample_rate = self.model.sample_rate
+
+ self.io_channels = self.model.channels
+
+ self.scale = scale
+
+ #self.encoded_channels = self.model.latent_dim
+
+ self.num_quantizers = self.model.num_codebooks
+
+ self.codebook_size = self.model.cardinality
+
+ self.model.to(torch.float16).eval().requires_grad_(False)
+
+ def encode(self, x):
+
+ assert False, "Audiocraft compression models do not support continuous encoding"
+
+ # latents = self.model.encoder(x)
+
+ # if self.quantize_on_decode:
+ # output = latents
+ # else:
+ # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
+ # output = z
+
+ # if self.scale != 1.0:
+ # output = output / self.scale
+
+ # return output
+
+ def decode(self, z):
+
+ assert False, "Audiocraft compression models do not support continuous decoding"
+
+ # if self.scale != 1.0:
+ # z = z * self.scale
+
+ # if self.quantize_on_decode:
+ # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
+
+ # return self.model.decode(z)
+
+ def tokenize(self, x):
+ with torch.cuda.amp.autocast(enabled=False):
+ return self.model.encode(x.to(torch.float16))[0]
+
+ def decode_tokens(self, tokens):
+ with torch.cuda.amp.autocast(enabled=False):
+ return self.model.decode(tokens)
diff --git a/src/modules/stable_vae/models/utils.py b/src/modules/stable_vae/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7adc0a9b4d72ed1e1d64262dfd71265fee9c9a4a
--- /dev/null
+++ b/src/modules/stable_vae/models/utils.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+from torchaudio import transforms as T
+
+
+class PadCrop(nn.Module):
+ def __init__(self, n_samples, randomize=True):
+ super().__init__()
+ self.n_samples = n_samples
+ self.randomize = randomize
+
+ def __call__(self, signal):
+ n, s = signal.shape
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
+ end = start + self.n_samples
+ output = signal.new_zeros([n, self.n_samples])
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
+ return output
+
+
+def set_audio_channels(audio, target_channels):
+ if target_channels == 1:
+ # Convert to mono
+ audio = audio.mean(1, keepdim=True)
+ elif target_channels == 2:
+ # Convert to stereo
+ if audio.shape[1] == 1:
+ audio = audio.repeat(1, 2, 1)
+ elif audio.shape[1] > 2:
+ audio = audio[:, :2, :]
+ return audio
+
+def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
+
+ audio = audio.to(device)
+
+ if in_sr != target_sr:
+ resample_tf = T.Resample(in_sr, target_sr).to(device)
+ audio = resample_tf(audio)
+
+ audio = PadCrop(target_length, randomize=False)(audio)
+
+ # Add batch dimension
+ if audio.dim() == 1:
+ audio = audio.unsqueeze(0).unsqueeze(0)
+ elif audio.dim() == 2:
+ audio = audio.unsqueeze(0)
+
+ audio = set_audio_channels(audio, target_channels)
+
+ return audio
\ No newline at end of file
diff --git a/src/test.py b/src/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b605f0f29482d06510b618691ff74a718816aba4
--- /dev/null
+++ b/src/test.py
@@ -0,0 +1,97 @@
+import random
+import argparse
+import os
+import time
+import soundfile as sf
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers import DDIMScheduler
+from models.conditioners import MaskDiT
+from modules.autoencoder_wrapper import Autoencoder
+from transformers import T5Tokenizer, T5EncoderModel
+from inference import inference
+from utils import scale_shift, get_lr_scheduler, compute_snr, load_yaml_with_includes
+
+
+parser = argparse.ArgumentParser()
+# config settings
+parser.add_argument('--config-name', type=str, default='configs/udit_ada.yml')
+parser.add_argument('--ckpt-path', type=str, default='../ckpts/')
+parser.add_argument('--ckpt-id', type=str, default='120')
+parser.add_argument('--save_path', type=str, default='../output/')
+parser.add_argument('--test-df', type=str, default='audiocaps_test.csv')
+# parser.add_argument('--test-split', type=str, default='test')
+
+parser.add_argument('--device', type=str, default='cuda')
+parser.add_argument('--guidance-scale', type=float, default=3)
+parser.add_argument('--guidance-rescale', type=float, default=0)
+parser.add_argument('--ddim-steps', type=int, default=50)
+parser.add_argument('--eta', type=float, default=1)
+parser.add_argument('--random-seed', type=int, default=None)
+
+args = parser.parse_args()
+params = load_yaml_with_includes(args.config_name)
+
+# args.ckpt_path = f"{args.ckpt_path}/{params['model_name']}/{args.ckpt_id}.pt"
+args.save_path = f"{args.save_path}/{params['model_name']}/{args.ckpt_id}_{args.ddim_steps}_{args.guidance_scale}_{args.guidance_rescale}/"
+args.ckpt_path = f"{args.ckpt_path}/{args.ckpt_id}.pt"
+
+if __name__ == '__main__':
+ # Codec Model
+ autoencoder = Autoencoder(ckpt_path=params['autoencoder']['path'],
+ model_type=params['autoencoder']['name'],
+ quantization_first=params['autoencoder']['q_first'])
+ autoencoder.to(args.device)
+ autoencoder.eval()
+
+ # text encoder
+ tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
+ text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model'],
+ device_map='cpu').to(args.device)
+ text_encoder.eval()
+
+ # main U-Net
+ unet = MaskDiT(**params['model']).to(args.device)
+ unet.eval()
+ unet.load_state_dict(torch.load(args.ckpt_path)['model'])
+
+ total_params = sum([param.nelement() for param in unet.parameters()])
+ print("Number of parameter: %.2fM" % (total_params / 1e6))
+
+ noise_scheduler = DDIMScheduler(**params['diff'])
+ # these steps reset dtype of noise_scheduler params
+ latents = torch.randn((1, 128, 128), device=args.device)
+ noise = torch.randn_like(latents)
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=args.device)
+ _ = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ df = pd.read_csv(args.test_df)
+ # Wdf = df[df['split'] == args.test_split]
+ df = df[df['audio_length'] != 0]
+ # df = df.sample(10)
+ os.makedirs(args.save_path, exist_ok=True)
+ audio_frames = params['data']['train_frames']
+
+ for i in tqdm(range(len(df))):
+ row = df.iloc[i]
+ text = row['caption']
+ audio_id = row['audiocap_id']
+
+ pred = inference(autoencoder, unet, None, None,
+ tokenizer, text_encoder,
+ params, noise_scheduler,
+ text, None,
+ audio_frames,
+ args.guidance_scale, args.guidance_rescale,
+ args.ddim_steps, args.eta, args.random_seed,
+ args.device)
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)
+
+ sf.write(f"{args.save_path}/{audio_id}.wav",
+ pred, samplerate=params['data']['sr'])
\ No newline at end of file
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90f60fdd89ad8575faafe45188bd1d968852fc67
--- /dev/null
+++ b/src/utils/__init__.py
@@ -0,0 +1 @@
+from .utils import *
\ No newline at end of file
diff --git a/src/utils/utils.py b/src/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..82a6bc56e9e341e54dc6a136f1f78261dde0f655
--- /dev/null
+++ b/src/utils/utils.py
@@ -0,0 +1,94 @@
+import torch
+import numpy as np
+import yaml
+import os
+
+
+def load_yaml_with_includes(yaml_file):
+ def loader_with_include(loader, node):
+ # Load the included file
+ include_path = os.path.join(os.path.dirname(yaml_file), loader.construct_scalar(node))
+ with open(include_path, 'r') as f:
+ return yaml.load(f, Loader=yaml.FullLoader)
+
+ yaml.add_constructor('!include', loader_with_include, Loader=yaml.FullLoader)
+
+ with open(yaml_file, 'r') as f:
+ return yaml.load(f, Loader=yaml.FullLoader)
+
+
+def scale_shift(x, scale, shift):
+ return (x+shift) * scale
+
+
+def scale_shift_re(x, scale, shift):
+ return (x/scale) - shift
+
+
+def align_seq(source, target_length, mapping_method='hard'):
+ source_len = source.shape[1]
+ if mapping_method == 'hard':
+ mapping_idx = np.round(np.arange(target_length) * source_len / target_length)
+ output = source[:, mapping_idx]
+ else:
+ # TBD
+ raise NotImplementedError
+
+ return output
+
+
+def customized_lr_scheduler(optimizer, warmup_steps=-1):
+ from torch.optim.lr_scheduler import LambdaLR
+
+ def fn(step):
+ if warmup_steps > 0:
+ return min(step / warmup_steps, 1)
+ else:
+ return 1
+ return LambdaLR(optimizer, fn)
+
+
+def get_lr_scheduler(optimizer, name, **kwargs):
+ if name == 'customized':
+ return customized_lr_scheduler(optimizer, **kwargs)
+ elif name == 'cosine':
+ from torch.optim.lr_scheduler import CosineAnnealingLR
+ return CosineAnnealingLR(optimizer, **kwargs)
+ else:
+ raise NotImplementedError(name)
+
+
+def compute_snr(noise_scheduler, timesteps):
+ """
+ Computes SNR as per
+ https://github.com/TiankaiHang/Min-SNR-Diffusion
+ Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ """
+ alphas_cumprod = noise_scheduler.alphas_cumprod
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+ # Expand the tensors.
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion
+ # Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
+
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
+
+ # Compute SNR.
+ snr = (alpha / sigma) ** 2
+ return snr
+
+
+if __name__ == "__main__":
+
+ a = torch.rand(2, 10)
+ target_len = 15
+
+ b = align_seq(a, target_len)
\ No newline at end of file