Nithya
updated parent repo and restructured things
98eb218
raw
history blame
2.91 kB
from __gin__ import dynamic_registration
from gamadhani import src
from gamadhani.src import dataset
from gamadhani.src import model_diffusion
from gamadhani.src import task_functions
from gamadhani.utils import utils
import torch
# Macros:
# ==============================================================================
LR = 0.0001
SEQ_LEN = 1200
TRANSPOSE_VALUE = 400
# Parameters for torch.optim.AdamW:
# ==============================================================================
torch.optim.AdamW.betas = (0.9, 0.99)
torch.optim.AdamW.lr = %LR
# Parameters for utils.build_warmed_exponential_lr_scheduler:
# ==============================================================================
utils.build_warmed_exponential_lr_scheduler.cycle_length = 200000
utils.build_warmed_exponential_lr_scheduler.eta_max = %LR
utils.build_warmed_exponential_lr_scheduler.eta_min = 0.1
utils.build_warmed_exponential_lr_scheduler.peak_iteration = 10000
utils.build_warmed_exponential_lr_scheduler.start_factor = 0.01
# Parameters for model_diffusion.UNetBase.configure_optimizers:
# ==============================================================================
model_diffusion.UNetBase.configure_optimizers.optimizer_cls = @torch.optim.AdamW
model_diffusion.UNetBase.configure_optimizers.scheduler_cls = \
@utils.build_warmed_exponential_lr_scheduler
# Parameters for dataset.Task:
# ==============================================================================
src.dataset.Task.kwargs = {
"decoder_key" : 'pitch',
"max_clip" : 600,
"min_clip" : 200,
"min_norm_pitch" : -4915,
"pitch_downsample" : 10,
"seq_len" : %SEQ_LEN,
"time_downsample" : 2}
# Parameters for train/dataset.pitch_read_w_downsample:
# ==============================================================================
# train/dataset.Task.kwargs = {"transpose_pitch": %TRANSPOSE_VALUE}
# Parameters for train/dataset.Task:
# ==============================================================================
src.dataset.Task.read_fn = @src.task_functions.pitch_read_downsample_diff
src.dataset.Task.invert_fn = @src.task_functions.invert_pitch_read_downsample_diff
# Parameters for model_diffusion.UNet:
# ==============================================================================
model_diffusion.UNet.dropout = 0.3
model_diffusion.UNet.features = [512, 640, 1024]
model_diffusion.UNet.inp_dim = 1
model_diffusion.UNet.kernel_size = 5
model_diffusion.UNet.nonlinearity = 'mish'
model_diffusion.UNet.norm = True
model_diffusion.UNet.num_attns = 4
model_diffusion.UNet.num_convs = 4
model_diffusion.UNet.num_heads = 8
model_diffusion.UNet.project_dim = 256
model_diffusion.UNet.seq_len = %SEQ_LEN
model_diffusion.UNet.strides = [4, 2, 2]
model_diffusion.UNet.time_dim = 128