Spaces:
Running
on
A10G
Running
on
A10G
File size: 3,376 Bytes
0a3525d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
defaults:
- base
- _self_
project: vq-gan-pretrain
# Lightning Trainer
trainer:
accelerator: gpu
devices: auto
precision: bf16-mixed
max_steps: 1_000_000
val_check_interval: 5000
strategy: ddp_find_unused_parameters_true
sample_rate: 44100
hop_length: 512
num_mels: 128
n_fft: 2048
win_length: 2048
# Dataset Configuration
train_dataset:
_target_: torch.utils.data.ConcatDataset
datasets:
- _target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/gigaspeech/vq_train_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}
slice_frames: 512
- _target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/sft/vq_train_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}
slice_frames: 512
val_dataset:
_target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/sft/vq_val_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}
data:
_target_: fish_speech.datasets.vqgan.VQGANDataModule
train_dataset: ${train_dataset}
val_dataset: ${val_dataset}
num_workers: 4
batch_size: 32
val_batch_size: 32
# Model Configuration
model:
_target_: fish_speech.models.vqgan.VQGAN
sampling_rate: ${sample_rate}
weight_adv: 0.2
weight_vq: 1.0
weight_mel: 1.0
freeze_encoder: false
encoder:
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
input_channels: ${num_mels}
residual_channels: 768
residual_layers: 20
dilation_cycle: 4
quantizer:
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
input_dim: 768
n_codebooks: 1
n_groups: 2
levels: [8, 5, 5, 5]
decoder:
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
output_channels: ${num_mels}
residual_channels: 768
residual_layers: 20
dilation_cycle: 4
condition_channels: 768
discriminator:
_target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
vocoder:
_target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
ckpt_path: null # You may download the pretrained vocoder and set the path here
encode_mel_transform:
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
sample_rate: ${sample_rate}
n_fft: ${n_fft}
hop_length: ${hop_length}
win_length: ${win_length}
n_mels: ${num_mels}
f_min: 0.0
f_max: 8000.0
gt_mel_transform:
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
sample_rate: ${sample_rate}
n_fft: ${n_fft}
hop_length: ${hop_length}
win_length: ${win_length}
n_mels: ${num_mels}
optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 1e-4
betas: [0.8, 0.99]
eps: 1e-5
weight_decay: 0.01
lr_scheduler:
_target_: torch.optim.lr_scheduler.LambdaLR
_partial_: true
lr_lambda:
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
_partial_: true
num_warmup_steps: 100
num_training_steps: ${trainer.max_steps}
final_lr_ratio: 0
callbacks:
model_summary:
_target_: lightning.pytorch.callbacks.ModelSummary
max_depth: 1
model_checkpoint:
every_n_train_steps: ${trainer.val_check_interval}
grad_norm_monitor:
sub_module:
- encoder
- decoder
- quantizer
- discriminator
|