forSubAnony commited on
Commit
1cae162
1 Parent(s): ffcc5c1
Files changed (28) hide show
  1. code_for_ade20k/crack_config_utils/__init__.py +0 -0
  2. code_for_ade20k/crack_config_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  3. code_for_ade20k/crack_config_utils/__pycache__/parse_args.cpython-39.pyc +0 -0
  4. code_for_ade20k/crack_config_utils/__pycache__/parse_args_ade.cpython-39.pyc +0 -0
  5. code_for_ade20k/crack_config_utils/__pycache__/utils.cpython-39.pyc +0 -0
  6. code_for_ade20k/crack_config_utils/__pycache__/utils_ade.cpython-39.pyc +0 -0
  7. code_for_ade20k/crack_config_utils/parse_args_ade.py +253 -0
  8. code_for_ade20k/crack_config_utils/utils_ade.py +98 -0
  9. code_for_ade20k/dataset/__init__.py +0 -0
  10. code_for_ade20k/dataset/ade20k.py +288 -0
  11. code_for_ade20k/diffusion_module/__pycache__/nn.cpython-39.pyc +0 -0
  12. code_for_ade20k/diffusion_module/__pycache__/unet.cpython-39.pyc +0 -0
  13. code_for_ade20k/diffusion_module/__pycache__/unet_2d_blocks.cpython-39.pyc +0 -0
  14. code_for_ade20k/diffusion_module/__pycache__/unet_2d_sdm.cpython-39.pyc +0 -0
  15. code_for_ade20k/diffusion_module/nn.py +183 -0
  16. code_for_ade20k/diffusion_module/unet.py +1260 -0
  17. code_for_ade20k/diffusion_module/unet_2d_blocks.py +0 -0
  18. code_for_ade20k/diffusion_module/unet_2d_sdm.py +357 -0
  19. code_for_ade20k/diffusion_module/utils/LSDMPipeline_expandDataset.py +164 -0
  20. code_for_ade20k/diffusion_module/utils/Pipline.py +361 -0
  21. code_for_ade20k/diffusion_module/utils/__pycache__/LSDMPipeline_expandDataset.cpython-39.pyc +0 -0
  22. code_for_ade20k/diffusion_module/utils/__pycache__/Pipline.cpython-310.pyc +0 -0
  23. code_for_ade20k/diffusion_module/utils/__pycache__/Pipline.cpython-39.pyc +0 -0
  24. code_for_ade20k/diffusion_module/utils/__pycache__/loss.cpython-39.pyc +0 -0
  25. code_for_ade20k/diffusion_module/utils/loss.py +149 -0
  26. code_for_ade20k/diffusion_module/utils/noise_sampler.py +16 -0
  27. code_for_ade20k/diffusion_module/utils/scheduler_factory.py +300 -0
  28. code_for_ade20k/train_SDM_LDM_ade.py +509 -0
code_for_ade20k/crack_config_utils/__init__.py ADDED
File without changes
code_for_ade20k/crack_config_utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (149 Bytes). View file
 
code_for_ade20k/crack_config_utils/__pycache__/parse_args.cpython-39.pyc ADDED
Binary file (7.08 kB). View file
 
code_for_ade20k/crack_config_utils/__pycache__/parse_args_ade.cpython-39.pyc ADDED
Binary file (6.73 kB). View file
 
code_for_ade20k/crack_config_utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.13 kB). View file
 
code_for_ade20k/crack_config_utils/__pycache__/utils_ade.cpython-39.pyc ADDED
Binary file (2.71 kB). View file
 
code_for_ade20k/crack_config_utils/parse_args_ade.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os
2
+ def parse_args():
3
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
4
+ parser.add_argument(
5
+ "--pretrained_model_name_or_path",
6
+ type=str,
7
+ default="HHRI-SSL/LDM-unconditioned",
8
+ required=False,
9
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
10
+ )
11
+ parser.add_argument(
12
+ "--data_root",
13
+ default="/data/leiqin/diffusion/huggingface_diffusers/dataset/ADE2016/ADEChallengeData2016",
14
+ help = (
15
+ "data_root for ADE20K"
16
+ )
17
+ )
18
+
19
+ parser.add_argument(
20
+ "--resume_dir",
21
+ type=str,
22
+ # default="/data/leiqin/diffusion/Data_generation/SLDM/VQVAE-official-SDM-learnvar-cityspace/checkpoint-1200/unet",
23
+ default=None,
24
+ required=False,
25
+ help="Resume the checkpoint",
26
+ )
27
+ parser.add_argument(
28
+ "--revision",
29
+ type=str,
30
+ default=None,
31
+ required=False,
32
+ help="Revision of pretrained model identifier from huggingface.co/models.",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--segmap_channels",
37
+ type=int,
38
+ default=151,
39
+ help=(
40
+ "num of mask class"
41
+ ),
42
+ )
43
+ parser.add_argument(
44
+ "--max_train_samples",
45
+ type=int,
46
+ default=None,
47
+ help=(
48
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
49
+ "value if set."
50
+ ),
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--output_dir",
55
+ type=str,
56
+ default="SLDM-VAE-15-ade20k",
57
+ help="The output directory where the model predictions and checkpoints will be written.",
58
+ )
59
+
60
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
61
+ parser.add_argument(
62
+ "--resolution",
63
+ type=int,
64
+ default=512,
65
+ help=(
66
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
67
+ " resolution"
68
+ ),
69
+ )
70
+
71
+ parser.add_argument(
72
+ "--train_batch_size", type=int, default=54, help="Batch size (per device) for the training dataloader."
73
+ )
74
+
75
+ parser.add_argument("--num_train_epochs", type=int, default=2000)
76
+
77
+ parser.add_argument(
78
+ "--max_train_steps",
79
+ type=int,
80
+ default=None,
81
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--gradient_accumulation_steps",
86
+ type=int,
87
+ default=1,
88
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
89
+ )
90
+ parser.add_argument(
91
+ "--gradient_checkpointing",
92
+ action="store_true",
93
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
94
+ )
95
+ parser.add_argument(
96
+ "--learning_rate",
97
+ type=float,
98
+ default=1e-4,
99
+ help="Initial learning rate (after the potential warmup period) to use.",
100
+ )
101
+ parser.add_argument(
102
+ "--scale_lr",
103
+ action="store_true",
104
+ default=False,
105
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
106
+ )
107
+ parser.add_argument(
108
+ "--lr_scheduler",
109
+ type=str,
110
+ default="constant",
111
+ help=(
112
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
113
+ ' "constant", "constant_with_warmup"]'
114
+ ),
115
+ )
116
+ parser.add_argument(
117
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
118
+ )
119
+ parser.add_argument(
120
+ "--snr_gamma",
121
+ type=float,
122
+ default=5.0,
123
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
124
+ "More details here: https://arxiv.org/abs/2303.09556.",
125
+ )
126
+
127
+ parser.add_argument("--use_ema", action="store_true", default=True,help="Whether to use EMA model.")
128
+ parser.add_argument(
129
+ "--non_ema_revision",
130
+ type=str,
131
+ default=None,
132
+ required=False,
133
+ help=(
134
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
135
+ " remote repository specified with --pretrained_model_name_or_path."
136
+ ),
137
+ )
138
+ parser.add_argument(
139
+ "--dataloader_num_workers",
140
+ type=int,
141
+ default=64,
142
+ help=(
143
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
144
+ ),
145
+ )
146
+
147
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
148
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
149
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
150
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
151
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
152
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
153
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
154
+ parser.add_argument(
155
+ "--hub_model_id",
156
+ type=str,
157
+ default=None,
158
+ help="The name of the repository to keep in sync with the local `output_dir`.",
159
+ )
160
+ parser.add_argument(
161
+ "--logging_dir",
162
+ type=str,
163
+ default="logs",
164
+ help=(
165
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
166
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
167
+ ),
168
+ )
169
+ parser.add_argument(
170
+ "--mixed_precision",
171
+ type=str,
172
+ default="fp16",
173
+ choices=["no", "fp16", "bf16"],
174
+ help=(
175
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
176
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
177
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
178
+ ),
179
+ )
180
+ parser.add_argument(
181
+ "--report_to",
182
+ type=str,
183
+ default="tensorboard",
184
+ help=(
185
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
186
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
187
+ ),
188
+ )
189
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
190
+ parser.add_argument(
191
+ "--checkpointing_steps",
192
+ type=int,
193
+ default=400,
194
+ help=(
195
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
196
+ " training using `--resume_from_checkpoint`."
197
+ ),
198
+ )
199
+ parser.add_argument(
200
+ "--checkpoints_total_limit",
201
+ type=int,
202
+ default=None,
203
+ help=(
204
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
205
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
206
+ " for more docs"
207
+ ),
208
+ )
209
+ parser.add_argument(
210
+ "--resume_from_checkpoint",
211
+ type=str,
212
+ # default="checkpoint-4950",
213
+ default=None,
214
+ help=(
215
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
216
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
217
+ ),
218
+ )
219
+ parser.add_argument(
220
+ "--enable_xformers_memory_efficient_attention",default=True, action="store_true", help="Whether or not to use xformers."
221
+ )
222
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
223
+ parser.add_argument(
224
+ "--validation_epochs",
225
+ type=int,
226
+ default=2,
227
+ help="Run validation every X epochs.",
228
+ )
229
+
230
+ parser.add_argument(
231
+ "--tracker_project_name",
232
+ type=str,
233
+ default="SLDM-from-scratch",
234
+ help=(
235
+ "The `project_name` argument passed to Accelerator.init_trackers for"
236
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
237
+ ),
238
+ )
239
+
240
+ parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
241
+ parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
242
+ parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
243
+
244
+ args = parser.parse_args()
245
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
246
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
247
+ args.local_rank = env_local_rank
248
+
249
+ # default to using the same revision for the non-ema model if not specified
250
+ if args.non_ema_revision is None:
251
+ args.non_ema_revision = args.revision
252
+
253
+ return args
code_for_ade20k/crack_config_utils/utils_ade.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from PIL import Image
4
+ import numpy as np
5
+ from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
6
+ from diffusion_module.utils.Pipline import SDMLDMPipeline
7
+
8
+ def log_validation(vae, unet, noise_scheduler, accelerator, weight_dtype, data_ld,
9
+ resolution=512,g_step=2,save_dir="cityspace_test"):
10
+ scheduler = UniPCMultistepScheduler.from_config(noise_scheduler.config)
11
+ pipeline = SDMLDMPipeline(
12
+ vae=accelerator.unwrap_model(vae),
13
+ unet=accelerator.unwrap_model(unet),
14
+ scheduler=scheduler,
15
+ torch_dtype=weight_dtype,
16
+ resolution = resolution,
17
+ resolution_type="crack"
18
+ )
19
+
20
+ pipeline = pipeline.to(accelerator.device)
21
+ pipeline.set_progress_bar_config(disable=False)
22
+ pipeline.enable_xformers_memory_efficient_attention()
23
+
24
+ generator = None
25
+ for i ,batch in enumerate(data_ld):
26
+ if i > 2:
27
+ break
28
+ images = []
29
+ with torch.autocast("cuda"):
30
+ segmap = preprocess_input(batch[1]['label'], num_classes=151)
31
+ segmap = segmap.to("cuda").to(torch.float16)
32
+ # 暂时删除这个因为不想写绘图的函数,种类多太麻烦了
33
+ # segmap_clr = batch[1]['label_ori'][0].permute(0, 3, 1, 2) / 255.
34
+
35
+ image = pipeline(segmap=segmap[0][None,:], generator=generator,batch_size = 1,
36
+ num_inference_steps=50, s=1.5).images
37
+
38
+ #segmap_clr = segmap_clr.cpu()
39
+ #segmap_clr = segmap_clr[0].permute(1, 2, 0).numpy()
40
+ #segmap_clr = (segmap_clr * 255).astype('uint8')
41
+ # pil_image = Image.fromarray(segmap_clr)
42
+ # images.append(pil_image)
43
+ #print(image)
44
+ #image = pipeline(args.validation_prompts[i], num_inference_steps=50, generator=generator).images[0]
45
+
46
+ images.extend(image)
47
+ merge_images(images, i,accelerator,g_step)
48
+ del pipeline
49
+ torch.cuda.empty_cache()
50
+
51
+
52
+ def merge_images(images, val_step,accelerator,step):
53
+ for k, image in enumerate(images):
54
+ """
55
+ if k == 0:
56
+ filename = "{}_condition.png".format(val_step)
57
+ else:
58
+ filename = "{}_{}.png".format(val_step, k)
59
+ """
60
+ filename = "{}_{}.png".format(val_step, k)
61
+ # 更新的路径,包含'singles'文件夹
62
+ path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "singles", filename)
63
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
64
+
65
+ image.save(path)
66
+
67
+ # 创建一个新的画板来合并所有图像
68
+ total_width = sum(img.width for img in images)
69
+ max_height = max(img.height for img in images)
70
+ combined_image = Image.new('RGB', (total_width, max_height))
71
+
72
+ # 粘贴每张图像到画板上
73
+ x_offset = 0
74
+ for img in images:
75
+ # 转换灰度图像为RGB
76
+ if img.mode != 'RGB':
77
+ img = img.convert('RGB')
78
+ combined_image.paste(img, (x_offset, 0))
79
+ x_offset += img.width
80
+
81
+ # 保存合并后的图像,路径包含'merges'文件夹
82
+ merge_filename = "{}_merge.png".format(val_step)
83
+ merge_path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "merges", merge_filename)
84
+ os.makedirs(os.path.split(merge_path)[0], exist_ok=True)
85
+ combined_image.save(merge_path)
86
+
87
+ def preprocess_input(data, num_classes):
88
+ # move to GPU and change data types
89
+ data = data.to(dtype=torch.int64)
90
+
91
+ # create one-hot label map
92
+ label_map = data
93
+ bs, _, h, w = label_map.size()
94
+ input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device)
95
+ input_semantics = input_label.scatter_(1, label_map, 1.0)
96
+
97
+ return input_semantics
98
+
code_for_ade20k/dataset/__init__.py ADDED
File without changes
code_for_ade20k/dataset/ade20k.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import random
4
+
5
+ from PIL import Image
6
+ import blobfile as bf
7
+ import numpy as np
8
+ from torch.utils.data import DataLoader, Dataset
9
+
10
+
11
+ def load_data(
12
+ *,
13
+ dataset_mode,
14
+ data_dir,
15
+ batch_size,
16
+ image_size,
17
+ class_cond=False,
18
+ deterministic=False,
19
+ random_crop=True,
20
+ random_flip=True,
21
+ is_train=True,
22
+ ):
23
+ """
24
+
25
+ For a dataset, create a generator over (images, kwargs) pairs.
26
+
27
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
28
+ more keys, each of which map to a batched Tensor of their own.
29
+ The kwargs dict can be used for class labels, in which case the key is "y"
30
+ and the values are integer tensors of class labels.
31
+
32
+ :param data_dir: a dataset directory.
33
+ :param batch_size: the batch size of each returned pair.
34
+ :param image_size: the size to which images are resized.
35
+ :param class_cond: if True, include a "y" key in returned dicts for class
36
+ label. If classes are not available and this is true, an
37
+ exception will be raised.
38
+ :param deterministic: if True, yield results in a deterministic order.
39
+ :param random_crop: if True, randomly crop the images for augmentation.
40
+ :param random_flip: if True, randomly flip the images for augmentation.
41
+ """
42
+ if not data_dir:
43
+ raise ValueError("unspecified data directory")
44
+
45
+ if dataset_mode == 'cityscapes':
46
+ all_files = _list_image_files_recursively(os.path.join(data_dir, 'leftImg8bit', 'train' if is_train else 'val'))
47
+ labels_file = _list_image_files_recursively(os.path.join(data_dir, 'gtFine', 'train' if is_train else 'val'))
48
+ classes = [x for x in labels_file if x.endswith('_labelIds.png')]
49
+ instances = [x for x in labels_file if x.endswith('_instanceIds.png')]
50
+ elif dataset_mode == 'ade20k':
51
+ all_files = _list_image_files_recursively(os.path.join(data_dir, 'images', 'training' if is_train else 'validation'))
52
+ classes = _list_image_files_recursively(os.path.join(data_dir, 'annotations', 'training' if is_train else 'validation'))
53
+ instances = None
54
+ elif dataset_mode == 'celeba':
55
+ # The edge is computed by the instances.
56
+ # However, the edge get from the labels and the instances are the same on CelebA.
57
+ # You can take either as instance input
58
+ all_files = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'test', 'images'))
59
+ classes = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'test', 'labels'))
60
+ instances = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'test', 'labels'))
61
+ elif dataset_mode == "crack500":
62
+ all_files = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'validation', 'images'))
63
+ classes = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'validation','annotations'))
64
+ instances = None
65
+ elif dataset_mode == "thincrack":
66
+ all_files = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'train', 'images'))
67
+ classes = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'train','annotations'))
68
+ instances = None
69
+ else:
70
+ raise NotImplementedError('{} not implemented'.format(dataset_mode))
71
+
72
+ print("Len of Dataset:", len(all_files))
73
+
74
+ dataset = ImageDataset(
75
+ dataset_mode,
76
+ image_size,
77
+ all_files,
78
+ classes=classes,
79
+ instances=instances,
80
+ random_crop=random_crop,
81
+ random_flip=random_flip,
82
+ is_train=is_train
83
+ )
84
+
85
+ if deterministic:
86
+ loader = DataLoader(
87
+ dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
88
+ )
89
+ else:
90
+ loader = DataLoader(
91
+ dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
92
+ )
93
+ return loader, dataset
94
+
95
+
96
+ def _list_image_files_recursively(data_dir):
97
+ results = []
98
+ for entry in sorted(bf.listdir(data_dir)):
99
+ full_path = bf.join(data_dir, entry)
100
+ ext = entry.split(".")[-1]
101
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
102
+ results.append(full_path)
103
+ elif bf.isdir(full_path):
104
+ results.extend(_list_image_files_recursively(full_path))
105
+ return results
106
+
107
+
108
+ class ImageDataset(Dataset):
109
+ def __init__(
110
+ self,
111
+ dataset_mode,
112
+ resolution,
113
+ image_paths,
114
+ classes=None,
115
+ instances=None,
116
+ shard=0,
117
+ num_shards=1,
118
+ random_crop=False,
119
+ random_flip=True,
120
+ is_train=True
121
+ ):
122
+ super().__init__()
123
+ self.is_train = is_train
124
+ self.dataset_mode = dataset_mode
125
+ self.resolution = resolution
126
+ self.local_images = image_paths[shard:][::num_shards]
127
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
128
+ self.local_instances = None if instances is None else instances[shard:][::num_shards]
129
+ self.random_crop = random_crop
130
+ self.random_flip = random_flip
131
+
132
+ def __len__(self):
133
+ return len(self.local_images)
134
+
135
+ def __getitem__(self, idx):
136
+ path = self.local_images[idx]
137
+ with bf.BlobFile(path, "rb") as f:
138
+ pil_image = Image.open(f)
139
+ pil_image.load()
140
+ pil_image = pil_image.convert("RGB")
141
+
142
+ out_dict = {}
143
+ class_path = self.local_classes[idx]
144
+ with bf.BlobFile(class_path, "rb") as f:
145
+ pil_class = Image.open(f)
146
+ pil_class.load()
147
+ pil_class = pil_class.convert("L")
148
+
149
+ if self.local_instances is not None:
150
+ instance_path = self.local_instances[idx] # DEBUG: from classes to instances, may affect CelebA
151
+ with bf.BlobFile(instance_path, "rb") as f:
152
+ pil_instance = Image.open(f)
153
+ pil_instance.load()
154
+ pil_instance = pil_instance.convert("L")
155
+ else:
156
+ pil_instance = None
157
+
158
+ if self.dataset_mode == 'cityscapes':
159
+ arr_image, arr_class, arr_instance = resize_arr([pil_image, pil_class, pil_instance], self.resolution)
160
+ else:
161
+ if self.is_train:
162
+ if self.random_crop:
163
+ arr_image, arr_class, arr_instance = random_crop_arr([pil_image, pil_class, pil_instance], self.resolution)
164
+ else:
165
+ arr_image, arr_class, arr_instance = center_crop_arr([pil_image, pil_class, pil_instance], self.resolution)
166
+ else:
167
+ arr_image, arr_class, arr_instance = resize_arr([pil_image, pil_class, pil_instance], self.resolution, keep_aspect=False)
168
+
169
+ if self.random_flip and random.random() < 0.5:
170
+ arr_image = arr_image[:, ::-1].copy()
171
+ arr_class = arr_class[:, ::-1].copy()
172
+ arr_instance = arr_instance[:, ::-1].copy() if arr_instance is not None else None
173
+
174
+ arr_image = arr_image.astype(np.float32) / 127.5 - 1
175
+
176
+ out_dict['path'] = path
177
+ out_dict['label_ori'] = arr_class.copy()
178
+
179
+ if self.dataset_mode == 'ade20k':
180
+ arr_class = arr_class - 1
181
+ arr_class[arr_class == 255] = 150
182
+ elif self.dataset_mode == 'coco':
183
+ arr_class[arr_class == 255] = 182
184
+ elif self.dataset_mode == 'crack500':
185
+ arr_class[arr_class == 255] = 1
186
+ elif self.dataset_mode == 'thincrack':
187
+ arr_class[arr_class == 255] = 1
188
+
189
+
190
+ out_dict['label'] = arr_class[None, ]
191
+
192
+ if arr_instance is not None:
193
+ out_dict['instance'] = arr_instance[None, ]
194
+
195
+ return np.transpose(arr_image, [2, 0, 1]), out_dict
196
+
197
+
198
+ def resize_arr(pil_list, image_size, keep_aspect=True):
199
+ # We are not on a new enough PIL to support the `reducing_gap`
200
+ # argument, which uses BOX downsampling at powers of two first.
201
+ # Thus, we do it by hand to improve downsample quality.
202
+ pil_image, pil_class, pil_instance = pil_list
203
+
204
+ while min(*pil_image.size) >= 2 * image_size:
205
+ pil_image = pil_image.resize(
206
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
207
+ )
208
+
209
+ if keep_aspect:
210
+ scale = image_size / min(*pil_image.size)
211
+ pil_image = pil_image.resize(
212
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
213
+ )
214
+ else:
215
+ pil_image = pil_image.resize((image_size, image_size), resample=Image.BICUBIC)
216
+
217
+ pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST)
218
+ if pil_instance is not None:
219
+ pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST)
220
+
221
+ arr_image = np.array(pil_image)
222
+ arr_class = np.array(pil_class)
223
+ arr_instance = np.array(pil_instance) if pil_instance is not None else None
224
+ return arr_image, arr_class, arr_instance
225
+
226
+
227
+ def center_crop_arr(pil_list, image_size):
228
+ # We are not on a new enough PIL to support the `reducing_gap`
229
+ # argument, which uses BOX downsampling at powers of two first.
230
+ # Thus, we do it by hand to improve downsample quality.
231
+ pil_image, pil_class, pil_instance = pil_list
232
+
233
+ while min(*pil_image.size) >= 2 * image_size:
234
+ pil_image = pil_image.resize(
235
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
236
+ )
237
+
238
+ scale = image_size / min(*pil_image.size)
239
+ pil_image = pil_image.resize(
240
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
241
+ )
242
+
243
+ pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST)
244
+ if pil_instance is not None:
245
+ pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST)
246
+
247
+ arr_image = np.array(pil_image)
248
+ arr_class = np.array(pil_class)
249
+ arr_instance = np.array(pil_instance) if pil_instance is not None else None
250
+ crop_y = (arr_image.shape[0] - image_size) // 2
251
+ crop_x = (arr_image.shape[1] - image_size) // 2
252
+ return arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size],\
253
+ arr_class[crop_y: crop_y + image_size, crop_x: crop_x + image_size],\
254
+ arr_instance[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_instance is not None else None
255
+
256
+
257
+ def random_crop_arr(pil_list, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
258
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
259
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
260
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
261
+
262
+ # We are not on a new enough PIL to support the `reducing_gap`
263
+ # argument, which uses BOX downsampling at powers of two first.
264
+ # Thus, we do it by hand to improve downsample quality.
265
+ pil_image, pil_class, pil_instance = pil_list
266
+
267
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
268
+ pil_image = pil_image.resize(
269
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
270
+ )
271
+
272
+ scale = smaller_dim_size / min(*pil_image.size)
273
+ pil_image = pil_image.resize(
274
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
275
+ )
276
+
277
+ pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST)
278
+ if pil_instance is not None:
279
+ pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST)
280
+
281
+ arr_image = np.array(pil_image)
282
+ arr_class = np.array(pil_class)
283
+ arr_instance = np.array(pil_instance) if pil_instance is not None else None
284
+ crop_y = random.randrange(arr_image.shape[0] - image_size + 1)
285
+ crop_x = random.randrange(arr_image.shape[1] - image_size + 1)
286
+ return arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size],\
287
+ arr_class[crop_y: crop_y + image_size, crop_x: crop_x + image_size],\
288
+ arr_instance[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_instance is not None else None
code_for_ade20k/diffusion_module/__pycache__/nn.cpython-39.pyc ADDED
Binary file (6.24 kB). View file
 
code_for_ade20k/diffusion_module/__pycache__/unet.cpython-39.pyc ADDED
Binary file (29.1 kB). View file
 
code_for_ade20k/diffusion_module/__pycache__/unet_2d_blocks.cpython-39.pyc ADDED
Binary file (57.3 kB). View file
 
code_for_ade20k/diffusion_module/__pycache__/unet_2d_sdm.cpython-39.pyc ADDED
Binary file (10.7 kB). View file
 
code_for_ade20k/diffusion_module/nn.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ def convert_module_to_f16(l):
12
+ """
13
+ Convert primitive modules to float16.
14
+ """
15
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
16
+ l.weight.data = l.weight.data.half()
17
+ if l.bias is not None:
18
+ l.bias.data = l.bias.data.half()
19
+
20
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
21
+ class SiLU(nn.Module):
22
+ def forward(self, x):
23
+ return x * th.sigmoid(x)
24
+
25
+
26
+ class GroupNorm32(nn.GroupNorm):
27
+ def forward(self, x):
28
+ #print(x.float().dtype)
29
+ return super().forward(x).type(x.dtype)
30
+
31
+
32
+ def conv_nd(dims, *args, **kwargs):
33
+ """
34
+ Create a 1D, 2D, or 3D convolution module.
35
+ """
36
+ if dims == 1:
37
+ return nn.Conv1d(*args, **kwargs)
38
+ elif dims == 2:
39
+ return nn.Conv2d(*args, **kwargs)
40
+ elif dims == 3:
41
+ return nn.Conv3d(*args, **kwargs)
42
+ raise ValueError(f"unsupported dimensions: {dims}")
43
+
44
+
45
+ def linear(*args, **kwargs):
46
+ """
47
+ Create a linear module.
48
+ """
49
+ return nn.Linear(*args, **kwargs)
50
+
51
+
52
+ def avg_pool_nd(dims, *args, **kwargs):
53
+ """
54
+ Create a 1D, 2D, or 3D average pooling module.
55
+ """
56
+ if dims == 1:
57
+ return nn.AvgPool1d(*args, **kwargs)
58
+ elif dims == 2:
59
+ return nn.AvgPool2d(*args, **kwargs)
60
+ elif dims == 3:
61
+ return nn.AvgPool3d(*args, **kwargs)
62
+ raise ValueError(f"unsupported dimensions: {dims}")
63
+
64
+
65
+ def update_ema(target_params, source_params, rate=0.99):
66
+ """
67
+ Update target parameters to be closer to those of source parameters using
68
+ an exponential moving average.
69
+
70
+ :param target_params: the target parameter sequence.
71
+ :param source_params: the source parameter sequence.
72
+ :param rate: the EMA rate (closer to 1 means slower).
73
+ """
74
+ for targ, src in zip(target_params, source_params):
75
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
76
+
77
+
78
+ def zero_module(module):
79
+ """
80
+ Zero out the parameters of a module and return it.
81
+ """
82
+ for p in module.parameters():
83
+ p.detach().zero_()
84
+ return module
85
+
86
+
87
+ def scale_module(module, scale):
88
+ """
89
+ Scale the parameters of a module and return it.
90
+ """
91
+ for p in module.parameters():
92
+ p.detach().mul_(scale)
93
+ return module
94
+
95
+
96
+ def mean_flat(tensor):
97
+ """
98
+ Take the mean over all non-batch dimensions.
99
+ """
100
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
101
+
102
+
103
+ def normalization(channels):
104
+ """
105
+ Make a standard normalization layer.
106
+
107
+ :param channels: number of input channels.
108
+ :return: an nn.Module for normalization.
109
+ """
110
+ return GroupNorm32(32, channels)
111
+
112
+
113
+ def timestep_embedding(timesteps, dim, max_period=10000):
114
+ """
115
+ Create sinusoidal timestep embeddings.
116
+
117
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
118
+ These may be fractional.
119
+ :param dim: the dimension of the output.
120
+ :param max_period: controls the minimum frequency of the embeddings.
121
+ :return: an [N x dim] Tensor of positional embeddings.
122
+ """
123
+ half = dim // 2
124
+ freqs = th.exp(
125
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
126
+ ).to(device=timesteps.device)
127
+ args = timesteps[:, None].float() * freqs[None]
128
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
129
+ if dim % 2:
130
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
131
+ return embedding
132
+
133
+
134
+ def checkpoint(func, inputs, params, flag):
135
+ """
136
+ Evaluate a function without caching intermediate activations, allowing for
137
+ reduced memory at the expense of extra compute in the backward pass.
138
+
139
+ :param func: the function to evaluate.
140
+ :param inputs: the argument sequence to pass to `func`.
141
+ :param params: a sequence of parameters `func` depends on but does not
142
+ explicitly take as arguments.
143
+ :param flag: if False, disable gradient checkpointing.
144
+ """
145
+ if flag:
146
+ args = tuple(inputs) + tuple(params)
147
+ #return th.utils.checkpoint.checkpoint.apply(func, inputs)
148
+ return CheckpointFunction.apply(func, len(inputs), *args)
149
+ else:
150
+ return func(*inputs)
151
+
152
+
153
+ class CheckpointFunction(th.autograd.Function):
154
+ @staticmethod
155
+ def forward(ctx, run_function, length, *args):
156
+ ctx.run_function = run_function
157
+ ctx.input_tensors = list(args[:length])
158
+ ctx.input_params = list(args[length:])
159
+ breakpoint()
160
+ with th.no_grad():
161
+ output_tensors = ctx.run_function(*ctx.input_tensors)
162
+ return output_tensors
163
+
164
+ @staticmethod
165
+ def backward(ctx, *output_grads):
166
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
167
+ with th.enable_grad():
168
+ # Fixes a bug where the first op in run_function modifies the
169
+ # Tensor storage in place, which is not allowed for detach()'d
170
+ # Tensors.
171
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
172
+ breakpoint()
173
+ output_tensors = ctx(*shallow_copies)
174
+ input_grads = th.autograd.grad(
175
+ output_tensors,
176
+ ctx.input_tensors + ctx.input_params,
177
+ output_grads,
178
+ allow_unused=True,
179
+ )
180
+ del ctx.input_tensors
181
+ del ctx.input_params
182
+ del output_tensors
183
+ return (None, None) + input_grads
code_for_ade20k/diffusion_module/unet.py ADDED
@@ -0,0 +1,1260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .nn import (
11
+ SiLU,
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ convert_module_to_f16
20
+ )
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from dataclasses import dataclass
26
+
27
+ @dataclass
28
+ class UNet2DOutput(BaseOutput):
29
+ """
30
+ Args:
31
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
32
+ Hidden states output. Output of last layer of model.
33
+ """
34
+
35
+ sample: th.FloatTensor
36
+
37
+
38
+ class AttentionPool2d(nn.Module):
39
+ """
40
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ spacial_dim: int,
46
+ embed_dim: int,
47
+ num_heads_channels: int,
48
+ output_dim: int = None,
49
+ ):
50
+ super().__init__()
51
+ self.positional_embedding = nn.Parameter(
52
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
53
+ )
54
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
55
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
56
+ self.num_heads = embed_dim // num_heads_channels
57
+ self.attention = QKVAttention(self.num_heads)
58
+
59
+ def forward(self, x):
60
+ b, c, *_spatial = x.shape
61
+ x = x.reshape(b, c, -1) # NC(HW)
62
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
63
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
64
+ x = self.qkv_proj(x)
65
+ x = self.attention(x)
66
+ x = self.c_proj(x)
67
+ return x[:, :, 0]
68
+
69
+
70
+ class TimestepBlock(nn.Module):
71
+ """
72
+ Any module where forward() takes timestep embeddings as a second argument.
73
+ """
74
+
75
+ @abstractmethod
76
+ def forward(self, x, emb):
77
+ """
78
+ Apply the module to `x` given `emb` timestep embeddings.
79
+ """
80
+
81
+ class CondTimestepBlock(nn.Module):
82
+ """
83
+ Any module where forward() takes timestep embeddings as a second argument.
84
+ """
85
+
86
+ @abstractmethod
87
+ def forward(self, x, cond, emb):
88
+ """
89
+ Apply the module to `x` given `emb` timestep embeddings.
90
+ """
91
+
92
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock, CondTimestepBlock):
93
+ """
94
+ A sequential module that passes timestep embeddings to the children that
95
+ support it as an extra input.
96
+ """
97
+
98
+ def forward(self, x, cond, emb):
99
+ for layer in self:
100
+ if isinstance(layer, CondTimestepBlock):
101
+ x = layer(x, cond, emb)
102
+ elif isinstance(layer, TimestepBlock):
103
+ x = layer(x, emb)
104
+ else:
105
+ x = layer(x)
106
+ return x
107
+
108
+
109
+ class Upsample(nn.Module):
110
+ """
111
+ An upsampling layer with an optional convolution.
112
+
113
+ :param channels: channels in the inputs and outputs.
114
+ :param use_conv: a bool determining if a convolution is applied.
115
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
116
+ upsampling occurs in the inner-two dimensions.
117
+ """
118
+
119
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.out_channels = out_channels or channels
123
+ self.use_conv = use_conv
124
+ self.dims = dims
125
+ if use_conv:
126
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
127
+
128
+ def forward(self, x):
129
+ assert x.shape[1] == self.channels
130
+ if self.dims == 3:
131
+ x = F.interpolate(
132
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
133
+ )
134
+ else:
135
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
136
+ if self.use_conv:
137
+ x = self.conv(x)
138
+ return x
139
+
140
+
141
+ class Downsample(nn.Module):
142
+ """
143
+ A downsampling layer with an optional convolution.
144
+
145
+ :param channels: channels in the inputs and outputs.
146
+ :param use_conv: a bool determining if a convolution is applied.
147
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
148
+ downsampling occurs in the inner-two dimensions.
149
+ """
150
+
151
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
152
+ super().__init__()
153
+ self.channels = channels
154
+ self.out_channels = out_channels or channels
155
+ self.use_conv = use_conv
156
+ self.dims = dims
157
+ stride = 2 if dims != 3 else (1, 2, 2)
158
+ if use_conv:
159
+ self.op = conv_nd(
160
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
161
+ )
162
+ else:
163
+ assert self.channels == self.out_channels
164
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
165
+
166
+ def forward(self, x):
167
+ assert x.shape[1] == self.channels
168
+ return self.op(x)
169
+
170
+
171
+ class SPADEGroupNorm(nn.Module):
172
+ def __init__(self, norm_nc, label_nc, eps = 1e-5):
173
+ super().__init__()
174
+
175
+ self.norm = nn.GroupNorm(32, norm_nc, affine=False) # 32/16
176
+
177
+ self.eps = eps
178
+ nhidden = 128
179
+ self.mlp_shared = nn.Sequential(
180
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
181
+ nn.ReLU(),
182
+ )
183
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
184
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
185
+
186
+ def forward(self, x, segmap):
187
+ # Part 1. generate parameter-free normalized activations
188
+ x = self.norm(x)
189
+
190
+ # Part 2. produce scaling and bias conditioned on semantic map
191
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
192
+ actv = self.mlp_shared(segmap)
193
+ gamma = self.mlp_gamma(actv)
194
+ beta = self.mlp_beta(actv)
195
+
196
+ # apply scale and bias
197
+ return x * (1 + gamma) + beta
198
+
199
+ class AdaIN(nn.Module):
200
+ def __init__(self, num_features):
201
+ super().__init__()
202
+ self.instance_norm = th.nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
203
+
204
+ def forward(self, x, alpha, gamma):
205
+ assert x.shape[:2] == alpha.shape[:2] == gamma.shape[:2]
206
+ norm = self.instance_norm(x)
207
+ return alpha * norm + gamma
208
+
209
+ class RESAILGroupNorm(nn.Module):
210
+ def __init__(self, norm_nc, label_nc, guidance_nc, eps = 1e-5):
211
+ super().__init__()
212
+
213
+ self.norm = nn.GroupNorm(32, norm_nc, affine=False) # 32/16
214
+
215
+ # SPADE
216
+ self.eps = eps
217
+ nhidden = 128
218
+ self.mask_mlp_shared = nn.Sequential(
219
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
220
+ nn.ReLU(),
221
+ )
222
+
223
+ self.mask_mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
224
+ self.mask_mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
225
+
226
+
227
+ # Guidance
228
+
229
+ self.conv_s = th.nn.Conv2d(label_nc, nhidden * 2, 3, 2)
230
+ self.pool_s = th.nn.AdaptiveAvgPool2d(1)
231
+ self.conv_s2 = th.nn.Conv2d(nhidden * 2, nhidden * 2, 1, 1)
232
+
233
+ self.conv1 = th.nn.Conv2d(guidance_nc, nhidden, 3, 1, padding=1)
234
+ self.adaIn1 = AdaIN(norm_nc * 2)
235
+ self.relu1 = nn.ReLU()
236
+
237
+ self.conv2 = th.nn.Conv2d(nhidden, nhidden, 3, 1, padding=1)
238
+ self.adaIn2 = AdaIN(norm_nc * 2)
239
+ self.relu2 = nn.ReLU()
240
+ self.conv3 = th.nn.Conv2d(nhidden, nhidden, 3, 1, padding=1)
241
+
242
+ self.guidance_mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
243
+ self.guidance_mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
244
+
245
+ self.blending_gamma = nn.Parameter(th.zeros(1), requires_grad=True)
246
+ self.blending_beta = nn.Parameter(th.zeros(1), requires_grad=True)
247
+ self.norm_nc = norm_nc
248
+
249
+ def forward(self, x, segmap, guidance):
250
+ # Part 1. generate parameter-free normalized activations
251
+ x = self.norm(x)
252
+ # Part 2. produce scaling and bias conditioned on semantic map
253
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
254
+ mask_actv = self.mask_mlp_shared(segmap)
255
+ mask_gamma = self.mask_mlp_gamma(mask_actv)
256
+ mask_beta = self.mask_mlp_beta(mask_actv)
257
+
258
+
259
+ # Part 3. produce scaling and bias conditioned on feature guidance
260
+ guidance = F.interpolate(guidance, size=x.size()[2:], mode='bilinear')
261
+
262
+ f_s_1 = self.conv_s(segmap)
263
+ c1 = self.pool_s(f_s_1)
264
+ c2 = self.conv_s2(c1)
265
+
266
+ f1 = self.conv1(guidance)
267
+
268
+ f1 = self.adaIn1(f1, c1[:, : 128, ...], c1[:, 128:, ...])
269
+ f2 = self.relu1(f1)
270
+
271
+ f2 = self.conv2(f2)
272
+ f2 = self.adaIn2(f2, c2[:, : 128, ...], c2[:, 128:, ...])
273
+ f2 = self.relu2(f2)
274
+ guidance_actv = self.conv3(f2)
275
+
276
+ guidance_gamma = self.guidance_mlp_gamma(guidance_actv)
277
+ guidance_beta = self.guidance_mlp_beta(guidance_actv)
278
+
279
+ gamma_alpha = F.sigmoid(self.blending_gamma)
280
+ beta_alpha = F.sigmoid(self.blending_beta)
281
+
282
+ gamma_final = gamma_alpha * guidance_gamma + (1 - gamma_alpha) * mask_gamma
283
+ beta_final = beta_alpha * guidance_beta + (1 - beta_alpha) * mask_beta
284
+ out = x * (1 + gamma_final) + beta_final
285
+
286
+ # apply scale and bias
287
+ return out
288
+
289
+ class SPMGroupNorm(nn.Module):
290
+ def __init__(self, norm_nc, label_nc, feature_nc, eps = 1e-5):
291
+ super().__init__()
292
+ print("use SPM")
293
+
294
+ self.norm = nn.GroupNorm(32, norm_nc, affine=False) # 32/16
295
+
296
+ # SPADE
297
+ self.eps = eps
298
+ nhidden = 128
299
+ self.mask_mlp_shared = nn.Sequential(
300
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
301
+ nn.ReLU(),
302
+ )
303
+
304
+ self.mask_mlp_gamma1 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
305
+ self.mask_mlp_beta1 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
306
+
307
+ self.mask_mlp_gamma2 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
308
+ self.mask_mlp_beta2 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
309
+
310
+
311
+ # Feature
312
+ self.feature_mlp_shared = nn.Sequential(
313
+ nn.Conv2d(feature_nc, nhidden, kernel_size=3, padding=1),
314
+ nn.ReLU(),
315
+ )
316
+
317
+ self.feature_mlp_gamma1 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
318
+ self.feature_mlp_beta1 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
319
+
320
+
321
+ def forward(self, x, segmap, guidance):
322
+ # Part 1. generate parameter-free normalized activations
323
+ x = self.norm(x)
324
+ # Part 2. produce scaling and bias conditioned on semantic map
325
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
326
+ mask_actv = self.mask_mlp_shared(segmap)
327
+ mask_gamma1 = self.mask_mlp_gamma1(mask_actv)
328
+ mask_beta1 = self.mask_mlp_beta1(mask_actv)
329
+
330
+ mask_gamma2 = self.mask_mlp_gamma2(mask_actv)
331
+ mask_beta2 = self.mask_mlp_beta2(mask_actv)
332
+
333
+
334
+ # Part 3. produce scaling and bias conditioned on feature guidance
335
+ guidance = F.interpolate(guidance, size=x.size()[2:], mode='bilinear')
336
+ feature_actv = self.feature_mlp_shared(guidance)
337
+ feature_gamma1 = self.feature_mlp_gamma1(feature_actv)
338
+ feature_beta1 = self.feature_mlp_beta1(feature_actv)
339
+
340
+ gamma_final = feature_gamma1 * (1 + mask_gamma1) + mask_beta1
341
+ beta_final = feature_beta1 * (1 + mask_gamma2) + mask_beta2
342
+
343
+ out = x * (1 + gamma_final) + beta_final
344
+
345
+ # apply scale and bias
346
+ return out
347
+
348
+
349
+ class ResBlock(TimestepBlock):
350
+ """
351
+ A residual block that can optionally change the number of channels.
352
+
353
+ :param channels: the number of input channels.
354
+ :param emb_channels: the number of timestep embedding channels.
355
+ :param dropout: the rate of dropout.
356
+ :param out_channels: if specified, the number of out channels.
357
+ :param use_conv: if True and out_channels is specified, use a spatial
358
+ convolution instead of a smaller 1x1 convolution to change the
359
+ channels in the skip connection.
360
+ :param dims: determines if the signal is 1D, 2D, or 3D.
361
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
362
+ :param up: if True, use this block for upsampling.
363
+ :param down: if True, use this block for downsampling.
364
+ """
365
+
366
+ def __init__(
367
+ self,
368
+ channels,
369
+ emb_channels,
370
+ dropout,
371
+ out_channels=None,
372
+ use_conv=False,
373
+ use_scale_shift_norm=False,
374
+ dims=2,
375
+ use_checkpoint=False,
376
+ up=False,
377
+ down=False,
378
+ ):
379
+ super().__init__()
380
+ self.channels = channels
381
+ self.emb_channels = emb_channels
382
+ self.dropout = dropout
383
+ self.out_channels = out_channels or channels
384
+ self.use_conv = use_conv
385
+ self.use_checkpoint = use_checkpoint
386
+ self.use_scale_shift_norm = use_scale_shift_norm
387
+
388
+ self.in_layers = nn.Sequential(
389
+ normalization(channels),
390
+ SiLU(),
391
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
392
+ )
393
+
394
+ self.updown = up or down
395
+
396
+ if up:
397
+ self.h_upd = Upsample(channels, False, dims)
398
+ self.x_upd = Upsample(channels, False, dims)
399
+ elif down:
400
+ self.h_upd = Downsample(channels, False, dims)
401
+ self.x_upd = Downsample(channels, False, dims)
402
+ else:
403
+ self.h_upd = self.x_upd = nn.Identity()
404
+
405
+ self.emb_layers = nn.Sequential(
406
+ SiLU(),
407
+ linear(
408
+ emb_channels,
409
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
410
+ ),
411
+ )
412
+ self.out_layers = nn.Sequential(
413
+ normalization(self.out_channels),
414
+ SiLU(),
415
+ nn.Dropout(p=dropout),
416
+ zero_module(
417
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
418
+ ),
419
+ )
420
+
421
+ if self.out_channels == channels:
422
+ self.skip_connection = nn.Identity()
423
+ elif use_conv:
424
+ self.skip_connection = conv_nd(
425
+ dims, channels, self.out_channels, 3, padding=1
426
+ )
427
+ else:
428
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
429
+
430
+ def forward(self, x, emb):
431
+ """
432
+ Apply the block to a Tensor, conditioned on a timestep embedding.
433
+
434
+ :param x: an [N x C x ...] Tensor of features.
435
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
436
+ :return: an [N x C x ...] Tensor of outputs.
437
+ """
438
+
439
+ return th.utils.checkpoint.checkpoint(self._forward, x ,emb)
440
+ # return checkpoint(
441
+ # self._forward, (x, emb), self.parameters(), self.use_checkpoint
442
+ # )
443
+
444
+ def _forward(self, x, emb):
445
+ if self.updown:
446
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
447
+ h = in_rest(x)
448
+ h = self.h_upd(h)
449
+ x = self.x_upd(x)
450
+ h = in_conv(h)
451
+ else:
452
+ h = self.in_layers(x)
453
+ emb_out = self.emb_layers(emb)#.type(h.dtype)
454
+ while len(emb_out.shape) < len(h.shape):
455
+ emb_out = emb_out[..., None]
456
+ if self.use_scale_shift_norm:
457
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
458
+ scale, shift = th.chunk(emb_out, 2, dim=1)
459
+ h = out_norm(h) * (1 + scale) + shift
460
+ h = out_rest(h)
461
+ else:
462
+ h = h + emb_out
463
+ h = self.out_layers(h)
464
+ return self.skip_connection(x) + h
465
+
466
+ class SDMResBlock(CondTimestepBlock):
467
+ """
468
+ A residual block that can optionally change the number of channels.
469
+
470
+ :param channels: the number of input channels.
471
+ :param emb_channels: the number of timestep embedding channels.
472
+ :param dropout: the rate of dropout.
473
+ :param out_channels: if specified, the number of out channels.
474
+ :param use_conv: if True and out_channels is specified, use a spatial
475
+ convolution instead of a smaller 1x1 convolution to change the
476
+ channels in the skip connection.
477
+ :param dims: determines if the signal is 1D, 2D, or 3D.
478
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
479
+ :param up: if True, use this block for upsampling.
480
+ :param down: if True, use this block for downsampling.
481
+ """
482
+
483
+ def __init__(
484
+ self,
485
+ channels,
486
+ emb_channels,
487
+ dropout,
488
+ c_channels=3,
489
+ out_channels=None,
490
+ use_conv=False,
491
+ use_scale_shift_norm=False,
492
+ dims=2,
493
+ use_checkpoint=False,
494
+ up=False,
495
+ down=False,
496
+ SPADE_type = "spade",
497
+ guidance_nc = None
498
+ ):
499
+ super().__init__()
500
+ self.channels = channels
501
+ self.guidance_nc = guidance_nc
502
+ self.emb_channels = emb_channels
503
+ self.dropout = dropout
504
+ self.out_channels = out_channels or channels
505
+ self.use_conv = use_conv
506
+ self.use_checkpoint = use_checkpoint
507
+ self.use_scale_shift_norm = use_scale_shift_norm
508
+ self.SPADE_type = SPADE_type
509
+ if self.SPADE_type == "spade":
510
+ self.in_norm = SPADEGroupNorm(channels, c_channels)
511
+ elif self.SPADE_type == "RESAIL":
512
+ self.in_norm = RESAILGroupNorm(channels, c_channels, guidance_nc)
513
+ elif self.SPADE_type == "SPM":
514
+ self.in_norm = SPMGroupNorm(channels, c_channels, guidance_nc)
515
+ self.in_layers = nn.Sequential(
516
+ SiLU(),
517
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
518
+ )
519
+
520
+ self.updown = up or down
521
+
522
+ if up:
523
+ self.h_upd = Upsample(channels, False, dims)
524
+ self.x_upd = Upsample(channels, False, dims)
525
+ elif down:
526
+ self.h_upd = Downsample(channels, False, dims)
527
+ self.x_upd = Downsample(channels, False, dims)
528
+ else:
529
+ self.h_upd = self.x_upd = nn.Identity()
530
+
531
+ self.emb_layers = nn.Sequential(
532
+ SiLU(),
533
+ linear(
534
+ emb_channels,
535
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
536
+ ),
537
+ )
538
+
539
+ if self.SPADE_type == "spade":
540
+ self.out_norm = SPADEGroupNorm(self.out_channels, c_channels)
541
+ elif self.SPADE_type == "RESAIL":
542
+ self.out_norm = RESAILGroupNorm(self.out_channels, c_channels, guidance_nc)
543
+ elif self.SPADE_type == "SPM":
544
+ self.out_norm = SPMGroupNorm(self.out_channels, c_channels, guidance_nc)
545
+
546
+ self.out_layers = nn.Sequential(
547
+ SiLU(),
548
+ nn.Dropout(p=dropout),
549
+ zero_module(
550
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
551
+ ),
552
+ )
553
+
554
+ if self.out_channels == channels:
555
+ self.skip_connection = nn.Identity()
556
+ elif use_conv:
557
+ self.skip_connection = conv_nd(
558
+ dims, channels, self.out_channels, 3, padding=1
559
+ )
560
+ else:
561
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
562
+
563
+ def forward(self, x, cond, emb):
564
+ """
565
+ Apply the block to a Tensor, conditioned on a timestep embedding.
566
+
567
+ :param x: an [N x C x ...] Tensor of features.
568
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
569
+ :return: an [N x C x ...] Tensor of outputs.
570
+ """
571
+ return th.utils.checkpoint.checkpoint(self._forward, x, cond, emb)
572
+ # return checkpoint(
573
+ # self._forward, (x, cond, emb), self.parameters(), self.use_checkpoint
574
+ # )
575
+
576
+ def _forward(self, x, cond, emb):
577
+ if self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
578
+ assert self.guidance_nc is not None, "Please set guidance_nc when you use RESAIL"
579
+ guidance = x[: ,x.shape[1] - self.guidance_nc:, ...]
580
+ else:
581
+ guidance = None
582
+ if self.updown:
583
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
584
+ if self.SPADE_type == "spade":
585
+ h = self.in_norm(x, cond)
586
+ elif self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
587
+ h = self.in_norm(x, cond, guidance)
588
+
589
+ h = in_rest(h)
590
+ h = self.h_upd(h)
591
+ x = self.x_upd(x)
592
+ h = in_conv(h)
593
+ else:
594
+ if self.SPADE_type == "spade":
595
+ h = self.in_norm(x, cond)
596
+ elif self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
597
+ h = self.in_norm(x, cond, guidance)
598
+ h = self.in_layers(h)
599
+
600
+ emb_out = self.emb_layers(emb)#.type(h.dtype)
601
+ while len(emb_out.shape) < len(h.shape):
602
+ emb_out = emb_out[..., None]
603
+ if self.use_scale_shift_norm:
604
+ scale, shift = th.chunk(emb_out, 2, dim=1)
605
+ if self.SPADE_type == "spade":
606
+ h = self.out_norm(h, cond)
607
+ elif self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
608
+ h = self.out_norm(h, cond, guidance)
609
+
610
+ h = h * (1 + scale) + shift
611
+ h = self.out_layers(h)
612
+ else:
613
+ h = h + emb_out
614
+ if self.SPADE_type == "spade":
615
+ h = self.out_norm(h, cond)
616
+ elif self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
617
+ h = self.out_norm(x, cond, guidance)
618
+
619
+ h = self.out_layers(h)
620
+ return self.skip_connection(x) + h
621
+
622
+ class AttentionBlock(nn.Module):
623
+ """
624
+ An attention block that allows spatial positions to attend to each other.
625
+
626
+ Originally ported from here, but adapted to the N-d case.
627
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
628
+ """
629
+
630
+ def __init__(
631
+ self,
632
+ channels,
633
+ num_heads=1,
634
+ num_head_channels=-1,
635
+ use_checkpoint=False,
636
+ use_new_attention_order=False,
637
+ ):
638
+ super().__init__()
639
+ self.channels = channels
640
+ if num_head_channels == -1:
641
+ self.num_heads = num_heads
642
+ else:
643
+ assert (
644
+ channels % num_head_channels == 0
645
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
646
+ self.num_heads = channels // num_head_channels
647
+ self.use_checkpoint = use_checkpoint
648
+ self.norm = normalization(channels)
649
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
650
+ if use_new_attention_order:
651
+ # split qkv before split heads
652
+ self.attention = QKVAttention(self.num_heads)
653
+ else:
654
+ # split heads before split qkv
655
+ self.attention = QKVAttentionLegacy(self.num_heads)
656
+
657
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
658
+
659
+ def forward(self, x):
660
+ return th.utils.checkpoint.checkpoint(self._forward, x)
661
+ #return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
662
+
663
+ def _forward(self, x):
664
+ b, c, *spatial = x.shape
665
+ x = x.reshape(b, c, -1)
666
+ qkv = self.qkv(self.norm(x))
667
+ h = self.attention(qkv)
668
+ h = self.proj_out(h)
669
+ return (x + h).reshape(b, c, *spatial)
670
+
671
+
672
+ def count_flops_attn(model, _x, y):
673
+ """
674
+ A counter for the `thop` package to count the operations in an
675
+ attention operation.
676
+ Meant to be used like:
677
+ macs, params = thop.profile(
678
+ model,
679
+ inputs=(inputs, timestamps),
680
+ custom_ops={QKVAttention: QKVAttention.count_flops},
681
+ )
682
+ """
683
+ b, c, *spatial = y[0].shape
684
+ num_spatial = int(np.prod(spatial))
685
+ # We perform two matmuls with the same number of ops.
686
+ # The first computes the weight matrix, the second computes
687
+ # the combination of the value vectors.
688
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
689
+ model.total_ops += th.DoubleTensor([matmul_ops])
690
+
691
+
692
+ class QKVAttentionLegacy(nn.Module):
693
+ """
694
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
695
+ """
696
+
697
+ def __init__(self, n_heads):
698
+ super().__init__()
699
+ self.n_heads = n_heads
700
+
701
+ def forward(self, qkv):
702
+ """
703
+ Apply QKV attention.
704
+
705
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
706
+ :return: an [N x (H * C) x T] tensor after attention.
707
+ """
708
+ bs, width, length = qkv.shape
709
+ assert width % (3 * self.n_heads) == 0
710
+ ch = width // (3 * self.n_heads)
711
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
712
+ scale = 1 / math.sqrt(math.sqrt(ch))
713
+ weight = th.einsum(
714
+ "bct,bcs->bts", q * scale, k * scale
715
+ ) # More stable with f16 than dividing afterwards
716
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
717
+ a = th.einsum("bts,bcs->bct", weight, v)
718
+ return a.reshape(bs, -1, length)
719
+
720
+ @staticmethod
721
+ def count_flops(model, _x, y):
722
+ return count_flops_attn(model, _x, y)
723
+
724
+
725
+ class QKVAttention(nn.Module):
726
+ """
727
+ A module which performs QKV attention and splits in a different order.
728
+ """
729
+
730
+ def __init__(self, n_heads):
731
+ super().__init__()
732
+ self.n_heads = n_heads
733
+
734
+ def forward(self, qkv):
735
+ """
736
+ Apply QKV attention.
737
+
738
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
739
+ :return: an [N x (H * C) x T] tensor after attention.
740
+ """
741
+ bs, width, length = qkv.shape
742
+ assert width % (3 * self.n_heads) == 0
743
+ ch = width // (3 * self.n_heads)
744
+ q, k, v = qkv.chunk(3, dim=1)
745
+ scale = 1 / math.sqrt(math.sqrt(ch))
746
+ weight = th.einsum(
747
+ "bct,bcs->bts",
748
+ (q * scale).view(bs * self.n_heads, ch, length),
749
+ (k * scale).view(bs * self.n_heads, ch, length),
750
+ ) # More stable with f16 than dividing afterwards
751
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
752
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
753
+ return a.reshape(bs, -1, length)
754
+
755
+ @staticmethod
756
+ def count_flops(model, _x, y):
757
+ return count_flops_attn(model, _x, y)
758
+
759
+
760
+ class UNetModel(ModelMixin, ConfigMixin):
761
+ """
762
+ The full UNet model with attention and timestep embedding.
763
+
764
+ :param in_channels: channels in the input Tensor.
765
+ :param model_channels: base channel count for the model.
766
+ :param out_channels: channels in the output Tensor.
767
+ :param num_res_blocks: number of residual blocks per downsample.
768
+ :param attention_resolutions: a collection of downsample rates at which
769
+ attention will take place. May be a set, list, or tuple.
770
+ For example, if this contains 4, then at 4x downsampling, attention
771
+ will be used.
772
+ :param dropout: the dropout probability.
773
+ :param channel_mult: channel multiplier for each level of the UNet.
774
+ :param conv_resample: if True, use learned convolutions for upsampling and
775
+ downsampling.
776
+ :param dims: determines if the signal is 1D, 2D, or 3D.
777
+ :param num_classes: if specified (as an int), then this model will be
778
+ class-conditional with `num_classes` classes.
779
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
780
+ :param num_heads: the number of attention heads in each attention layer.
781
+ :param num_heads_channels: if specified, ignore num_heads and instead use
782
+ a fixed channel width per attention head.
783
+ :param num_heads_upsample: works with num_heads to set a different number
784
+ of heads for upsampling. Deprecated.
785
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
786
+ :param resblock_updown: use residual blocks for up/downsampling.
787
+ :param use_new_attention_order: use a different attention pattern for potentially
788
+ increased efficiency.
789
+ """
790
+
791
+ _supports_gradient_checkpointing = True
792
+ @register_to_config
793
+ def __init__(
794
+ self,
795
+ image_size,
796
+ in_channels,
797
+ model_channels,
798
+ out_channels,
799
+ num_res_blocks,
800
+ attention_resolutions,
801
+ dropout=0,
802
+ channel_mult=(1, 2, 4, 8),
803
+ conv_resample=True,
804
+ dims=2,
805
+ num_classes=None,
806
+ use_checkpoint=False,
807
+ use_fp16=True,
808
+ num_heads=1,
809
+ num_head_channels=-1,
810
+ num_heads_upsample=-1,
811
+ use_scale_shift_norm=False,
812
+ resblock_updown=False,
813
+ use_new_attention_order=False,
814
+ mask_emb="resize",
815
+ SPADE_type="spade",
816
+ ):
817
+ super().__init__()
818
+
819
+ if num_heads_upsample == -1:
820
+ num_heads_upsample = num_heads
821
+
822
+ self.sample_size = image_size
823
+ self.in_channels = in_channels
824
+ self.model_channels = model_channels
825
+ self.out_channels = out_channels
826
+ self.num_res_blocks = num_res_blocks
827
+ self.attention_resolutions = attention_resolutions
828
+ self.dropout = dropout
829
+ self.channel_mult = channel_mult
830
+ self.conv_resample = conv_resample
831
+ self.num_classes = num_classes
832
+ self.use_checkpoint = use_checkpoint
833
+ self.num_heads = num_heads
834
+ self.num_head_channels = num_head_channels
835
+ self.num_heads_upsample = num_heads_upsample
836
+
837
+ self.mask_emb = mask_emb
838
+
839
+ time_embed_dim = model_channels * 4
840
+ self.time_embed = nn.Sequential(
841
+ linear(model_channels, time_embed_dim),
842
+ SiLU(),
843
+ linear(time_embed_dim, time_embed_dim),
844
+ )
845
+
846
+ ch = input_ch = int(channel_mult[0] * model_channels)
847
+ self.input_blocks = nn.ModuleList(
848
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] #ch=256
849
+ )
850
+ self._feature_size = ch
851
+ input_block_chans = [ch]
852
+ ds = 1
853
+ for level, mult in enumerate(channel_mult):
854
+ for _ in range(num_res_blocks):
855
+ layers = [
856
+ ResBlock(
857
+ ch,
858
+ time_embed_dim,
859
+ dropout,
860
+ out_channels=int(mult * model_channels),
861
+ dims=dims,
862
+ use_checkpoint=use_checkpoint,
863
+ use_scale_shift_norm=use_scale_shift_norm,
864
+ )
865
+ ]
866
+ ch = int(mult * model_channels)
867
+ #print(ds)
868
+ if ds in attention_resolutions:
869
+ layers.append(
870
+ AttentionBlock(
871
+ ch,
872
+ use_checkpoint=use_checkpoint,
873
+ num_heads=num_heads,
874
+ num_head_channels=num_head_channels,
875
+ use_new_attention_order=use_new_attention_order,
876
+ )
877
+ )
878
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
879
+ self._feature_size += ch
880
+ input_block_chans.append(ch)
881
+ if level != len(channel_mult) - 1:
882
+ out_ch = ch
883
+ self.input_blocks.append(
884
+ TimestepEmbedSequential(
885
+ ResBlock(
886
+ ch,
887
+ time_embed_dim,
888
+ dropout,
889
+ out_channels=out_ch,
890
+ dims=dims,
891
+ use_checkpoint=use_checkpoint,
892
+ use_scale_shift_norm=use_scale_shift_norm,
893
+ down=True,
894
+ )
895
+ if resblock_updown
896
+ else Downsample(
897
+ ch, conv_resample, dims=dims, out_channels=out_ch
898
+ )
899
+ )
900
+ )
901
+ ch = out_ch
902
+ input_block_chans.append(ch)
903
+ ds *= 2
904
+ self._feature_size += ch
905
+ self.middle_block = TimestepEmbedSequential(
906
+ SDMResBlock(
907
+ ch,
908
+ time_embed_dim,
909
+ dropout,
910
+ c_channels=num_classes if mask_emb == "resize" else num_classes*4,
911
+ dims=dims,
912
+ use_checkpoint=use_checkpoint,
913
+ use_scale_shift_norm=use_scale_shift_norm,
914
+ ),
915
+ AttentionBlock(
916
+ ch,
917
+ use_checkpoint=use_checkpoint,
918
+ num_heads=num_heads,
919
+ num_head_channels=num_head_channels,
920
+ use_new_attention_order=use_new_attention_order,
921
+ ),
922
+ SDMResBlock(
923
+ ch,
924
+ time_embed_dim,
925
+ dropout,
926
+ c_channels=num_classes if mask_emb == "resize" else num_classes*4 ,
927
+ dims=dims,
928
+ use_checkpoint=use_checkpoint,
929
+ use_scale_shift_norm=use_scale_shift_norm,
930
+ ),
931
+ )
932
+ self._feature_size += ch
933
+
934
+ self.output_blocks = nn.ModuleList([])
935
+ for level, mult in list(enumerate(channel_mult))[::-1]:
936
+ for i in range(num_res_blocks + 1):
937
+ ich = input_block_chans.pop()
938
+ #print(ch, ich)
939
+ layers = [
940
+ SDMResBlock(
941
+ ch + ich,
942
+ time_embed_dim,
943
+ dropout,
944
+ c_channels=num_classes if mask_emb == "resize" else num_classes*4,
945
+ out_channels=int(model_channels * mult),
946
+ dims=dims,
947
+ use_checkpoint=use_checkpoint,
948
+ use_scale_shift_norm=use_scale_shift_norm,
949
+ SPADE_type=SPADE_type,
950
+ guidance_nc = ich,
951
+ )
952
+ ]
953
+ ch = int(model_channels * mult)
954
+ #print(ds)
955
+ if ds in attention_resolutions:
956
+ layers.append(
957
+ AttentionBlock(
958
+ ch,
959
+ use_checkpoint=use_checkpoint,
960
+ num_heads=num_heads_upsample,
961
+ num_head_channels=num_head_channels,
962
+ use_new_attention_order=use_new_attention_order,
963
+ )
964
+ )
965
+ if level and i == num_res_blocks:
966
+ out_ch = ch
967
+ layers.append(
968
+ SDMResBlock(
969
+ ch,
970
+ time_embed_dim,
971
+ dropout,
972
+ c_channels=num_classes if mask_emb == "resize" else num_classes*4,
973
+ out_channels=out_ch,
974
+ dims=dims,
975
+ use_checkpoint=use_checkpoint,
976
+ use_scale_shift_norm=use_scale_shift_norm,
977
+ up=True,
978
+ )
979
+ if resblock_updown
980
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
981
+ )
982
+ ds //= 2
983
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
984
+ self._feature_size += ch
985
+
986
+ self.out = nn.Sequential(
987
+ normalization(ch),
988
+ SiLU(),
989
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
990
+ )
991
+ def _set_gradient_checkpointing(self, module, value=False):
992
+ #if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
993
+ module.gradient_checkpointing = value
994
+ def forward(self, x, y=None, timesteps=None ):
995
+ """
996
+ Apply the model to an input batch.
997
+
998
+ :param x: an [N x C x ...] Tensor of inputs.
999
+ :param timesteps: a 1-D batch of timesteps.
1000
+ :param y: an [N] Tensor of labels, if class-conditional.
1001
+ :return: an [N x C x ...] Tensor of outputs.
1002
+ """
1003
+ assert (y is not None) == (
1004
+ self.num_classes is not None
1005
+ ), "must specify y if and only if the model is class-conditional"
1006
+
1007
+ hs = []
1008
+ if not th.is_tensor(timesteps):
1009
+ timesteps = th.tensor([timesteps], dtype=th.long, device=x.device)
1010
+ elif th.is_tensor(timesteps) and len(timesteps.shape) == 0:
1011
+ timesteps = timesteps[None].to(x.device)
1012
+
1013
+ timesteps = timestep_embedding(timesteps, self.model_channels).type(x.dtype).to(x.device)
1014
+ emb = self.time_embed(timesteps)
1015
+
1016
+ y = y.type(self.dtype)
1017
+ h = x.type(self.dtype)
1018
+ for module in self.input_blocks:
1019
+ # input_blocks have no any opts for y
1020
+ h = module(h, y, emb)
1021
+ #print(h.shape)
1022
+ hs.append(h)
1023
+
1024
+ h = self.middle_block(h, y, emb)
1025
+
1026
+ for module in self.output_blocks:
1027
+ temp = hs.pop()
1028
+
1029
+ #print("before:", h.shape, temp.shape)
1030
+ # copy padding to match the downsample size
1031
+ if h.shape[2] != temp.shape[2]:
1032
+ p1d = (0, 0, 0, 1)
1033
+ h = F.pad(h, p1d, "replicate")
1034
+
1035
+ if h.shape[3] != temp.shape[3]:
1036
+ p2d = (0, 1, 0, 0)
1037
+ h = F.pad(h, p2d, "replicate")
1038
+ #print("after:", h.shape, temp.shape)
1039
+
1040
+ h = th.cat([h, temp], dim=1)
1041
+ h = module(h, y, emb)
1042
+
1043
+ h = h.type(x.dtype)
1044
+ return UNet2DOutput(sample=self.out(h))
1045
+
1046
+
1047
+ class SuperResModel(UNetModel):
1048
+ """
1049
+ A UNetModel that performs super-resolution.
1050
+
1051
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
1052
+ """
1053
+
1054
+ def __init__(self, image_size, in_channels, *args, **kwargs):
1055
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
1056
+
1057
+ def forward(self, x, cond, timesteps, low_res=None, **kwargs):
1058
+ _, _, new_height, new_width = x.shape
1059
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
1060
+ x = th.cat([x, upsampled], dim=1)
1061
+ return super().forward(x, cond, timesteps, **kwargs)
1062
+
1063
+
1064
+ class EncoderUNetModel(nn.Module):
1065
+ """
1066
+ The half UNet model with attention and timestep embedding.
1067
+
1068
+ For usage, see UNet.
1069
+ """
1070
+
1071
+ def __init__(
1072
+ self,
1073
+ image_size,
1074
+ in_channels,
1075
+ model_channels,
1076
+ out_channels,
1077
+ num_res_blocks,
1078
+ attention_resolutions,
1079
+ dropout=0,
1080
+ channel_mult=(1, 2, 4, 8),
1081
+ conv_resample=True,
1082
+ dims=2,
1083
+ use_checkpoint=False,
1084
+ use_fp16=False,
1085
+ num_heads=1,
1086
+ num_head_channels=-1,
1087
+ num_heads_upsample=-1,
1088
+ use_scale_shift_norm=False,
1089
+ resblock_updown=False,
1090
+ use_new_attention_order=False,
1091
+ pool="adaptive",
1092
+ ):
1093
+ super().__init__()
1094
+
1095
+ if num_heads_upsample == -1:
1096
+ num_heads_upsample = num_heads
1097
+
1098
+ self.in_channels = in_channels
1099
+ self.model_channels = model_channels
1100
+ self.out_channels = out_channels
1101
+ self.num_res_blocks = num_res_blocks
1102
+ self.attention_resolutions = attention_resolutions
1103
+ self.dropout = dropout
1104
+ self.channel_mult = channel_mult
1105
+ self.conv_resample = conv_resample
1106
+ self.use_checkpoint = use_checkpoint
1107
+ self.num_heads = num_heads
1108
+ self.num_head_channels = num_head_channels
1109
+ self.num_heads_upsample = num_heads_upsample
1110
+
1111
+ time_embed_dim = model_channels * 4
1112
+ self.time_embed = nn.Sequential(
1113
+ linear(model_channels, time_embed_dim),
1114
+ SiLU(),
1115
+ linear(time_embed_dim, time_embed_dim),
1116
+ )
1117
+
1118
+ ch = int(channel_mult[0] * model_channels)
1119
+ self.input_blocks = nn.ModuleList(
1120
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
1121
+ )
1122
+ self._feature_size = ch
1123
+ input_block_chans = [ch]
1124
+ ds = 1
1125
+ for level, mult in enumerate(channel_mult):
1126
+ for _ in range(num_res_blocks):
1127
+ layers = [
1128
+ ResBlock(
1129
+ ch,
1130
+ time_embed_dim,
1131
+ dropout,
1132
+ out_channels=int(mult * model_channels),
1133
+ dims=dims,
1134
+ use_checkpoint=use_checkpoint,
1135
+ use_scale_shift_norm=use_scale_shift_norm,
1136
+ )
1137
+ ]
1138
+ ch = int(mult * model_channels)
1139
+ if ds in attention_resolutions:
1140
+ layers.append(
1141
+ AttentionBlock(
1142
+ ch,
1143
+ use_checkpoint=use_checkpoint,
1144
+ num_heads=num_heads,
1145
+ num_head_channels=num_head_channels,
1146
+ use_new_attention_order=use_new_attention_order,
1147
+ )
1148
+ )
1149
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1150
+ self._feature_size += ch
1151
+ input_block_chans.append(ch)
1152
+ if level != len(channel_mult) - 1:
1153
+ out_ch = ch
1154
+ self.input_blocks.append(
1155
+ TimestepEmbedSequential(
1156
+ ResBlock(
1157
+ ch,
1158
+ time_embed_dim,
1159
+ dropout,
1160
+ out_channels=out_ch,
1161
+ dims=dims,
1162
+ use_checkpoint=use_checkpoint,
1163
+ use_scale_shift_norm=use_scale_shift_norm,
1164
+ down=True,
1165
+ )
1166
+ if resblock_updown
1167
+ else Downsample(
1168
+ ch, conv_resample, dims=dims, out_channels=out_ch
1169
+ )
1170
+ )
1171
+ )
1172
+ ch = out_ch
1173
+ input_block_chans.append(ch)
1174
+ ds *= 2
1175
+ self._feature_size += ch
1176
+
1177
+ self.middle_block = TimestepEmbedSequential(
1178
+ ResBlock(
1179
+ ch,
1180
+ time_embed_dim,
1181
+ dropout,
1182
+ dims=dims,
1183
+ use_checkpoint=use_checkpoint,
1184
+ use_scale_shift_norm=use_scale_shift_norm,
1185
+ ),
1186
+ AttentionBlock(
1187
+ ch,
1188
+ use_checkpoint=use_checkpoint,
1189
+ num_heads=num_heads,
1190
+ num_head_channels=num_head_channels,
1191
+ use_new_attention_order=use_new_attention_order,
1192
+ ),
1193
+ ResBlock(
1194
+ ch,
1195
+ time_embed_dim,
1196
+ dropout,
1197
+ dims=dims,
1198
+ use_checkpoint=use_checkpoint,
1199
+ use_scale_shift_norm=use_scale_shift_norm,
1200
+ ),
1201
+ )
1202
+ self._feature_size += ch
1203
+ self.pool = pool
1204
+ if pool == "adaptive":
1205
+ self.out = nn.Sequential(
1206
+ normalization(ch),
1207
+ SiLU(),
1208
+ nn.AdaptiveAvgPool2d((1, 1)),
1209
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1210
+ nn.Flatten(),
1211
+ )
1212
+ elif pool == "attention":
1213
+ assert num_head_channels != -1
1214
+ self.out = nn.Sequential(
1215
+ normalization(ch),
1216
+ SiLU(),
1217
+ AttentionPool2d(
1218
+ (image_size // ds), ch, num_head_channels, out_channels
1219
+ ),
1220
+ )
1221
+ elif pool == "spatial":
1222
+ self.out = nn.Sequential(
1223
+ nn.Linear(self._feature_size, 2048),
1224
+ nn.ReLU(),
1225
+ nn.Linear(2048, self.out_channels),
1226
+ )
1227
+ elif pool == "spatial_v2":
1228
+ self.out = nn.Sequential(
1229
+ nn.Linear(self._feature_size, 2048),
1230
+ normalization(2048),
1231
+ SiLU(),
1232
+ nn.Linear(2048, self.out_channels),
1233
+ )
1234
+ else:
1235
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1236
+ def forward(self, x, timesteps):
1237
+ """
1238
+ Apply the model to an input batch.
1239
+
1240
+ :param x: an [N x C x ...] Tensor of inputs.
1241
+ :param timesteps: a 1-D batch of timesteps.
1242
+ :return: an [N x K] Tensor of outputs.
1243
+ """
1244
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1245
+
1246
+ results = []
1247
+ h = x.type(self.dtype)
1248
+ for module in self.input_blocks:
1249
+ h = module(h, emb)
1250
+ if self.pool.startswith("spatial"):
1251
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1252
+ h = self.middle_block(h, emb)
1253
+ if self.pool.startswith("spatial"):
1254
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1255
+ h = th.cat(results, axis=-1)
1256
+ return self.out(h)
1257
+ else:
1258
+ h = h.type(x.dtype)
1259
+ return self.out(h)
1260
+
code_for_ade20k/diffusion_module/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
code_for_ade20k/diffusion_module/unet_2d_sdm.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ from diffusers.models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ from .unet_2d_blocks import UNetSDMMidBlock2D, get_down_block, get_up_block, UNetSDMMidBlock2D
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin
26
+
27
+ @dataclass
28
+ class UNet2DOutput(BaseOutput):
29
+ """
30
+ Args:
31
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
32
+ Hidden states output. Output of last layer of model.
33
+ """
34
+
35
+ sample: torch.FloatTensor
36
+
37
+ class SDMUNet2DModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
38
+ r"""
39
+ UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
40
+
41
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
42
+ implements for all the model (such as downloading or saving, etc.)
43
+
44
+ Parameters:
45
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
46
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
47
+ 1)`.
48
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
49
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
50
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
51
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
52
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
53
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
54
+ obj:`True`): Whether to flip sin to cos for fourier time embedding.
55
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
56
+ obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
57
+ types.
58
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
59
+ The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
60
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
61
+ obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
62
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
63
+ obj:`(224, 448, 672, 896)`): Tuple of block output channels.
64
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
65
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
66
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
67
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
68
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
69
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
70
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
71
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
72
+ for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
73
+ class_embed_type (`str`, *optional*, defaults to None):
74
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
75
+ `"timestep"`, or `"identity"`.
76
+ num_class_embeds (`int`, *optional*, defaults to None):
77
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
78
+ class conditioning with `class_embed_type` equal to `None`.
79
+ """
80
+
81
+ _supports_gradient_checkpointing = True
82
+ @register_to_config
83
+ def __init__(
84
+ self,
85
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
86
+ in_channels: int = 3,
87
+ out_channels: int = 3,
88
+ center_input_sample: bool = False,
89
+ time_embedding_type: str = "positional",
90
+ freq_shift: int = 0,
91
+ flip_sin_to_cos: bool = True,
92
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
93
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
94
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
95
+ layers_per_block: int = 2,
96
+ mid_block_scale_factor: float = 1,
97
+ downsample_padding: int = 1,
98
+ act_fn: str = "silu",
99
+ attention_head_dim: Optional[int] = 8,
100
+ norm_num_groups: int = 32,
101
+ norm_eps: float = 1e-5,
102
+ resnet_time_scale_shift: str = "scale_shift",
103
+ add_attention: bool = True,
104
+ class_embed_type: Optional[str] = None,
105
+ num_class_embeds: Optional[int] = None,
106
+ segmap_channels: int = 34,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.sample_size = sample_size
111
+ self.segmap_channels = segmap_channels
112
+ time_embed_dim = block_out_channels[0] * 4
113
+
114
+ # Check inputs
115
+ if len(down_block_types) != len(up_block_types):
116
+ raise ValueError(
117
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
118
+ )
119
+
120
+ if len(block_out_channels) != len(down_block_types):
121
+ raise ValueError(
122
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
123
+ )
124
+
125
+ # input
126
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
127
+
128
+ # time
129
+ if time_embedding_type == "fourier":
130
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
131
+ timestep_input_dim = 2 * block_out_channels[0]
132
+ elif time_embedding_type == "positional":
133
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
134
+ timestep_input_dim = block_out_channels[0]
135
+
136
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
137
+
138
+ # class embedding
139
+ if class_embed_type is None and num_class_embeds is not None:
140
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
141
+ elif class_embed_type == "timestep":
142
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
143
+ elif class_embed_type == "identity":
144
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
145
+ else:
146
+ self.class_embedding = None
147
+
148
+ self.down_blocks = nn.ModuleList([])
149
+ self.mid_block = None
150
+ self.up_blocks = nn.ModuleList([])
151
+
152
+ # down
153
+ output_channel = block_out_channels[0]
154
+ for i, down_block_type in enumerate(down_block_types):
155
+ input_channel = output_channel
156
+ output_channel = block_out_channels[i]
157
+ is_final_block = i == len(block_out_channels) - 1
158
+
159
+ down_block = get_down_block(
160
+ down_block_type,
161
+ num_layers=layers_per_block,
162
+ in_channels=input_channel,
163
+ out_channels=output_channel,
164
+ temb_channels=time_embed_dim,
165
+ add_downsample=not is_final_block,
166
+ resnet_eps=norm_eps,
167
+ resnet_act_fn=act_fn,
168
+ resnet_groups=norm_num_groups,
169
+ attn_num_head_channels=attention_head_dim,
170
+ downsample_padding=downsample_padding,
171
+ resnet_time_scale_shift=resnet_time_scale_shift,
172
+ )
173
+ self.down_blocks.append(down_block)
174
+
175
+ # mid
176
+ self.mid_block = UNetSDMMidBlock2D(
177
+ in_channels=block_out_channels[-1],
178
+ temb_channels=time_embed_dim,
179
+ resnet_eps=norm_eps,
180
+ resnet_act_fn=act_fn,
181
+ output_scale_factor=mid_block_scale_factor,
182
+ resnet_time_scale_shift=resnet_time_scale_shift,
183
+ attn_num_head_channels=attention_head_dim,
184
+ resnet_groups=norm_num_groups,
185
+ add_attention=add_attention,
186
+ segmap_channels=segmap_channels,
187
+ )
188
+
189
+ # up
190
+ reversed_block_out_channels = list(reversed(block_out_channels))
191
+ output_channel = reversed_block_out_channels[0]
192
+ for i, up_block_type in enumerate(up_block_types):
193
+ prev_output_channel = output_channel
194
+ output_channel = reversed_block_out_channels[i]
195
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
196
+
197
+ is_final_block = i == len(block_out_channels) - 1
198
+
199
+ up_block = get_up_block(
200
+ up_block_type,
201
+ num_layers=layers_per_block + 1,
202
+ in_channels=input_channel,
203
+ out_channels=output_channel,
204
+ prev_output_channel=prev_output_channel,
205
+ temb_channels=time_embed_dim,
206
+ add_upsample=not is_final_block,
207
+ resnet_eps=norm_eps,
208
+ resnet_act_fn=act_fn,
209
+ resnet_groups=norm_num_groups,
210
+ attn_num_head_channels=attention_head_dim,
211
+ resnet_time_scale_shift=resnet_time_scale_shift,
212
+ segmap_channels=segmap_channels,
213
+ )
214
+ self.up_blocks.append(up_block)
215
+ prev_output_channel = output_channel
216
+
217
+ # out
218
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
219
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
220
+ self.conv_act = nn.SiLU()
221
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
222
+ def _set_gradient_checkpointing(self, module, value=False):
223
+ #if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
224
+ module.gradient_checkpointing = value
225
+ def forward(
226
+ self,
227
+ sample: torch.FloatTensor,
228
+ segmap: torch.FloatTensor,
229
+ timestep: Union[torch.Tensor, float, int],
230
+ class_labels: Optional[torch.Tensor] = None,
231
+ return_dict: bool = True,
232
+ ) -> Union[UNet2DOutput, Tuple]:
233
+ r"""
234
+ Args:
235
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
236
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
237
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
238
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
239
+ return_dict (`bool`, *optional*, defaults to `True`):
240
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
241
+
242
+ Returns:
243
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
244
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
245
+ """
246
+ # 0. center input if necessary
247
+ if self.config.center_input_sample:
248
+ sample = 2 * sample - 1.0
249
+
250
+ # 1. time
251
+ #print(timestep.shape)
252
+ timesteps = timestep
253
+ if not torch.is_tensor(timesteps):
254
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
255
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
256
+ timesteps = timesteps[None].to(sample.device)
257
+
258
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
259
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
260
+
261
+ t_emb = self.time_proj(timesteps)
262
+
263
+ # timesteps does not contain any weights and will always return f32 tensors
264
+ # but time_embedding might actually be running in fp16. so we need to cast here.
265
+ # there might be better ways to encapsulate this.
266
+ t_emb = t_emb.to(dtype=self.dtype)
267
+ emb = self.time_embedding(t_emb)
268
+
269
+ if self.class_embedding is not None:
270
+ if class_labels is None:
271
+ raise ValueError("class_labels should be provided when doing class conditioning")
272
+
273
+ if self.config.class_embed_type == "timestep":
274
+ class_labels = self.time_proj(class_labels)
275
+
276
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
277
+ emb = emb + class_emb
278
+
279
+ # 2. pre-process
280
+ skip_sample = sample
281
+ sample = self.conv_in(sample)
282
+
283
+ # 3. down
284
+ down_block_res_samples = (sample,)
285
+ for downsample_block in self.down_blocks:
286
+ if hasattr(downsample_block, "skip_conv"):
287
+ sample, res_samples, skip_sample = downsample_block(
288
+ hidden_states=sample, temb=emb, skip_sample=skip_sample,segmap=segmap
289
+ )
290
+ else:
291
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
292
+
293
+ down_block_res_samples += res_samples
294
+
295
+ # 4. mid
296
+ sample = self.mid_block(sample, segmap, emb)
297
+
298
+ # 5. up
299
+ skip_sample = None
300
+ for upsample_block in self.up_blocks:
301
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
302
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
303
+
304
+ if hasattr(upsample_block, "skip_conv"):
305
+ sample, skip_sample = upsample_block(sample, segmap, res_samples, emb, skip_sample)
306
+ else:
307
+ sample = upsample_block(sample, segmap, res_samples, emb)
308
+
309
+ # 6. post-process
310
+ sample = self.conv_norm_out(sample)
311
+ sample = self.conv_act(sample)
312
+ sample = self.conv_out(sample)
313
+
314
+ if skip_sample is not None:
315
+ sample += skip_sample
316
+
317
+ if self.config.time_embedding_type == "fourier":
318
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
319
+ sample = sample / timesteps
320
+
321
+ if not return_dict:
322
+ return (sample,)
323
+
324
+ return UNet2DOutput(sample=sample)
325
+
326
+
327
+ if __name__ == "__main__":
328
+ path = 'output.txt'
329
+ f = open(path, 'w')
330
+
331
+ unet = SDMUNet2DModel(
332
+ sample_size=270,
333
+ in_channels=3,
334
+ out_channels=3,
335
+ layers_per_block=2,
336
+ block_out_channels=(256, 256, 512, 1024, 1024),
337
+ down_block_types=(
338
+ "ResnetDownsampleBlock2D",
339
+ "ResnetDownsampleBlock2D",
340
+ "ResnetDownsampleBlock2D",
341
+ "AttnDownBlock2D",
342
+ "AttnDownBlock2D",
343
+ ),
344
+ up_block_types=(
345
+ "SDMAttnUpBlock2D",
346
+ "SDMAttnUpBlock2D",
347
+ "SDMResnetUpsampleBlock2D",
348
+ "SDMResnetUpsampleBlock2D",
349
+ "SDMResnetUpsampleBlock2D",
350
+ ),
351
+ segmap_channels=34+1
352
+ )
353
+
354
+ print(unet,file=f)
355
+ f.close()
356
+
357
+ #summary(unet, [(1, 3, 270, 360), (1, 3, 270, 360), (2,)], device="cpu")
code_for_ade20k/diffusion_module/utils/LSDMPipeline_expandDataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+ from diffusers.models import UNet2DModel, VQModel
7
+ from diffusers.schedulers import DDIMScheduler
8
+ from diffusers.utils import randn_tensor
9
+ from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
10
+ import copy
11
+
12
+ class SDMLDMPipeline(DiffusionPipeline):
13
+ r"""
14
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
15
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
16
+
17
+ Parameters:
18
+ vae ([`VQModel`]):
19
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
20
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
21
+ scheduler ([`SchedulerMixin`]):
22
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
23
+ """
24
+
25
+ def __init__(self, vae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler, torch_dtype=torch.float16, resolution=512, resolution_type="city"):
26
+ super().__init__()
27
+ self.register_modules(vae=vae, unet=unet, scheduler=scheduler)
28
+ self.torch_dtype = torch_dtype
29
+ self.resolution = resolution
30
+ self.resolution_type = resolution_type
31
+ @torch.no_grad()
32
+ def __call__(
33
+ self,
34
+ segmap = None,
35
+ batch_size: int = 8,
36
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
37
+ eta: float = 0.0,
38
+ num_inference_steps: int = 1000,
39
+ output_type: Optional[str] = "pil",
40
+ return_dict: bool = True,
41
+ every_step_save: int = None,
42
+ s: int = 1,
43
+ num_evolution_per_mask = 10,
44
+ **kwargs,
45
+ ) -> Union[Tuple, ImagePipelineOutput]:
46
+ r"""
47
+ Args:
48
+ batch_size (`int`, *optional*, defaults to 1):
49
+ Number of images to generate.
50
+ generator (`torch.Generator`, *optional*):
51
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
52
+ to make generation deterministic.
53
+ num_inference_steps (`int`, *optional*, defaults to 50):
54
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
55
+ expense of slower inference.
56
+ output_type (`str`, *optional*, defaults to `"pil"`):
57
+ The output format of the generate image. Choose between
58
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
59
+ return_dict (`bool`, *optional*, defaults to `True`):
60
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
61
+
62
+ Returns:
63
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.model.ImagePipelineOutput`] if `return_dict` is
64
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
65
+ """
66
+ # self.unet.config.sample_size = (64, 64) # (135,180)
67
+ # self.unet.config.sample_size = (135,180)
68
+ if self.resolution_type == "crack":
69
+ self.unet.config.sample_size = (64,64)
70
+ elif self.resolution_type == "crack_256":
71
+ self.unet.config.sample_size = (256,256)
72
+ else:
73
+ sc = 1080 // self.resolution
74
+ latent_size = (self.resolution // 4, 1440 // (sc*4))
75
+ self.unet.config.sample_size = latent_size
76
+ #
77
+ if not isinstance(self.unet.config.sample_size, tuple):
78
+ self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
79
+
80
+ if segmap is None:
81
+ print("Didn't inpute any segmap, use the empty as the input")
82
+ segmap = torch.zeros(batch_size,self.unet.config.segmap_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1])
83
+ segmap = segmap.to(self.device).type(self.torch_dtype)
84
+ if batch_size == 1 and num_evolution_per_mask > batch_size:
85
+ latents = randn_tensor(
86
+ (num_evolution_per_mask, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
87
+ generator=generator,
88
+ )
89
+ else:
90
+ latents = randn_tensor(
91
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
92
+ generator=generator,
93
+ )
94
+ latents = latents.to(self.device).type(self.torch_dtype)
95
+
96
+ # scale the initial noise by the standard deviation required by the scheduler (need to check)
97
+ latents = latents * self.scheduler.init_noise_sigma
98
+
99
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps)
100
+
101
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
102
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
103
+
104
+ extra_kwargs = {}
105
+ if accepts_eta:
106
+ extra_kwargs["eta"] = eta
107
+
108
+ step_latent = []
109
+ learn_sigma = True if hasattr(self.scheduler, "variance_type") else False
110
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
111
+
112
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
113
+ # predict the noise residual
114
+ noise_prediction = self.unet(latent_model_input, segmap, t).sample
115
+ # compute the previous noisy sample x_t -> x_t-1
116
+
117
+
118
+ if learn_sigma and "learn" in self.scheduler.variance_type:
119
+ model_pred, var_pred = torch.split(noise_prediction, latents.shape[1], dim=1)
120
+ else:
121
+ model_pred = noise_prediction
122
+ if s > 1.0:
123
+ model_output_zero = self.unet(latent_model_input, torch.zeros_like(segmap), t).sample
124
+ if learn_sigma and "learn" in self.scheduler.variance_type:
125
+ model_output_zero,_ = torch.split(model_output_zero, latents.shape[1], dim=1)
126
+ model_pred = model_pred + s * (model_pred - model_output_zero)
127
+ if learn_sigma and "learn" in self.scheduler.variance_type:
128
+ recombined = torch.cat((model_pred, var_pred), dim=1)
129
+ # when apply different scheduler, mean only !!
130
+ if learn_sigma and "learn" in self.scheduler.variance_type:
131
+ latents = self.scheduler.step(recombined, t, latents, **extra_kwargs).prev_sample
132
+ else:
133
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
134
+
135
+ if every_step_save is not None:
136
+ if (i+1) % every_step_save == 0:
137
+ step_latent.append(copy.deepcopy(latents))
138
+
139
+ # decode the image latents with the VAE
140
+ if every_step_save is not None:
141
+ image = []
142
+ for i, l in enumerate(step_latent):
143
+ l /= self.vae.config.scaling_factor # (0.18215)
144
+ #latents /= 7.706491063029163
145
+ l = self.vae.decode(l, segmap)
146
+ l = (l / 2 + 0.5).clamp(0, 1)
147
+ l = l.cpu().permute(0, 2, 3, 1).numpy()
148
+ if output_type == "pil":
149
+ l = self.numpy_to_pil(l)
150
+ image.append(l)
151
+ else:
152
+ latents /= self.vae.config.scaling_factor#(0.18215)
153
+ #latents /= 7.706491063029163
154
+ # image = self.vae.decode(latents, segmap).sample
155
+ image = self.vae.decode(latents, return_dict=False)[0]
156
+ image = (image / 2 + 0.5).clamp(0, 1)
157
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
158
+ if output_type == "pil":
159
+ image = self.numpy_to_pil(image)
160
+
161
+ if not return_dict:
162
+ return (image,)
163
+
164
+ return ImagePipelineOutput(images=image)
code_for_ade20k/diffusion_module/utils/Pipline.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+ from diffusers.models import UNet2DModel, VQModel
7
+ from diffusers.schedulers import DDIMScheduler
8
+ from diffusers.utils import randn_tensor
9
+ from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
10
+ import copy
11
+
12
+ class LDMPipeline(DiffusionPipeline):
13
+ r"""
14
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
15
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
16
+
17
+ Parameters:
18
+ vae ([`VQModel`]):
19
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
20
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
21
+ scheduler ([`SchedulerMixin`]):
22
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
23
+ """
24
+
25
+ def __init__(self, vae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler, torch_dtype=torch.float16):
26
+ super().__init__()
27
+ self.register_modules(vae=vae, unet=unet, scheduler=scheduler)
28
+ self.torch_dtype = torch_dtype
29
+
30
+ @torch.no_grad()
31
+ def __call__(
32
+ self,
33
+ batch_size: int = 8,
34
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
35
+ eta: float = 0.0,
36
+ num_inference_steps: int = 1000,
37
+ output_type: Optional[str] = "pil",
38
+ return_dict: bool = True,
39
+ **kwargs,
40
+ ) -> Union[Tuple, ImagePipelineOutput]:
41
+ r"""
42
+ Args:
43
+ batch_size (`int`, *optional*, defaults to 1):
44
+ Number of images to generate.
45
+ generator (`torch.Generator`, *optional*):
46
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
47
+ to make generation deterministic.
48
+ num_inference_steps (`int`, *optional*, defaults to 50):
49
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
50
+ expense of slower inference.
51
+ output_type (`str`, *optional*, defaults to `"pil"`):
52
+ The output format of the generate image. Choose between
53
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
54
+ return_dict (`bool`, *optional*, defaults to `True`):
55
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
56
+
57
+ Returns:
58
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.model.ImagePipelineOutput`] if `return_dict` is
59
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
60
+ """
61
+ if not isinstance(self.unet.config.sample_size,tuple):
62
+ self.unet.config.sample_size = (self.unet.config.sample_size,self.unet.config.sample_size)
63
+
64
+ latents = randn_tensor(
65
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
66
+ generator=generator,
67
+ )
68
+ latents = latents.to(self.device).type(self.torch_dtype)
69
+
70
+ # scale the initial noise by the standard deviation required by the scheduler (need to check)
71
+ latents = latents * self.scheduler.init_noise_sigma
72
+
73
+ self.scheduler.set_timesteps(num_inference_steps)
74
+
75
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
76
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
77
+
78
+ extra_kwargs = {}
79
+ if accepts_eta:
80
+ extra_kwargs["eta"] = eta
81
+
82
+ for t in self.progress_bar(self.scheduler.timesteps):
83
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
84
+ # predict the noise residual
85
+ noise_prediction = self.unet(latent_model_input, t).sample
86
+ # compute the previous noisy sample x_t -> x_t-1
87
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
88
+
89
+ # decode the image latents with the VAE
90
+ latents /= self.vae.config.scaling_factor#(0.18215)
91
+ image = self.vae.decode(latents).sample
92
+
93
+ image = (image / 2 + 0.5).clamp(0, 1)
94
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
95
+ if output_type == "pil":
96
+ image = self.numpy_to_pil(image)
97
+
98
+ if not return_dict:
99
+ return (image,)
100
+
101
+ return ImagePipelineOutput(images=image)
102
+
103
+
104
+ class SDMLDMPipeline(DiffusionPipeline):
105
+ r"""
106
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
107
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
108
+
109
+ Parameters:
110
+ vae ([`VQModel`]):
111
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
112
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
113
+ scheduler ([`SchedulerMixin`]):
114
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
115
+ """
116
+
117
+ def __init__(self, vae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler, torch_dtype=torch.float16, resolution=512, resolution_type="city"):
118
+ super().__init__()
119
+ self.register_modules(vae=vae, unet=unet, scheduler=scheduler)
120
+ self.torch_dtype = torch_dtype
121
+ self.resolution = resolution
122
+ self.resolution_type = resolution_type
123
+ @torch.no_grad()
124
+ def __call__(
125
+ self,
126
+ segmap = None,
127
+ batch_size: int = 8,
128
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
129
+ eta: float = 0.0,
130
+ num_inference_steps: int = 1000,
131
+ output_type: Optional[str] = "pil",
132
+ return_dict: bool = True,
133
+ every_step_save: int = None,
134
+ s: int = 1,
135
+ **kwargs,
136
+ ) -> Union[Tuple, ImagePipelineOutput]:
137
+ r"""
138
+ Args:
139
+ batch_size (`int`, *optional*, defaults to 1):
140
+ Number of images to generate.
141
+ generator (`torch.Generator`, *optional*):
142
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
143
+ to make generation deterministic.
144
+ num_inference_steps (`int`, *optional*, defaults to 50):
145
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
146
+ expense of slower inference.
147
+ output_type (`str`, *optional*, defaults to `"pil"`):
148
+ The output format of the generate image. Choose between
149
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
150
+ return_dict (`bool`, *optional*, defaults to `True`):
151
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
152
+
153
+ Returns:
154
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.model.ImagePipelineOutput`] if `return_dict` is
155
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
156
+ """
157
+ # self.unet.config.sample_size = (64, 64) # (135,180)
158
+ # self.unet.config.sample_size = (135,180)
159
+ if self.resolution_type == "crack":
160
+ self.unet.config.sample_size = (64,64)
161
+ elif self.resolution_type == "crack_256":
162
+ self.unet.config.sample_size = (256,256)
163
+ else:
164
+ sc = 1080 // self.resolution
165
+ latent_size = (self.resolution // 4, 1440 // (sc*4))
166
+ self.unet.config.sample_size = latent_size
167
+ #
168
+ if not isinstance(self.unet.config.sample_size, tuple):
169
+ self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
170
+
171
+ if segmap is None:
172
+ print("Didn't inpute any segmap, use the empty as the input")
173
+ segmap = torch.zeros(batch_size,self.unet.config.segmap_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1])
174
+ segmap = segmap.to(self.device).type(self.torch_dtype)
175
+ latents = randn_tensor(
176
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
177
+ generator=generator,
178
+ )
179
+ latents = latents.to(self.device).type(self.torch_dtype)
180
+
181
+ # scale the initial noise by the standard deviation required by the scheduler (need to check)
182
+ latents = latents * self.scheduler.init_noise_sigma
183
+
184
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps)
185
+
186
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
187
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
188
+
189
+ extra_kwargs = {}
190
+ if accepts_eta:
191
+ extra_kwargs["eta"] = eta
192
+
193
+ step_latent = []
194
+ learn_sigma = True if hasattr(self.scheduler, "variance_type") else False
195
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
196
+
197
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
198
+ # predict the noise residual
199
+ noise_prediction = self.unet(latent_model_input, segmap, t).sample
200
+ # compute the previous noisy sample x_t -> x_t-1
201
+
202
+
203
+ if learn_sigma and "learn" in self.scheduler.variance_type:
204
+ model_pred, var_pred = torch.split(noise_prediction, latents.shape[1], dim=1)
205
+ else:
206
+ model_pred = noise_prediction
207
+ if s > 1.0:
208
+ model_output_zero = self.unet(latent_model_input, torch.zeros_like(segmap), t).sample
209
+ if learn_sigma and "learn" in self.scheduler.variance_type:
210
+ model_output_zero,_ = torch.split(model_output_zero, latents.shape[1], dim=1)
211
+ model_pred = model_pred + s * (model_pred - model_output_zero)
212
+ if learn_sigma and "learn" in self.scheduler.variance_type:
213
+ recombined = torch.cat((model_pred, var_pred), dim=1)
214
+ # when apply different scheduler, mean only !!
215
+ if learn_sigma and "learn" in self.scheduler.variance_type:
216
+ latents = self.scheduler.step(recombined, t, latents, **extra_kwargs).prev_sample
217
+ else:
218
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
219
+
220
+ if every_step_save is not None:
221
+ if (i+1) % every_step_save == 0:
222
+ step_latent.append(copy.deepcopy(latents))
223
+
224
+ # decode the image latents with the VAE
225
+ if every_step_save is not None:
226
+ image = []
227
+ for i, l in enumerate(step_latent):
228
+ l /= self.vae.config.scaling_factor # (0.18215)
229
+ #latents /= 7.706491063029163
230
+ l = self.vae.decode(l, segmap)
231
+ l = (l / 2 + 0.5).clamp(0, 1)
232
+ l = l.cpu().permute(0, 2, 3, 1).numpy()
233
+ if output_type == "pil":
234
+ l = self.numpy_to_pil(l)
235
+ image.append(l)
236
+ else:
237
+ latents /= self.vae.config.scaling_factor#(0.18215)
238
+ #latents /= 7.706491063029163
239
+ # image = self.vae.decode(latents, segmap).sample
240
+ image = self.vae.decode(latents, return_dict=False)[0]
241
+ image = (image / 2 + 0.5).clamp(0, 1)
242
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
243
+ if output_type == "pil":
244
+ image = self.numpy_to_pil(image)
245
+
246
+ if not return_dict:
247
+ return (image,)
248
+
249
+ return ImagePipelineOutput(images=image)
250
+
251
+
252
+ class SDMPipeline(DiffusionPipeline):
253
+ r"""
254
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
255
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
256
+
257
+ Parameters:
258
+ vae ([`VQModel`]):
259
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
260
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
261
+ scheduler ([`SchedulerMixin`]):
262
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
263
+ """
264
+
265
+ def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler, torch_dtype=torch.float16, vae=None):
266
+ super().__init__()
267
+ self.register_modules(unet=unet, scheduler=scheduler)
268
+ self.torch_dtype = torch_dtype
269
+
270
+ @torch.no_grad()
271
+ def __call__(
272
+ self,
273
+ segmap = None,
274
+ batch_size: int = 8,
275
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
276
+ eta: float = 0.0,
277
+ num_inference_steps: int = 1000,
278
+ output_type: Optional[str] = "pil",
279
+ return_dict: bool = True,
280
+ s: int = 1,
281
+ **kwargs,
282
+ ) -> Union[Tuple, ImagePipelineOutput]:
283
+ r"""
284
+ Args:
285
+ batch_size (`int`, *optional*, defaults to 1):
286
+ Number of images to generate.
287
+ generator (`torch.Generator`, *optional*):
288
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
289
+ to make generation deterministic.
290
+ num_inference_steps (`int`, *optional*, defaults to 50):
291
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
292
+ expense of slower inference.
293
+ output_type (`str`, *optional*, defaults to `"pil"`):
294
+ The output format of the generate image. Choose between
295
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
296
+ return_dict (`bool`, *optional*, defaults to `True`):
297
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
298
+
299
+ Returns:
300
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.model.ImagePipelineOutput`] if `return_dict` is
301
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
302
+ """
303
+ self.unet.config.sample_size = (270,360)
304
+ if not isinstance(self.unet.config.sample_size, tuple):
305
+ self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
306
+
307
+ if segmap is None:
308
+ print("Didn't inpute any segmap, use the empty as the input")
309
+ segmap = torch.zeros(batch_size,self.unet.config.segmap_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1])
310
+ segmap = segmap.to(self.device).type(self.torch_dtype)
311
+ latents = randn_tensor(
312
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
313
+ generator=generator,
314
+ )
315
+
316
+ latents = latents.to(self.device).type(self.torch_dtype)
317
+
318
+ # scale the initial noise by the standard deviation required by the scheduler (need to check)
319
+ latents = latents * self.scheduler.init_noise_sigma
320
+
321
+ self.scheduler.set_timesteps(num_inference_steps)
322
+
323
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
324
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
325
+
326
+ extra_kwargs = {}
327
+ if accepts_eta:
328
+ extra_kwargs["eta"] = eta
329
+
330
+ for t in self.progress_bar(self.scheduler.timesteps):
331
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
332
+ # predict the noise residual
333
+ noise_prediction = self.unet(latent_model_input, segmap, t).sample
334
+
335
+ #noise_prediction = noise_prediction[]
336
+
337
+ if s > 1.0:
338
+ model_output_zero = self.unet(latent_model_input, torch.zeros_like(segmap), t).sample
339
+ noise_prediction[:, :3] = model_output_zero[:, :3] + s * (noise_prediction[:, :3] - model_output_zero[:, :3])
340
+
341
+ #noise_prediction = noise_prediction[:, :3]
342
+
343
+ # compute the previous noisy sample x_t -> x_t-1
344
+ #breakpoint()
345
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
346
+
347
+ # decode the image latents with the VAE
348
+ # latents /= self.vae.config.scaling_factor#(0.18215)
349
+ # image = self.vae.decode(latents).sample
350
+ image = latents
351
+ #image = (image + 1) / 2.0
352
+ image = (image / 2 + 0.5).clamp(0, 1)
353
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
354
+ if output_type == "pil":
355
+ image = self.numpy_to_pil(image)
356
+
357
+ if not return_dict:
358
+ return (image,)
359
+
360
+ return ImagePipelineOutput(images=image)
361
+
code_for_ade20k/diffusion_module/utils/__pycache__/LSDMPipeline_expandDataset.cpython-39.pyc ADDED
Binary file (5.62 kB). View file
 
code_for_ade20k/diffusion_module/utils/__pycache__/Pipline.cpython-310.pyc ADDED
Binary file (8.22 kB). View file
 
code_for_ade20k/diffusion_module/utils/__pycache__/Pipline.cpython-39.pyc ADDED
Binary file (8.52 kB). View file
 
code_for_ade20k/diffusion_module/utils/__pycache__/loss.cpython-39.pyc ADDED
Binary file (4.06 kB). View file
 
code_for_ade20k/diffusion_module/utils/loss.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for various likelihood-based losses. These are ported from the original
3
+ Ho et al. diffusion models codebase:
4
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ import torch as th
10
+
11
+
12
+ def normal_kl(mean1, logvar1, mean2, logvar2):
13
+ """
14
+ Compute the KL divergence between two gaussians.
15
+
16
+ Shapes are automatically broadcasted, so batches can be compared to
17
+ scalars, among other use cases.
18
+ """
19
+ tensor = None
20
+ for obj in (mean1, logvar1, mean2, logvar2):
21
+ if isinstance(obj, th.Tensor):
22
+ tensor = obj
23
+ break
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ return 0.5 * (
34
+ -1.0
35
+ + logvar2
36
+ - logvar1
37
+ + th.exp(logvar1 - logvar2)
38
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39
+ )
40
+
41
+
42
+ def approx_standard_normal_cdf(x):
43
+ """
44
+ A fast approximation of the cumulative distribution function of the
45
+ standard normal.
46
+ """
47
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48
+
49
+
50
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51
+ """
52
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
53
+ given image.
54
+
55
+ :param x: the target images. It is assumed that this was uint8 values,
56
+ rescaled to the range [-1, 1].
57
+ :param means: the Gaussian mean Tensor.
58
+ :param log_scales: the Gaussian log stddev Tensor.
59
+ :return: a tensor like x of log probabilities (in nats).
60
+ """
61
+ assert x.shape == means.shape == log_scales.shape
62
+ centered_x = x - means
63
+ inv_stdv = th.exp(-log_scales)
64
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65
+ cdf_plus = approx_standard_normal_cdf(plus_in)
66
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67
+ cdf_min = approx_standard_normal_cdf(min_in)
68
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70
+ cdf_delta = cdf_plus - cdf_min
71
+ log_probs = th.where(
72
+ x < -0.999,
73
+ log_cdf_plus,
74
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75
+ )
76
+ assert log_probs.shape == x.shape
77
+ return log_probs
78
+
79
+ def variance_KL_loss(latents, noisy_latents, timesteps, model_pred_mean, model_pred_var, noise_scheduler,posterior_mean_coef1, posterior_mean_coef2, posterior_log_variance_clipped):
80
+ model_pred_mean = model_pred_mean.detach()
81
+ true_mean = (
82
+ posterior_mean_coef1.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * latents
83
+ + posterior_mean_coef2.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * noisy_latents
84
+ )
85
+
86
+ true_log_variance_clipped = posterior_log_variance_clipped.to(device=timesteps.device)[timesteps].float()[
87
+ ..., None, None, None]
88
+
89
+ if noise_scheduler.variance_type == "learned":
90
+ model_log_variance = model_pred_var
91
+ #model_pred_var = th.exp(model_log_variance)
92
+ else:
93
+ min_log = true_log_variance_clipped
94
+ max_log = th.log(noise_scheduler.betas.to(device=timesteps.device)[timesteps].float()[..., None, None, None])
95
+ frac = (model_pred_var + 1) / 2
96
+ model_log_variance = frac * max_log + (1 - frac) * min_log
97
+ #model_pred_var = th.exp(model_log_variance)
98
+
99
+ sqrt_recip_alphas_cumprod = th.sqrt(1.0 / noise_scheduler.alphas_cumprod)
100
+ sqrt_recipm1_alphas_cumprod = th.sqrt(1.0 / noise_scheduler.alphas_cumprod - 1)
101
+
102
+ pred_xstart = (sqrt_recip_alphas_cumprod.to(device=timesteps.device)[timesteps].float()[
103
+ ..., None, None, None] * noisy_latents
104
+ - sqrt_recipm1_alphas_cumprod.to(device=timesteps.device)[timesteps].float()[
105
+ ..., None, None, None] * model_pred_mean)
106
+
107
+ model_mean = (
108
+ posterior_mean_coef1.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * pred_xstart
109
+ + posterior_mean_coef2.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * noisy_latents
110
+ )
111
+
112
+ # model_mean = out["mean"] model_log_variance = out["log_variance"]
113
+ kl = normal_kl(
114
+ true_mean, true_log_variance_clipped, model_mean, model_log_variance
115
+ )
116
+ kl = kl.mean() / np.log(2.0)
117
+
118
+ decoder_nll = -discretized_gaussian_log_likelihood(
119
+ latents, means=model_mean, log_scales=0.5 * model_log_variance
120
+ )
121
+ assert decoder_nll.shape == latents.shape
122
+ decoder_nll = decoder_nll.mean() / np.log(2.0)
123
+
124
+ # At the first timestep return the decoder NLL,
125
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
126
+ kl_loss = th.where((timesteps == 0), decoder_nll, kl).mean()
127
+ return kl_loss
128
+
129
+ def get_variance(noise_scheduler):
130
+ alphas_cumprod_prev = th.cat([th.tensor([1.0]), noise_scheduler.alphas_cumprod[:-1]])
131
+
132
+ posterior_mean_coef1 = (
133
+ noise_scheduler.betas * th.sqrt(alphas_cumprod_prev) / (1.0 - noise_scheduler.alphas_cumprod)
134
+ )
135
+
136
+ posterior_mean_coef2 = (
137
+ (1.0 - alphas_cumprod_prev)
138
+ * th.sqrt(noise_scheduler.alphas)
139
+ / (1.0 - noise_scheduler.alphas_cumprod)
140
+ )
141
+
142
+ posterior_variance = (
143
+ noise_scheduler.betas * (1.0 - alphas_cumprod_prev) / (1.0 - noise_scheduler.alphas_cumprod)
144
+ )
145
+ posterior_log_variance_clipped = th.log(
146
+ th.cat([posterior_variance[1][..., None], posterior_variance[1:]])
147
+ )
148
+ #res = posterior_log_variance_clipped.to(device=timesteps.device)[timesteps].float()
149
+ return posterior_mean_coef1, posterior_mean_coef2, posterior_log_variance_clipped #res[..., None, None, None]
code_for_ade20k/diffusion_module/utils/noise_sampler.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.mixture import GaussianMixture
2
+
3
+ def get_noise_sampler(sample_type='gau'):
4
+ if sample_type == 'gau':
5
+ sampler = lambda latnt_sz: torch.randn_like(latnt_sz)
6
+ elif sample_type == 'gau_offset':
7
+ sampler = lambda latnt_sz: torch.randn_like(latnt_sz) + (torch.randn_like(latnt_sz))
8
+ ...
9
+ elif sample_type == 'gmm':
10
+ ...
11
+ else:
12
+ ...
13
+ return
14
+
15
+ if __name__ == "__main__":
16
+ ...
code_for_ade20k/diffusion_module/utils/scheduler_factory.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from diffusers import DDPMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler, DPMSolverSinglestepScheduler
3
+ from diffusers.pipeline_utils import DiffusionPipeline
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from typing import List, Optional, Tuple, Union
7
+ import numpy as np
8
+ from diffusers.schedulers.scheduling_utils import SchedulerOutput
9
+ from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
10
+ from diffusers.utils import randn_tensor, BaseOutput
11
+
12
+
13
+ ### Testing the DDPM Scheduler for Variant
14
+ class ModifiedDDPMScheduler(DDPMScheduler):
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+
18
+ def step(
19
+ self,
20
+ model_output: torch.FloatTensor,
21
+ timestep: int,
22
+ sample: torch.FloatTensor,
23
+ generator=None,
24
+ return_dict: bool = True,
25
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
26
+ """
27
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
28
+ process from the learned model outputs (most often the predicted noise).
29
+
30
+ Args:
31
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
32
+ timestep (`int`): current discrete timestep in the diffusion chain.
33
+ sample (`torch.FloatTensor`):
34
+ current instance of sample being created by diffusion process.
35
+ generator: random number generator.
36
+ return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
37
+
38
+ Returns:
39
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
40
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
41
+ returning a tuple, the first element is the sample tensor.
42
+
43
+ """
44
+ t = timestep
45
+
46
+ prev_t = self.previous_timestep(t)
47
+
48
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
49
+ print("Conidtion is trigger")
50
+
51
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
52
+ # [2,3, 64, 128]
53
+ else:
54
+ predicted_variance = None
55
+
56
+ # 1. compute alphas, betas
57
+ alpha_prod_t = self.alphas_cumprod[t]
58
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
59
+ beta_prod_t = 1 - alpha_prod_t
60
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
61
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
62
+ current_beta_t = 1 - current_alpha_t
63
+
64
+ # 2. compute predicted original sample from predicted noise also called
65
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
66
+ if self.config.prediction_type == "epsilon":
67
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
68
+
69
+ elif self.config.prediction_type == "sample":
70
+ pred_original_sample = model_output
71
+ elif self.config.prediction_type == "v_prediction":
72
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
73
+ else:
74
+ raise ValueError(
75
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
76
+ " `v_prediction` for the DDPMScheduler."
77
+ )
78
+
79
+ # 3. Clip or threshold "predicted x_0"
80
+ if self.config.thresholding:
81
+ pred_original_sample = self._threshold_sample(pred_original_sample)
82
+ elif self.config.clip_sample:
83
+ pred_original_sample = pred_original_sample.clamp(
84
+ -self.config.clip_sample_range, self.config.clip_sample_range
85
+ )
86
+
87
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
88
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
89
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
90
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
91
+
92
+ # 5. Compute predicted previous sample µ_t
93
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
94
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
95
+
96
+ # 6. Add noise
97
+ variance = 0
98
+ if t > 0:
99
+ device = model_output.device
100
+ variance_noise = randn_tensor(
101
+ model_output.shape, generator=generator, device=device, dtype=model_output.dtype
102
+ )
103
+ if self.variance_type == "fixed_small_log":
104
+ variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
105
+
106
+ elif self.variance_type == "learned_range":
107
+ variance = self._get_variance(t, predicted_variance=predicted_variance)
108
+ variance = torch.exp(0.5 * variance) * variance_noise
109
+
110
+ else:
111
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
112
+
113
+ pred_prev_sample = pred_prev_sample + variance
114
+ print(pred_prev_sample.shape)
115
+ if not return_dict:
116
+ return (pred_prev_sample,)
117
+
118
+ return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
119
+
120
+
121
+ class ModifiedUniPCScheduler(UniPCMultistepScheduler):
122
+ '''
123
+ This is the modification of UniPCMultistepScheduler, which is the same as UniPCMultistepScheduler except for the _get_variance function.
124
+ '''
125
+ def __init__(self, variance_type: str = "fixed_small", *args, **kwargs):
126
+ super().__init__(*args, **kwargs)
127
+ self.custom_timesteps = False
128
+ self.variance_type=variance_type
129
+ self.config.timestep_spacing="leading"
130
+ def previous_timestep(self, timestep):
131
+ if self.custom_timesteps:
132
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
133
+ if index == self.timesteps.shape[0] - 1:
134
+ prev_t = torch.tensor(-1)
135
+ else:
136
+ prev_t = self.timesteps[index + 1]
137
+ else:
138
+ num_inference_steps = (
139
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
140
+ )
141
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
142
+
143
+ return prev_t
144
+
145
+ def _get_variance(self, t, predicted_variance=None, variance_type="learned_range"):
146
+ prev_t = self.previous_timestep(t)
147
+
148
+ alpha_prod_t = self.alphas_cumprod[t]
149
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
150
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
151
+
152
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
153
+
154
+ variance = torch.clamp(variance, min=1e-20)
155
+
156
+ if variance_type is None:
157
+ variance_type = self.config.variance_type
158
+
159
+ if variance_type == "fixed_small":
160
+ variance = variance
161
+ elif variance_type == "fixed_small_log":
162
+ variance = torch.log(variance)
163
+ variance = torch.exp(0.5 * variance)
164
+ elif variance_type == "fixed_large":
165
+ variance = current_beta_t
166
+ elif variance_type == "fixed_large_log":
167
+ variance = torch.log(current_beta_t)
168
+ elif variance_type == "learned":
169
+ return predicted_variance
170
+ elif variance_type == "learned_range":
171
+ min_log = torch.log(variance)
172
+ max_log = torch.log(current_beta_t)
173
+ frac = (predicted_variance + 1) / 2
174
+ variance = frac * max_log + (1 - frac) * min_log
175
+
176
+ return variance
177
+
178
+ def step(self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True) -> Union[SchedulerOutput, Tuple]:
179
+
180
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
181
+ print("condition using predicted_variance is trigger")
182
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
183
+ else:
184
+ predicted_variance = None
185
+
186
+ super_output = super().step(model_output, timestep, sample, return_dict=False)
187
+ prev_sample = super_output[0]
188
+ # breakpoint()
189
+ variance = 0
190
+ if timestep > 0:
191
+ device = model_output.device
192
+ variance_noise = randn_tensor(
193
+ model_output.shape, generator=None, device=device, dtype=model_output.dtype
194
+ )
195
+ if self.variance_type == "fixed_small_log":
196
+ variance = self._get_variance(timestep, predicted_variance=predicted_variance) * variance_noise
197
+ elif self.variance_type == "learned_range":
198
+ # breakpoint()
199
+ variance = self._get_variance(timestep, predicted_variance=predicted_variance)
200
+ variance = torch.exp(0.5 * variance) * variance_noise
201
+ # breakpoint()
202
+ else:
203
+ variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * variance_noise
204
+
205
+
206
+ # breakpoint()
207
+ print("time step is ", timestep)
208
+ prev_sample = prev_sample + variance
209
+
210
+ if not return_dict:
211
+ return (prev_sample,)
212
+
213
+ return DDPMSchedulerOutput(prev_sample=prev_sample,pred_original_sample=prev_sample)
214
+
215
+ #return SchedulerOutput(prev_sample=prev_sample)
216
+
217
+
218
+ def build_proc(sch_cfg=None, _sch=None, **kwargs):
219
+ if kwargs:
220
+ return _sch(**kwargs)
221
+
222
+ type_str = str(type(sch_cfg))
223
+ if 'dict' in type_str:
224
+ return _sch.from_config(**sch_cfg)
225
+ return _sch.from_config(sch_cfg, subfolder="scheduler")
226
+
227
+ scheduler_factory = {
228
+ 'UniPC' : partial(build_proc, _sch=UniPCMultistepScheduler),
229
+ 'modifiedUniPC' : partial(build_proc, _sch=ModifiedUniPCScheduler),
230
+ # DPM family
231
+ 'DDPM' : partial(build_proc, _sch=DDPMScheduler),
232
+ 'DPMSolver' : partial(build_proc, _sch=DPMSolverMultistepScheduler, algorithm_type='dpmsolver'),
233
+ 'DPMSolver++' : partial(build_proc, _sch=DPMSolverMultistepScheduler),
234
+ 'DPMSolverSingleStep' : partial(build_proc, _sch=DPMSolverSinglestepScheduler)
235
+
236
+ }
237
+
238
+ def scheduler_setup(pipe : DiffusionPipeline = None, scheduler_type : str = 'UniPC', from_config=None, **kwargs):
239
+ if not isinstance(pipe, DiffusionPipeline):
240
+ raise TypeError(f'pipe should be DiffusionPipeline, but given {type(pipe)}\n')
241
+
242
+ sch_cfg = from_config if from_config else pipe.scheduler.config
243
+ #sch_cfg = diffusers.configuration_utils.FrozenDict({**sch_cfg, 'solver_order':3})
244
+ #pipe.scheduler = scheduler_factory[scheduler_type](**kwargs) if kwargs \
245
+ # else scheduler_factory[scheduler_type](sch_cfg)
246
+
247
+ # pipe.scheduler = DPMSolverSinglestepScheduler()
248
+ # #pipe.scheduler = DDPMScheduler(beta_schedule="linear", variance_type="learned_range")
249
+ # print(pipe.scheduler)
250
+ print("Scheduler type in Scheduler_factory.py is Hard-coded to modifyUniPC, Please change it back to AutoDetect functionality if you want to change scheudler")
251
+ pipe.scheduler = ModifiedUniPCScheduler(variance_type="learned_range", )
252
+ # pipe.scheduler = ModifiedDDPMScheduler(beta_schedule="linear", variance_type="learned_range")
253
+
254
+ #pipe.scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
255
+ #pipe.scheduler._get_variance = _get_variance
256
+ return pipe
257
+
258
+ # unittest of scheduler..
259
+ if __name__ == "__main__":
260
+ def ld_mod():
261
+ noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
262
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to("cuda").to(torch.float16)
263
+ unet = SDMUNet2DModel.from_pretrained("/data/harry/Data_generation/diffusers-main/examples/VAESDM/LDM-sdm-model/checkpoint-46000", subfolder="unet").to("cuda").to(torch.float16)
264
+ return noise_scheduler, vae, unet
265
+
266
+ from Pipline import SDMLDMPipeline
267
+ from diffusers import StableDiffusionPipeline
268
+ import torch
269
+
270
+ path = "CompVis/stable-diffusion-v1-4"
271
+ pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
272
+
273
+ # change scheduler
274
+ # customized args : once you customized, customize forever ~ no from_config
275
+ #pipe = scheduler_setup(pipe, 'DPMSolver++', thresholding=True)
276
+ # from_config
277
+ pipe = scheduler_setup(pipe, 'DPMSolverSingleStep')
278
+
279
+ pipe = pipe.to("cuda")
280
+ prompt = "a highly realistic photo of green turtle"
281
+ generator = torch.manual_seed(0)
282
+ # only 15 steps are needed for good results => 2-4 seconds on GPU
283
+ image = pipe(prompt, generator=generator, num_inference_steps=15).images[0]
284
+ # save image
285
+ image.save("turtle.png")
286
+
287
+ '''
288
+ # load & wrap submodules into pipe-API
289
+ noise_scheduler, vae, unet = ld_mod()
290
+ pipe = SDMLDMPipeline(
291
+ unet=unet,
292
+ vqvae=vae,
293
+ scheduler=noise_scheduler,
294
+ torch_dtype=torch.float16
295
+ )
296
+
297
+ # change scheduler
298
+ pipe = scheduler_setup(pipe, 'DPMSolverSingleStep')
299
+ pipe = pipe.to("cuda")
300
+ '''
code_for_ade20k/train_SDM_LDM_ade.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+
17
+ import logging
18
+ import math
19
+ import os
20
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
21
+ from pathlib import Path
22
+
23
+ import accelerate
24
+ import datasets
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ from accelerate import Accelerator
31
+ from accelerate.logging import get_logger
32
+ from accelerate.utils import ProjectConfiguration, set_seed
33
+ from huggingface_hub import create_repo, upload_folder
34
+ from packaging import version
35
+ from tqdm.auto import tqdm
36
+
37
+ import diffusers
38
+ from diffusers import AutoencoderKL, DDPMScheduler
39
+ from diffusers.optimization import get_scheduler
40
+ from diffusers.training_utils import EMAModel
41
+ from diffusers.utils import deprecate
42
+ from diffusers.utils.import_utils import is_xformers_available
43
+ from diffusion_module.utils.Pipline import SDMLDMPipeline
44
+ from diffusion_module.unet_2d_sdm import SDMUNet2DModel
45
+ from diffusion_module.unet import UNetModel
46
+ from diffusers.schedulers import DDIMScheduler,UniPCMultistepScheduler
47
+
48
+ # from taming.models.vqvae import VQSub
49
+ from diffusion_module.utils.loss import get_variance, variance_KL_loss
50
+
51
+ from dataset.ade20k import load_data
52
+ from crack_config_utils.parse_args_ade import parse_args
53
+ from crack_config_utils.utils_ade import log_validation, preprocess_input
54
+ import datetime
55
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
56
+
57
+ logger = get_logger(__name__, log_level="INFO")
58
+
59
+
60
+ def main():
61
+
62
+ args = parse_args()
63
+
64
+ if args.non_ema_revision is not None:
65
+ deprecate(
66
+ "non_ema_revision!=None",
67
+ "0.15.0",
68
+ message=(
69
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
70
+ " use `--variant=non_ema` instead."
71
+ ),
72
+ )
73
+
74
+ current_time = datetime.datetime.now()
75
+ timestamp = current_time.strftime("%Y-%m-%d-%H%M")
76
+ output_dir = os.path.join(args.output_dir, timestamp)
77
+ logging_dir = os.path.join(output_dir, args.logging_dir)
78
+
79
+ accelerator_project_config = ProjectConfiguration(project_dir=output_dir, logging_dir=logging_dir,
80
+ total_limit=args.checkpoints_total_limit)
81
+
82
+ accelerator = Accelerator(
83
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
84
+ mixed_precision=args.mixed_precision,
85
+ log_with=args.report_to,
86
+ project_config=accelerator_project_config,
87
+ )
88
+
89
+ # Make one log on every process with the configuration for debugging.
90
+ logging.basicConfig(
91
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
92
+ datefmt="%m/%d/%Y %H:%M:%S",
93
+ level=logging.INFO,
94
+ )
95
+ logger.info(accelerator.state, main_process_only=False)
96
+ if accelerator.is_local_main_process:
97
+ datasets.utils.logging.set_verbosity_warning()
98
+ transformers.utils.logging.set_verbosity_warning()
99
+ diffusers.utils.logging.set_verbosity_info()
100
+ else:
101
+ datasets.utils.logging.set_verbosity_error()
102
+ transformers.utils.logging.set_verbosity_error()
103
+ diffusers.utils.logging.set_verbosity_error()
104
+
105
+ # If passed along, set the training seed now.
106
+ if args.seed is not None:
107
+ set_seed(args.seed)
108
+
109
+ # Handle the repository creation
110
+ if accelerator.is_main_process:
111
+ if args.output_dir is not None:
112
+ os.makedirs(args.output_dir, exist_ok=True)
113
+
114
+ if args.push_to_hub:
115
+ repo_id = create_repo(
116
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
117
+ ).repo_id
118
+
119
+ # Load scheduler and models.
120
+ # noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
121
+ # noise_scheduler.variance_type = "learned_range"
122
+ # noise_scheduler = DDPMScheduler(variance_type="learned_range")
123
+ noise_scheduler = UniPCMultistepScheduler()
124
+ # noise_scheduler = DDPMScheduler()
125
+ # noise_scheduler = DDPMScheduler(variance_type="learned_range", beta_end=0.012,beta_start=0.00085
126
+ # , beta_schedule="scaled_linear",num_train_timesteps=1000, skip_prk_steps=True
127
+ # , steps_offset=1,trained_betas=None,clip_sample=False)
128
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
129
+ # vae = VQModel.from_pretrained("CompVis/ldm-super-resolution-4x-openimages", subfolder="vqvae", revision=args.revision)
130
+ # vae = VQSub.from_pretrained("/data/harry/Data_generation/diffusers-main/VQVAE/SPADE_VQ_model_V2/99ep", subfolder="vqvae")
131
+ # vae = VQSub.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
132
+ # Freeze vae
133
+ vae.requires_grad_(False)
134
+
135
+
136
+ latent_size = (64, 64)
137
+ print(latent_size)
138
+ unet = UNetModel(
139
+ image_size = latent_size,
140
+ in_channels=vae.config.latent_channels,
141
+ model_channels=256,
142
+ # out_channels=vae.config.latent_channels*2 if "learned" in noise_scheduler.variance_type else vae.config.latent_channels,
143
+ out_channels=vae.config.latent_channels,
144
+ num_res_blocks=2,
145
+ # attention_resolutions=(8, 16, 32),
146
+ attention_resolutions=(2, 4, 8),
147
+ dropout=0,
148
+ # channel_mult=(1, 1, 2, 2, 4, 4),
149
+ channel_mult=(1, 2, 3, 4),
150
+ num_heads=8,
151
+ num_head_channels=-1,
152
+ num_heads_upsample=-1,
153
+ use_scale_shift_norm=True,
154
+ resblock_updown=True,
155
+ use_new_attention_order=False,
156
+ num_classes=args.segmap_channels,
157
+ mask_emb="resize",
158
+ use_checkpoint=True,
159
+ SPADE_type="spade",
160
+ )
161
+
162
+ if args.resume_dir is not None:
163
+ unet = unet.from_pretrained(args.resume_dir)
164
+
165
+ # Create EMA for the unet.
166
+ if args.use_ema:
167
+ ema_unet = EMAModel(
168
+ unet.parameters(),
169
+ decay=args.ema_max_decay,
170
+ use_ema_warmup=True,
171
+ inv_gamma=args.ema_inv_gamma,
172
+ power=args.ema_power,
173
+ model_cls=UNetModel,
174
+ model_config=unet.config,
175
+ )
176
+
177
+ if args.enable_xformers_memory_efficient_attention:
178
+ if is_xformers_available():
179
+ import xformers
180
+
181
+ xformers_version = version.parse(xformers.__version__)
182
+ if xformers_version == version.parse("0.0.16"):
183
+ logger.warn(
184
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
185
+ )
186
+ unet.enable_xformers_memory_efficient_attention()
187
+ else:
188
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
189
+
190
+ def compute_snr(timesteps):
191
+ """
192
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
193
+ """
194
+ alphas_cumprod = noise_scheduler.alphas_cumprod
195
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
196
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
197
+
198
+ # Expand the tensors.
199
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
200
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
201
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
202
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
203
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
204
+
205
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
206
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
207
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
208
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
209
+
210
+ # Compute SNR.
211
+ snr = (alpha / sigma) ** 2
212
+ return snr
213
+
214
+ # `accelerate` 0.16.0 will have better support for customized saving
215
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
216
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
217
+ def save_model_hook(models, weights, output_dir):
218
+ if args.use_ema:
219
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
220
+
221
+ for i, model in enumerate(models):
222
+ model.save_pretrained(os.path.join(output_dir, "unet"))
223
+
224
+ # make sure to pop weight so that corresponding model is not saved again
225
+ weights.pop()
226
+
227
+ def load_model_hook(models, input_dir):
228
+ if args.use_ema:
229
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), SDMUNet2DModel)
230
+ ema_unet.load_state_dict(load_model.state_dict())
231
+ ema_unet.to(accelerator.device)
232
+ del load_model
233
+
234
+ for i in range(len(models)):
235
+ # pop models so that they are not loaded again
236
+ model = models.pop()
237
+
238
+ # load diffusers style into model
239
+ load_model = UNetModel.from_pretrained(input_dir, subfolder="unet")
240
+ model.register_to_config(**load_model.config)
241
+
242
+ model.load_state_dict(load_model.state_dict())
243
+ del load_model
244
+
245
+ accelerator.register_save_state_pre_hook(save_model_hook)
246
+ accelerator.register_load_state_pre_hook(load_model_hook)
247
+
248
+ if args.gradient_checkpointing:
249
+ unet.enable_gradient_checkpointing()
250
+
251
+ if args.scale_lr:
252
+ args.learning_rate = (
253
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
254
+ )
255
+
256
+
257
+ optimizer_cls = torch.optim.AdamW
258
+
259
+ optimizer = optimizer_cls(
260
+ unet.parameters(),
261
+ lr=args.learning_rate,
262
+ betas=(args.adam_beta1, args.adam_beta2),
263
+ weight_decay=args.adam_weight_decay,
264
+ eps=args.adam_epsilon,
265
+ )
266
+
267
+ train_dataloader, train_dataset = load_data(
268
+ dataset_mode="ade20k",
269
+ data_dir=args.data_root,
270
+ batch_size=args.train_batch_size,
271
+ image_size= args.resolution,
272
+ is_train=True)
273
+
274
+ val_dataloader, _ = load_data(
275
+ dataset_mode="ade20k",
276
+ data_dir=args.data_root,
277
+ batch_size=1,
278
+ image_size= args.resolution,
279
+ is_train=False)
280
+
281
+ # Scheduler and math around the number of training steps.
282
+ overrode_max_train_steps = False
283
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
284
+ if args.max_train_steps is None:
285
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
286
+ overrode_max_train_steps = True
287
+
288
+ lr_scheduler = get_scheduler(
289
+ args.lr_scheduler,
290
+ optimizer=optimizer,
291
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
292
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
293
+ )
294
+
295
+ # Prepare everything with our `accelerator`.
296
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
297
+ unet, optimizer, train_dataloader, lr_scheduler
298
+ )
299
+
300
+ if args.use_ema:
301
+ ema_unet.to(accelerator.device)
302
+
303
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
304
+ # as these models are only used for inference, keeping weights in full precision is not required.
305
+ weight_dtype = torch.float32
306
+ if accelerator.mixed_precision == "fp16":
307
+ weight_dtype = torch.float16
308
+ elif accelerator.mixed_precision == "bf16":
309
+ weight_dtype = torch.bfloat16
310
+
311
+ # Move vae to gpu and cast to weight_dtype
312
+ vae.to(accelerator.device, dtype=weight_dtype)
313
+
314
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
315
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
316
+ if overrode_max_train_steps:
317
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
318
+ # Afterwards we recalculate our number of training epochs
319
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
320
+
321
+ # We need to initialize the trackers we use, and also store our configuration.
322
+ # The trackers initializes automatically on the main process.
323
+ if accelerator.is_main_process:
324
+ tracker_config = dict(vars(args))
325
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
326
+
327
+ # Train!
328
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
329
+
330
+ logger.info("***** Running training *****")
331
+ logger.info(f" Num examples = {len(train_dataset)}")
332
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
333
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
334
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
335
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
336
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
337
+ global_step = 0
338
+ first_epoch = 0
339
+
340
+ # Potentially load in the weights and states from a previous save
341
+ if args.resume_from_checkpoint:
342
+ if args.resume_from_checkpoint != "latest":
343
+ path = os.path.basename(args.resume_from_checkpoint)
344
+ else:
345
+ # Get the most recent checkpoint
346
+ dirs = os.listdir(args.output_dir)
347
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
348
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
349
+ path = dirs[-1] if len(dirs) > 0 else None
350
+
351
+ if path is None:
352
+ accelerator.print(
353
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
354
+ )
355
+ args.resume_from_checkpoint = None
356
+ else:
357
+ accelerator.print(f"Resuming from checkpoint {path}")
358
+ accelerator.load_state(os.path.join(args.output_dir, path))
359
+ global_step = int(path.split("-")[1])
360
+
361
+ resume_global_step = global_step * args.gradient_accumulation_steps
362
+ first_epoch = global_step // num_update_steps_per_epoch
363
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
364
+
365
+ # Only show the progress bar once on each machine.
366
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
367
+ progress_bar.set_description("Steps")
368
+
369
+ for epoch in range(first_epoch, args.num_train_epochs):
370
+ unet.train()
371
+ train_loss = 0.0
372
+ for step, batch in enumerate(train_dataloader):
373
+ # Skip steps until we reach the resumed step
374
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
375
+ if step % args.gradient_accumulation_steps == 0:
376
+ progress_bar.update(1)
377
+ continue
378
+
379
+ with accelerator.accumulate(unet):
380
+ # Convert images to latent space
381
+ images =batch[0]
382
+ labels = batch[1]['label']
383
+ latents = vae.encode(images.to(weight_dtype)).latent_dist.sample()
384
+ latents = latents * vae.config.scaling_factor
385
+ segmap = preprocess_input(labels, args.segmap_channels)
386
+
387
+ # TODO : Support GMM noise distribution
388
+ # Sample noise that we'll add to the latents
389
+ noise = torch.randn_like(latents)
390
+ # TODO : move this into noise_sampler.py
391
+ if args.noise_offset:
392
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
393
+ noise += args.noise_offset * torch.randn(
394
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
395
+ )
396
+
397
+ bsz = latents.shape[0]
398
+ # Sample a random timestep for each image
399
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
400
+ timesteps = timesteps.long()
401
+
402
+ # Add noise to the latents according to the noise magnitude at each timestep
403
+ # (this is the forward diffusion process)
404
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
405
+
406
+ # Get the target for loss depending on the prediction type
407
+ if noise_scheduler.config.prediction_type == "epsilon":
408
+ target = noise
409
+ elif noise_scheduler.config.prediction_type == "v_prediction":
410
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
411
+ else:
412
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
413
+
414
+ # Predict the noise residual and compute loss
415
+ model_pred = unet(noisy_latents, segmap, timesteps).sample
416
+
417
+ if args.snr_gamma is None:
418
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
419
+
420
+ else:
421
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
422
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
423
+ # This is discussed in Section 4.2 of the same paper.
424
+ snr = compute_snr(timesteps)
425
+ mse_loss_weights = (
426
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
427
+ )
428
+ # We first calculate the original loss. Then we mean over the non-batch dimensions and
429
+ # rebalance the sample-wise losses with their respective loss weights.
430
+ # Finally, we take the mean of the rebalanced loss.
431
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
432
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
433
+ loss = loss.mean()
434
+
435
+ # Gather the losses across all processes for logging (if we use distributed training).
436
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
437
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
438
+
439
+ # Backpropagate
440
+ accelerator.backward(loss)
441
+ if accelerator.sync_gradients:
442
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
443
+ optimizer.step()
444
+ lr_scheduler.step()
445
+ optimizer.zero_grad()
446
+
447
+ # Checks if the accelerator has performed an optimization step behind the scenes
448
+ if accelerator.sync_gradients:
449
+ if args.use_ema:
450
+ ema_unet.step(unet.parameters())
451
+ progress_bar.update(1)
452
+ global_step += 1
453
+ log_dic = {"train_loss": train_loss}
454
+ accelerator.log(log_dic, step=global_step)
455
+ train_loss = 0.0
456
+
457
+ if global_step % args.checkpointing_steps == 0:
458
+ if accelerator.is_main_process:
459
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
460
+ accelerator.save_state(save_path)
461
+ logger.info(f"Saved state to {save_path}")
462
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
463
+ progress_bar.set_postfix(**logs)
464
+
465
+ if global_step >= args.max_train_steps:
466
+ break
467
+
468
+ if accelerator.is_main_process:
469
+ if epoch % args.validation_epochs == 0:
470
+ if args.use_ema:
471
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
472
+ ema_unet.store(unet.parameters())
473
+ ema_unet.copy_to(unet.parameters())
474
+ log_validation(vae, unet, noise_scheduler,
475
+ accelerator, weight_dtype, val_dataloader,
476
+ save_dir = args.output_dir,resolution=args.resolution, g_step=global_step)
477
+ if args.use_ema:
478
+ # Switch back to the original UNet parameters.
479
+ ema_unet.restore(unet.parameters())
480
+
481
+ # Create the pipeline using the trained modules and save it.
482
+ accelerator.wait_for_everyone()
483
+ if accelerator.is_main_process:
484
+ unet = accelerator.unwrap_model(unet)
485
+ if args.use_ema:
486
+ ema_unet.copy_to(unet.parameters())
487
+
488
+ pipeline = SDMLDMPipeline(
489
+ vae=vae,
490
+ unet=unet,
491
+ scheduler=noise_scheduler,
492
+ torch_dtype=weight_dtype,
493
+ )
494
+ pipeline.save_pretrained(args.output_dir)
495
+
496
+ if args.push_to_hub:
497
+ upload_folder(
498
+ repo_id=repo_id,
499
+ folder_path=args.output_dir,
500
+ commit_message="End of training",
501
+ ignore_patterns=["step_*", "epoch_*"],
502
+ )
503
+
504
+ accelerator.end_training()
505
+
506
+
507
+ if __name__ == "__main__":
508
+
509
+ main()