File size: 4,608 Bytes
2dbdb17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
model:
  base_learning_rate: 5.0e-05
  target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    image_size: 64
    channels: 4
    cond_stage_trainable: false
    conditioning_key: hybrid
    scale_factor: 0.18215
    monitor: val/loss_simple_ema
    finetune_keys: null
    use_ema: False

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        use_checkpoint: True
        image_size: 32 # unused
        in_channels: 9
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_head_channels: 64 # need to fix for flash-attn
        use_spatial_transformer: True
        use_linear_in_transformer: True
        transformer_depth: 1
        context_dim: 1024
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          #attn_type: "vanilla-xformers"
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
            - 1
            - 2
            - 4
            - 4
          num_res_blocks: 2
          attn_resolutions: [ ]
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
      params:
        freeze: True
        layer: "penultimate"


data:
  target: ldm.data.laion.WebDataModuleFromConfig
  params:
    tar_base: null  # for concat as in LAION-A
    p_unsafe_threshold: 0.1
    filter_word_list: "data/filters.yaml"
    max_pwatermark: 0.45
    batch_size: 8
    num_workers: 6
    multinode: True
    min_size: 512
    train:
      shards:
        - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
        - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
        - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
        - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
        - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -"  #{00000-94333}.tar"
      shuffle: 10000
      image_key: jpg
      image_transforms:
      - target: torchvision.transforms.Resize
        params:
          size: 512
          interpolation: 3
      - target: torchvision.transforms.RandomCrop
        params:
          size: 512
      postprocess:
        target: ldm.data.laion.AddMask
        params:
          mode: "512train-large"
          p_drop: 0.25
    # NOTE use enough shards to avoid empty validation loops in workers
    validation:
      shards:
        - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
      shuffle: 0
      image_key: jpg
      image_transforms:
      - target: torchvision.transforms.Resize
        params:
          size: 512
          interpolation: 3
      - target: torchvision.transforms.CenterCrop
        params:
          size: 512
      postprocess:
        target: ldm.data.laion.AddMask
        params:
          mode: "512train-large"
          p_drop: 0.25

lightning:
  find_unused_parameters: True
  modelcheckpoint:
    params:
      every_n_train_steps: 5000

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 10000

    image_logger:
      target: main.ImageLogger
      params:
        enable_autocast: False
        disabled: False
        batch_frequency: 1000
        max_images: 4
        increase_log_steps: False
        log_first_step: False
        log_images_kwargs:
          use_ema_scope: False
          inpaint: False
          plot_progressive_rows: False
          plot_diffusion_rows: False
          N: 4
          unconditional_guidance_scale: 5.0
          unconditional_guidance_label: [""]
          ddim_steps: 50  # todo check these out for depth2img,
          ddim_eta: 0.0   # todo check these out for depth2img,

  trainer:
    benchmark: True
    val_check_interval: 5000000
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1