|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""init_train.py""" |
|
from typing import Tuple, Literal, Any |
|
from copy import deepcopy |
|
import os |
|
import argparse |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import WandbLogger |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from pytorch_lightning.utilities import rank_zero_only |
|
from config.config import shared_cfg as default_shared_cfg |
|
from config.config import audio_cfg as default_audio_cfg |
|
from config.config import model_cfg as default_model_cfg |
|
from config.config import DEEPSPEED_CFG |
|
|
|
|
|
def initialize_trainer(args: argparse.Namespace, |
|
stage: Literal['train', 'test'] = 'train') -> Tuple[pl.Trainer, WandbLogger, dict]: |
|
"""Initialize trainer and logger""" |
|
shared_cfg = deepcopy(default_shared_cfg) |
|
|
|
|
|
os.makedirs(shared_cfg["WANDB"]["save_dir"], exist_ok=True) |
|
|
|
|
|
if "@" in args.exp_id: |
|
args.exp_id, checkpoint_name = args.exp_id.split("@") |
|
else: |
|
checkpoint_name = "last.ckpt" |
|
|
|
|
|
lightning_dir = os.path.join(shared_cfg["WANDB"]["save_dir"], args.project, args.exp_id) |
|
|
|
|
|
if args.wandb_mode is not None: |
|
shared_cfg["WANDB"]["mode"] = str(args.wandb_mode) |
|
if shared_cfg["WANDB"].get("cache_dir", None) is not None: |
|
os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir") |
|
del shared_cfg["WANDB"]["cache_dir"] |
|
wandb_logger = WandbLogger(log_model="all", |
|
project=args.project, |
|
id=args.exp_id, |
|
allow_val_change=True, |
|
**shared_cfg['WANDB']) |
|
|
|
|
|
last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name) |
|
if os.path.exists(os.path.join(last_ckpt_path)): |
|
print(f'Resuming from {last_ckpt_path}') |
|
elif stage == 'train': |
|
print(f'No checkpoint found in {last_ckpt_path}. Starting from scratch') |
|
last_ckpt_path = None |
|
else: |
|
raise ValueError(f'No checkpoint found in {last_ckpt_path}. Quit...') |
|
|
|
|
|
dir_info = dict(lightning_dir=lightning_dir, last_ckpt_path=last_ckpt_path) |
|
|
|
|
|
checkpoint_callback = ModelCheckpoint(**shared_cfg["CHECKPOINT"],) |
|
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
|
|
|
|
if args.strategy == 'deepspeed': |
|
strategy = pl.strategies.DeepSpeedStrategy(config=DEEPSPEED_CFG) |
|
|
|
|
|
if stage == 'train' and args.val_interval is not None: |
|
shared_cfg["TRAINER"]["check_val_every_n_epoch"] = None |
|
shared_cfg["TRAINER"]["val_check_interval"] = int(args.val_interval) |
|
|
|
|
|
sync_batchnorm = False |
|
if stage == 'train': |
|
|
|
if args.train_batch_size is not None: |
|
train_sub_bsz = int(args.train_batch_size[0]) |
|
train_local_bsz = int(args.train_batch_size[1]) |
|
if train_local_bsz % train_sub_bsz == 0: |
|
shared_cfg["BSZ"]["train_sub"] = train_sub_bsz |
|
shared_cfg["BSZ"]["train_local"] = train_local_bsz |
|
else: |
|
raise ValueError( |
|
f'Local batch size {train_local_bsz} must be divisible by sub batch size {train_sub_bsz}') |
|
|
|
|
|
if args.strategy == 'ddp': |
|
args.strategy = 'ddp_find_unused_parameters_true' |
|
|
|
|
|
if args.sync_batchnorm is True: |
|
sync_batchnorm = True |
|
|
|
train_params = dict(**shared_cfg["TRAINER"], |
|
devices=args.num_gpus if args.num_gpus == 'auto' else int(args.num_gpus), |
|
num_nodes=int(args.num_nodes), |
|
strategy=strategy if args.strategy == 'deepspeed' else args.strategy, |
|
precision=args.precision, |
|
max_epochs=args.max_epochs if stage == 'train' else None, |
|
max_steps=args.max_steps if stage == 'train' else -1, |
|
logger=wandb_logger, |
|
callbacks=[checkpoint_callback, lr_monitor], |
|
sync_batchnorm=sync_batchnorm) |
|
trainer = pl.trainer.trainer.Trainer(**train_params) |
|
|
|
|
|
if trainer.global_rank == 0: |
|
wandb_logger.experiment.config.update(args, allow_val_change=True) |
|
|
|
return trainer, wandb_logger, dir_info, shared_cfg |
|
|
|
|
|
def update_config(args, shared_cfg, stage: Literal['train', 'test'] = 'train'): |
|
"""Update audio/model/shared configurations with args""" |
|
audio_cfg = default_audio_cfg |
|
model_cfg = default_model_cfg |
|
|
|
|
|
if stage == 'train': |
|
|
|
if args.random_amp_range is not None: |
|
shared_cfg["AUGMENTATION"]["train_random_amp_range"] = list( |
|
(float(args.random_amp_range[0]), float(args.random_amp_range[1]))) |
|
if args.stem_iaug_prob is not None: |
|
shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"] = float(args.stem_iaug_prob) |
|
|
|
if args.xaug_max_k is not None: |
|
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["max_k"] = int(args.xaug_max_k) |
|
if args.xaug_tau is not None: |
|
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["tau"] = float(args.xaug_tau) |
|
if args.xaug_alpha is not None: |
|
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["alpha"] = float(args.xaug_alpha) |
|
if args.xaug_no_instr_overlap is not None: |
|
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_instr_overlap"] = bool(args.xaug_no_instr_overlap) |
|
if args.xaug_no_drum_overlap is not None: |
|
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_drum_overlap"] = bool(args.xaug_no_drum_overlap) |
|
if args.uhat_intra_stem_augment is not None: |
|
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["uhat_intra_stem_augment"] = bool( |
|
args.uhat_intra_stem_augment) |
|
|
|
if args.pitch_shift_range is not None: |
|
if args.pitch_shift_range in [["0", "0"], [0, 0]]: |
|
shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = None |
|
else: |
|
shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = list( |
|
(int(args.pitch_shift_range[0]), int(args.pitch_shift_range[1]))) |
|
|
|
train_stem_iaug_prob = shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"] |
|
random_amp_range = shared_cfg["AUGMENTATION"]["train_random_amp_range"] |
|
train_stem_xaug_policy = shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"] |
|
print(f'Random amp range: {random_amp_range}\n' + |
|
f'Intra-stem augmentation probability: {train_stem_iaug_prob}\n' + |
|
f'Stem augmentation policy: {train_stem_xaug_policy}\n' + |
|
f'Pitch shift range: {shared_cfg["AUGMENTATION"]["train_pitch_shift_range"]}\n') |
|
|
|
|
|
if args.audio_codec != None: |
|
assert args.audio_codec in ['spec', 'melspec'] |
|
audio_cfg["codec"] = str(args.audio_codec) |
|
if args.hop_length != None: |
|
audio_cfg["hop_length"] = int(args.hop_length) |
|
if args.n_mels != None: |
|
audio_cfg["n_mels"] = int(args.n_mels) |
|
if args.input_frames != None: |
|
audio_cfg["input_frames"] = int(args.input_frames) |
|
|
|
|
|
if shared_cfg["TOKENIZER"]["max_shift_steps"] == "auto": |
|
shift_steps_ms = shared_cfg["TOKENIZER"]["shift_step_ms"] |
|
input_frames = audio_cfg["input_frames"] |
|
fs = audio_cfg["sample_rate"] |
|
max_shift_steps = (input_frames / fs) // (shift_steps_ms / 1000) + 2 |
|
shared_cfg["TOKENIZER"]["max_shift_steps"] = int(max_shift_steps) |
|
|
|
|
|
if args.encoder_type != None: |
|
model_cfg["encoder_type"] = str(args.encoder_type) |
|
if args.decoder_type != None: |
|
model_cfg["decoder_type"] = str(args.decoder_type) |
|
if args.pre_encoder_type != "default": |
|
model_cfg["pre_encoder_type"] = str(args.pre_encoder_type) |
|
if args.pre_decoder_type != 'default': |
|
model_cfg["pre_decoder_type"] = str(args.pre_decoder_type) |
|
if args.conv_out_channels != None: |
|
model_cfg["conv_out_channels"] = int(args.conv_out_channels) |
|
assert isinstance(args.task_cond_decoder, bool) and isinstance(args.task_cond_encoder, bool) |
|
model_cfg["use_task_conditional_encoder"] = args.task_cond_encoder |
|
model_cfg["use_task_conditional_decoder"] = args.task_cond_decoder |
|
|
|
if args.encoder_position_encoding_type != 'default': |
|
if args.encoder_position_encoding_type in ['None', 'none', '0']: |
|
model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = None |
|
elif args.encoder_position_encoding_type in [ |
|
'sinusoidal', 'rope', 'trainable', 'alibi', 'alibit', 'tkd', 'td', 'tk', 'kdt' |
|
]: |
|
model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = str( |
|
args.encoder_position_encoding_type) |
|
else: |
|
raise ValueError(f'Encoder PE type {args.encoder_position_encoding_type} not supported') |
|
if args.decoder_position_encoding_type != 'default': |
|
if args.decoder_position_encoding_type in ['None', 'none', '0']: |
|
raise ValueError('Decoder PE type cannot be None') |
|
elif args.decoder_position_encoding_type in ['sinusoidal', 'trainable']: |
|
model_cfg["decoder"][model_cfg["decoder_type"]]["position_encoding_type"] = str( |
|
args.decoder_position_encoding_type) |
|
else: |
|
raise ValueError(f'Decoder PE {args.decoder_position_encoding_type} not supported') |
|
|
|
if args.tie_word_embedding is not None: |
|
model_cfg["tie_word_embedding"] = bool(args.tie_word_embedding) |
|
|
|
if args.d_feat != None: |
|
model_cfg["d_feat"] = int(args.d_feat) |
|
if args.d_latent != None: |
|
model_cfg['encoder']['perceiver-tf']["d_latent"] = int(args.d_latent) |
|
if args.num_latents != None: |
|
model_cfg['encoder']['perceiver-tf']['num_latents'] = int(args.num_latents) |
|
if args.perceiver_tf_d_model != None: |
|
model_cfg['encoder']['perceiver-tf']['d_model'] = int(args.perceiver_tf_d_model) |
|
if args.num_perceiver_tf_blocks != None: |
|
model_cfg["encoder"]["perceiver-tf"]["num_blocks"] = int(args.num_perceiver_tf_blocks) |
|
if args.num_perceiver_tf_local_transformers_per_block != None: |
|
model_cfg["encoder"]["perceiver-tf"]["num_local_transformers_per_block"] = int( |
|
args.num_perceiver_tf_local_transformers_per_block) |
|
if args.num_perceiver_tf_temporal_transformers_per_block != None: |
|
model_cfg["encoder"]["perceiver-tf"]["num_temporal_transformers_per_block"] = int( |
|
args.num_perceiver_tf_temporal_transformers_per_block) |
|
if args.attention_to_channel != None: |
|
model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = bool(args.attention_to_channel) |
|
if args.sca_use_query_residual != None: |
|
model_cfg["encoder"]["perceiver-tf"]["sca_use_query_residual"] = bool(args.sca_use_query_residual) |
|
if args.layer_norm_type != None: |
|
model_cfg["encoder"]["perceiver-tf"]["layer_norm"] = str(args.layer_norm_type) |
|
if args.ff_layer_type != None: |
|
model_cfg["encoder"]["perceiver-tf"]["ff_layer_type"] = str(args.ff_layer_type) |
|
if args.ff_widening_factor != None: |
|
model_cfg["encoder"]["perceiver-tf"]["ff_widening_factor"] = int(args.ff_widening_factor) |
|
if args.moe_num_experts != None: |
|
model_cfg["encoder"]["perceiver-tf"]["moe_num_experts"] = int(args.moe_num_experts) |
|
if args.moe_topk != None: |
|
model_cfg["encoder"]["perceiver-tf"]["moe_topk"] = int(args.moe_topk) |
|
if args.hidden_act != None: |
|
model_cfg["encoder"]["perceiver-tf"]["hidden_act"] = str(args.hidden_act) |
|
if args.rotary_type != None: |
|
assert len( |
|
args.rotary_type |
|
) == 3, "rotary_type must be a 3-letter string (e.g. 'ppl': 'pixel' for SCA, 'pixel' for latent, 'lang' for temporal transformer)" |
|
model_cfg["encoder"]["perceiver-tf"]["rotary_type_sca"] = str(args.rotary_type)[0] |
|
model_cfg["encoder"]["perceiver-tf"]["rotary_type_latent"] = str(args.rotary_type)[1] |
|
model_cfg["encoder"]["perceiver-tf"]["rotary_type_temporal"] = str(args.rotary_type)[2] |
|
if args.rope_apply_to_keys != None: |
|
model_cfg["encoder"]["perceiver-tf"]["rope_apply_to_keys"] = bool(args.rope_apply_to_keys) |
|
if args.rope_partial_pe != None: |
|
model_cfg["encoder"]["perceiver-tf"]["rope_partial_pe"] = bool(args.rope_partial_pe) |
|
|
|
if args.decoder_ff_layer_type != None: |
|
model_cfg["decoder"][model_cfg["decoder_type"]]["ff_layer_type"] = str(args.decoder_ff_layer_type) |
|
if args.decoder_ff_widening_factor != None: |
|
model_cfg["decoder"][model_cfg["decoder_type"]]["ff_widening_factor"] = int(args.decoder_ff_widening_factor) |
|
|
|
if args.event_length != None: |
|
model_cfg["event_length"] = int(args.event_length) |
|
|
|
if stage == 'train': |
|
if args.encoder_dropout_rate != None: |
|
model_cfg["encoder"][model_cfg["encoder_type"]]["dropout_rate"] = float(args.encoder_dropout_rate) |
|
if args.decoder_dropout_rate != None: |
|
model_cfg["decoder"][model_cfg["decoder_type"]]["dropout_rate"] = float(args.decoder_dropout_rate) |
|
|
|
return shared_cfg, audio_cfg, model_cfg |
|
|