defaults: - base - _self_ project: vq-gan-pretrain # Lightning Trainer trainer: accelerator: gpu devices: auto precision: bf16-mixed max_steps: 1_000_000 val_check_interval: 5000 strategy: ddp_find_unused_parameters_true sample_rate: 44100 hop_length: 512 num_mels: 128 n_fft: 2048 win_length: 2048 # Dataset Configuration train_dataset: _target_: torch.utils.data.ConcatDataset datasets: - _target_: fish_speech.datasets.vqgan.VQGANDataset filelist: data/gigaspeech/vq_train_filelist.txt sample_rate: ${sample_rate} hop_length: ${hop_length} slice_frames: 512 - _target_: fish_speech.datasets.vqgan.VQGANDataset filelist: data/sft/vq_train_filelist.txt sample_rate: ${sample_rate} hop_length: ${hop_length} slice_frames: 512 val_dataset: _target_: fish_speech.datasets.vqgan.VQGANDataset filelist: data/sft/vq_val_filelist.txt sample_rate: ${sample_rate} hop_length: ${hop_length} data: _target_: fish_speech.datasets.vqgan.VQGANDataModule train_dataset: ${train_dataset} val_dataset: ${val_dataset} num_workers: 4 batch_size: 32 val_batch_size: 32 # Model Configuration model: _target_: fish_speech.models.vqgan.VQGAN sampling_rate: ${sample_rate} weight_adv: 0.2 weight_vq: 1.0 weight_mel: 1.0 freeze_encoder: false encoder: _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet input_channels: ${num_mels} residual_channels: 768 residual_layers: 20 dilation_cycle: 4 quantizer: _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize input_dim: 768 n_codebooks: 1 n_groups: 2 levels: [8, 5, 5, 5] decoder: _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet output_channels: ${num_mels} residual_channels: 768 residual_layers: 20 dilation_cycle: 4 condition_channels: 768 discriminator: _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator vocoder: _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase ckpt_path: null # You may download the pretrained vocoder and set the path here encode_mel_transform: _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram sample_rate: ${sample_rate} n_fft: ${n_fft} hop_length: ${hop_length} win_length: ${win_length} n_mels: ${num_mels} f_min: 0.0 f_max: 8000.0 gt_mel_transform: _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram sample_rate: ${sample_rate} n_fft: ${n_fft} hop_length: ${hop_length} win_length: ${win_length} n_mels: ${num_mels} optimizer: _target_: torch.optim.AdamW _partial_: true lr: 1e-4 betas: [0.8, 0.99] eps: 1e-5 weight_decay: 0.01 lr_scheduler: _target_: torch.optim.lr_scheduler.LambdaLR _partial_: true lr_lambda: _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda _partial_: true num_warmup_steps: 100 num_training_steps: ${trainer.max_steps} final_lr_ratio: 0 callbacks: model_summary: _target_: lightning.pytorch.callbacks.ModelSummary max_depth: 1 model_checkpoint: every_n_train_steps: ${trainer.val_check_interval} grad_norm_monitor: sub_module: - encoder - decoder - quantizer - discriminator