Spaces:
Running
on
A10G
Running
on
A10G
defaults: | |
- base | |
- _self_ | |
project: vq-gan-pretrain | |
# Lightning Trainer | |
trainer: | |
accelerator: gpu | |
devices: auto | |
precision: bf16-mixed | |
max_steps: 1_000_000 | |
val_check_interval: 5000 | |
strategy: ddp_find_unused_parameters_true | |
sample_rate: 44100 | |
hop_length: 512 | |
num_mels: 128 | |
n_fft: 2048 | |
win_length: 2048 | |
# Dataset Configuration | |
train_dataset: | |
_target_: torch.utils.data.ConcatDataset | |
datasets: | |
- _target_: fish_speech.datasets.vqgan.VQGANDataset | |
filelist: data/gigaspeech/vq_train_filelist.txt | |
sample_rate: ${sample_rate} | |
hop_length: ${hop_length} | |
slice_frames: 512 | |
- _target_: fish_speech.datasets.vqgan.VQGANDataset | |
filelist: data/sft/vq_train_filelist.txt | |
sample_rate: ${sample_rate} | |
hop_length: ${hop_length} | |
slice_frames: 512 | |
val_dataset: | |
_target_: fish_speech.datasets.vqgan.VQGANDataset | |
filelist: data/sft/vq_val_filelist.txt | |
sample_rate: ${sample_rate} | |
hop_length: ${hop_length} | |
data: | |
_target_: fish_speech.datasets.vqgan.VQGANDataModule | |
train_dataset: ${train_dataset} | |
val_dataset: ${val_dataset} | |
num_workers: 4 | |
batch_size: 32 | |
val_batch_size: 32 | |
# Model Configuration | |
model: | |
_target_: fish_speech.models.vqgan.VQGAN | |
sampling_rate: ${sample_rate} | |
weight_adv: 0.2 | |
weight_vq: 1.0 | |
weight_mel: 1.0 | |
freeze_encoder: false | |
encoder: | |
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet | |
input_channels: ${num_mels} | |
residual_channels: 768 | |
residual_layers: 20 | |
dilation_cycle: 4 | |
quantizer: | |
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize | |
input_dim: 768 | |
n_codebooks: 1 | |
n_groups: 2 | |
levels: [8, 5, 5, 5] | |
decoder: | |
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet | |
output_channels: ${num_mels} | |
residual_channels: 768 | |
residual_layers: 20 | |
dilation_cycle: 4 | |
condition_channels: 768 | |
discriminator: | |
_target_: fish_speech.models.vqgan.modules.discriminator.Discriminator | |
vocoder: | |
_target_: fish_speech.models.vqgan.modules.firefly.FireflyBase | |
ckpt_path: null # You may download the pretrained vocoder and set the path here | |
encode_mel_transform: | |
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram | |
sample_rate: ${sample_rate} | |
n_fft: ${n_fft} | |
hop_length: ${hop_length} | |
win_length: ${win_length} | |
n_mels: ${num_mels} | |
f_min: 0.0 | |
f_max: 8000.0 | |
gt_mel_transform: | |
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram | |
sample_rate: ${sample_rate} | |
n_fft: ${n_fft} | |
hop_length: ${hop_length} | |
win_length: ${win_length} | |
n_mels: ${num_mels} | |
optimizer: | |
_target_: torch.optim.AdamW | |
_partial_: true | |
lr: 1e-4 | |
betas: [0.8, 0.99] | |
eps: 1e-5 | |
weight_decay: 0.01 | |
lr_scheduler: | |
_target_: torch.optim.lr_scheduler.LambdaLR | |
_partial_: true | |
lr_lambda: | |
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda | |
_partial_: true | |
num_warmup_steps: 100 | |
num_training_steps: ${trainer.max_steps} | |
final_lr_ratio: 0 | |
callbacks: | |
model_summary: | |
_target_: lightning.pytorch.callbacks.ModelSummary | |
max_depth: 1 | |
model_checkpoint: | |
every_n_train_steps: ${trainer.val_check_interval} | |
grad_norm_monitor: | |
sub_module: | |
- encoder | |
- decoder | |
- quantizer | |
- discriminator | |