File size: 3,183 Bytes
0a3525d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
defaults:
  - base
  - _self_

project: vq-gan-finetune
ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
resume_weights_only: true

# Lightning Trainer
trainer:
  accelerator: gpu
  devices: auto
  precision: bf16-mixed
  max_steps: 100_000
  val_check_interval: 5000
  strategy: ddp_find_unused_parameters_true

sample_rate: 44100
hop_length: 512
num_mels: 128
n_fft: 2048
win_length: 2048
freeze_encoder: true

# Dataset Configuration
train_dataset:
  _target_: fish_speech.datasets.vqgan.VQGANDataset
  filelist: data/filelist.train.txt
  sample_rate: ${sample_rate}
  hop_length: ${hop_length}
  slice_frames: 512

val_dataset:
  _target_: fish_speech.datasets.vqgan.VQGANDataset
  filelist: data/filelist.val.txt
  sample_rate: ${sample_rate}
  hop_length: ${hop_length}

data:
  _target_: fish_speech.datasets.vqgan.VQGANDataModule
  train_dataset: ${train_dataset}
  val_dataset: ${val_dataset}
  num_workers: 4
  batch_size: 16
  val_batch_size: 16

# Model Configuration
model:
  _target_: fish_speech.models.vqgan.VQGAN

  sampling_rate: ${sample_rate}
  weight_adv: 0.2
  weight_vq: 1.0
  weight_mel: 1.0
  freeze_encoder: false

  encoder:
    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
    input_channels: ${num_mels}
    residual_channels: 768
    residual_layers: 20
    dilation_cycle: 4
  
  quantizer:
    _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
    input_dim: 768
    n_codebooks: 1
    n_groups: 2
    levels: [8, 5, 5, 5]

  decoder:
    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
    output_channels: ${num_mels}
    residual_channels: 768
    residual_layers: 20
    dilation_cycle: 4
    condition_channels: 768
  
  discriminator:
    _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator

  vocoder:
    _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
    ckpt_path: null # You may download the pretrained vocoder and set the path here

  encode_mel_transform:
    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
    sample_rate: ${sample_rate}
    n_fft: ${n_fft}
    hop_length: ${hop_length}
    win_length: ${win_length}
    n_mels: ${num_mels}
    f_min: 0.0
    f_max: 8000.0

  gt_mel_transform:
    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
    sample_rate: ${sample_rate}
    n_fft: ${n_fft}
    hop_length: ${hop_length}
    win_length: ${win_length}
    n_mels: ${num_mels}

  optimizer:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 4e-5
    betas: [0.8, 0.99]
    eps: 1e-5
    weight_decay: 0.01

  lr_scheduler:
    _target_: torch.optim.lr_scheduler.LambdaLR
    _partial_: true
    lr_lambda:
      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
      _partial_: true
      num_warmup_steps: 100
      num_training_steps: ${trainer.max_steps}
      final_lr_ratio: 0

callbacks:
  model_summary:
    _target_: lightning.pytorch.callbacks.ModelSummary
    max_depth: 1

  model_checkpoint:
    every_n_train_steps: ${trainer.val_check_interval}

  grad_norm_monitor:
    sub_module: 
      - encoder
      - decoder
      - quantizer
      - discriminator