File size: 2,544 Bytes
0a3525d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69e8a46
0a3525d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Base configuration for training a model
paths:
  run_dir: results/${project}
  ckpt_dir: ${paths.run_dir}/checkpoints

hydra:
  run:
    dir: ${paths.run_dir}

# Lightning Trainer
trainer:
  _target_: lightning.pytorch.trainer.Trainer

  default_root_dir: ${paths.run_dir}
  accelerator: gpu
  num_nodes: 1
  devices: auto
  strategy:
    _target_: lightning.pytorch.strategies.DDPStrategy
    process_group_backend: nccl  # This should be override when training on windows

  precision: bf16-mixed

  # disable validation by epoch end
  check_val_every_n_epoch: null
  val_check_interval: 5000
  max_steps: 100_000

  # Use torch.backends.cudnn.benchmark to speed up training
  benchmark: true

# Callbacks
callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    dirpath: ${paths.ckpt_dir}
    filename: "step_{step:09d}"
    save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
    save_top_k: 5 # save 5 latest checkpoints
    monitor: step # use step to monitor checkpoints
    mode: max # save the latest checkpoint with the highest global_step
    every_n_epochs: null # don't save checkpoints by epoch end
    every_n_train_steps: 5000 # save checkpoints every 5000 steps
    auto_insert_metric_name: false

  model_summary:
    _target_: lightning.pytorch.callbacks.ModelSummary
    max_depth: 2 # the maximum depth of layer nesting that the summary will include

  learning_rate_monitor:
    _target_: lightning.pytorch.callbacks.LearningRateMonitor
    logging_interval: step
    log_momentum: false

  grad_norm_monitor:
    _target_: fish_speech.callbacks.GradNormMonitor
    norm_type: 2
    logging_interval: step

# Logger
logger:
  tensorboard:
    _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
    save_dir: "${paths.run_dir}/tensorboard/"
    name: null
    log_graph: false
    default_hp_metric: true
    prefix: ""

  # wandb:
  #   _target_: lightning.pytorch.loggers.wandb.WandbLogger
  #   # name: "" # name of the run (normally generated by wandb)
  #   save_dir: "${paths.run_dir}"
  #   offline: False
  #   id: null # pass correct id to resume experiment!
  #   anonymous: null # enable anonymous logging
  #   project: "fish-speech"
  #   log_model: False # upload lightning ckpts
  #   prefix: "" # a string to put at the beginning of metric keys
  #   # entity: "" # set to name of your wandb team
  #   group: ""
  #   tags: ["vq", "hq", "finetune"]
  #   job_type: ""
    
# Loop
train: true
test: false