fish-speech-1 / fish_speech /configs /vqgan_pretrain.yaml
lengyue233's picture
Init hf space integration
0a3525d verified
defaults:
- base
- _self_
project: vq-gan-pretrain
# Lightning Trainer
trainer:
accelerator: gpu
devices: auto
precision: bf16-mixed
max_steps: 1_000_000
val_check_interval: 5000
strategy: ddp_find_unused_parameters_true
sample_rate: 44100
hop_length: 512
num_mels: 128
n_fft: 2048
win_length: 2048
# Dataset Configuration
train_dataset:
_target_: torch.utils.data.ConcatDataset
datasets:
- _target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/gigaspeech/vq_train_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}
slice_frames: 512
- _target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/sft/vq_train_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}
slice_frames: 512
val_dataset:
_target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/sft/vq_val_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}
data:
_target_: fish_speech.datasets.vqgan.VQGANDataModule
train_dataset: ${train_dataset}
val_dataset: ${val_dataset}
num_workers: 4
batch_size: 32
val_batch_size: 32
# Model Configuration
model:
_target_: fish_speech.models.vqgan.VQGAN
sampling_rate: ${sample_rate}
weight_adv: 0.2
weight_vq: 1.0
weight_mel: 1.0
freeze_encoder: false
encoder:
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
input_channels: ${num_mels}
residual_channels: 768
residual_layers: 20
dilation_cycle: 4
quantizer:
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
input_dim: 768
n_codebooks: 1
n_groups: 2
levels: [8, 5, 5, 5]
decoder:
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
output_channels: ${num_mels}
residual_channels: 768
residual_layers: 20
dilation_cycle: 4
condition_channels: 768
discriminator:
_target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
vocoder:
_target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
ckpt_path: null # You may download the pretrained vocoder and set the path here
encode_mel_transform:
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
sample_rate: ${sample_rate}
n_fft: ${n_fft}
hop_length: ${hop_length}
win_length: ${win_length}
n_mels: ${num_mels}
f_min: 0.0
f_max: 8000.0
gt_mel_transform:
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
sample_rate: ${sample_rate}
n_fft: ${n_fft}
hop_length: ${hop_length}
win_length: ${win_length}
n_mels: ${num_mels}
optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 1e-4
betas: [0.8, 0.99]
eps: 1e-5
weight_decay: 0.01
lr_scheduler:
_target_: torch.optim.lr_scheduler.LambdaLR
_partial_: true
lr_lambda:
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
_partial_: true
num_warmup_steps: 100
num_training_steps: ${trainer.max_steps}
final_lr_ratio: 0
callbacks:
model_summary:
_target_: lightning.pytorch.callbacks.ModelSummary
max_depth: 1
model_checkpoint:
every_n_train_steps: ${trainer.val_check_interval}
grad_norm_monitor:
sub_module:
- encoder
- decoder
- quantizer
- discriminator