Nithya
updated the models
3d6b478
raw
history blame
1.54 kB
from __gin__ import dynamic_registration
from gamadhani import src
from gamadhani.src import dataset
from gamadhani.src import model_transformer
from gamadhani.src import task_functions
from gamadhani.utils import utils
import torch.optim
MODEL_DIM = 512
EMB_DIM = 512
NUM_TOKENS = 7928
NUM_QUANTIZERS = 1
DROPOUT_RATE = 0.3
NUM_HEADS = 8
SEQ_LEN = 1200
HEAD_DIM = 32
NUM_LAYERS = 8
LR = 1e-3
model_transformer.XTransformerPrior:
num_tokens = %NUM_TOKENS
seq_len = %SEQ_LEN
model_dim = %MODEL_DIM
emb_dim = %EMB_DIM
head_dim = %HEAD_DIM
num_layers = %NUM_LAYERS
num_heads = %NUM_HEADS
dropout_rate = %DROPOUT_RATE
src.dataset.Task:
read_fn = @src.task_functions.pitch_read_downsample
invert_fn = @src.task_functions.invert_pitch_read_downsample
kwargs = {"seq_len": %SEQ_LEN,
"decoder_key": "pitch",
"min_norm_pitch": -4915,
"time_downsample": 2,
"pitch_downsample": 10,
"base_tonic": 440.}
src.dataset.SequenceDataset:
task = @dataset.Task()
apply_transform = False
model_transformer.XTransformerPrior.configure_optimizers:
optimizer_cls = @torch.optim.AdamW
scheduler_cls = @utils.build_warmed_exponential_lr_scheduler
utils.build_warmed_exponential_lr_scheduler:
start_factor = .01
peak_iteration = 10000
cycle_length = 394600
eta_min = 0.1
eta_max = %LR
utils.set_seed:
seed = 2023
torch.optim.AdamW:
lr = %LR
betas = (.9, .98)