from __gin__ import dynamic_registration from gamadhani import src from gamadhani.src import dataset from gamadhani.src import model_diffusion from gamadhani.src import task_functions from gamadhani.utils import utils import torch # Macros: # ============================================================================== LR = 0.0001 SEQ_LEN = 1200 TRANSPOSE_VALUE = 400 # Parameters for torch.optim.AdamW: # ============================================================================== torch.optim.AdamW.betas = (0.9, 0.99) torch.optim.AdamW.lr = %LR # Parameters for utils.build_warmed_exponential_lr_scheduler: # ============================================================================== utils.build_warmed_exponential_lr_scheduler.cycle_length = 200000 utils.build_warmed_exponential_lr_scheduler.eta_max = %LR utils.build_warmed_exponential_lr_scheduler.eta_min = 0.1 utils.build_warmed_exponential_lr_scheduler.peak_iteration = 10000 utils.build_warmed_exponential_lr_scheduler.start_factor = 0.01 # Parameters for model_diffusion.UNetBase.configure_optimizers: # ============================================================================== model_diffusion.UNetBase.configure_optimizers.optimizer_cls = @torch.optim.AdamW model_diffusion.UNetBase.configure_optimizers.scheduler_cls = \ @utils.build_warmed_exponential_lr_scheduler # Parameters for dataset.Task: # ============================================================================== src.dataset.Task.kwargs = { "decoder_key" : 'pitch', "max_clip" : 600, "min_clip" : 200, "min_norm_pitch" : -4915, "pitch_downsample" : 10, "seq_len" : %SEQ_LEN, "time_downsample" : 2} # Parameters for train/dataset.pitch_read_w_downsample: # ============================================================================== # train/dataset.Task.kwargs = {"transpose_pitch": %TRANSPOSE_VALUE} # Parameters for train/dataset.Task: # ============================================================================== src.dataset.Task.read_fn = @src.task_functions.pitch_read_downsample_diff src.dataset.Task.invert_fn = @src.task_functions.invert_pitch_read_downsample_diff # Parameters for model_diffusion.UNet: # ============================================================================== model_diffusion.UNet.dropout = 0.3 model_diffusion.UNet.features = [512, 640, 1024] model_diffusion.UNet.inp_dim = 1 model_diffusion.UNet.kernel_size = 5 model_diffusion.UNet.nonlinearity = 'mish' model_diffusion.UNet.norm = True model_diffusion.UNet.num_attns = 4 model_diffusion.UNet.num_convs = 4 model_diffusion.UNet.num_heads = 8 model_diffusion.UNet.project_dim = 256 model_diffusion.UNet.seq_len = %SEQ_LEN model_diffusion.UNet.strides = [4, 2, 2] model_diffusion.UNet.time_dim = 128