LiangbinXie commited on
Commit
0177fec
β€’
1 Parent(s): aa0bbd7

add composable adapter

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +1 -5
  2. app.py +277 -63
  3. {models β†’ configs/mm}/faster_rcnn_r50_fpn_coco.py +182 -182
  4. {models β†’ configs/mm}/hrnet_w48_coco_256x192.py +169 -169
  5. configs/stable-diffusion/sd-v1-inference.yaml +65 -0
  6. configs/stable-diffusion/sd-v1-train.yaml +86 -0
  7. configs/stable-diffusion/train_keypose.yaml +87 -0
  8. configs/stable-diffusion/train_mask.yaml +87 -0
  9. configs/stable-diffusion/train_sketch.yaml +87 -0
  10. demo/demos.py +0 -309
  11. demo/model.py +0 -979
  12. dist_util.py +91 -0
  13. docs/AdapterZoo.md +16 -0
  14. docs/FAQ.md +5 -0
  15. docs/examples.md +41 -0
  16. environment.yaml +0 -31
  17. ldm/modules/structure_condition/midas/__init__.py β†’ experiments/README.md +0 -0
  18. ldm/data/base.py +0 -23
  19. ldm/data/dataset_coco.py +36 -0
  20. ldm/data/dataset_depth.py +35 -0
  21. ldm/data/dataset_laion.py +130 -0
  22. ldm/data/dataset_wikiart.py +67 -0
  23. ldm/data/imagenet.py +0 -394
  24. ldm/data/lsun.py +0 -92
  25. ldm/data/utils.py +40 -0
  26. ldm/inference_base.py +282 -0
  27. ldm/models/autoencoder.py +43 -275
  28. ldm/models/diffusion/classifier.py +0 -267
  29. ldm/models/diffusion/ddim.py +68 -17
  30. ldm/models/diffusion/ddpm.py +251 -384
  31. ldm/models/diffusion/dpm_solver/dpm_solver.py +152 -119
  32. ldm/models/diffusion/dpm_solver/sampler.py +8 -3
  33. ldm/models/diffusion/plms.py +23 -48
  34. ldm/modules/attention.py +4 -0
  35. ldm/modules/diffusionmodules/openaimodel.py +85 -263
  36. ldm/modules/diffusionmodules/util.py +5 -2
  37. ldm/modules/ema.py +12 -8
  38. ldm/modules/encoders/adapter.py +84 -76
  39. ldm/modules/encoders/modules.py +349 -142
  40. ldm/modules/{structure_condition β†’ extra_condition}/__init__.py +0 -0
  41. ldm/modules/extra_condition/api.py +269 -0
  42. ldm/modules/{structure_condition/midas β†’ extra_condition}/midas/__init__.py +0 -0
  43. ldm/modules/{structure_condition β†’ extra_condition}/midas/api.py +4 -4
  44. ldm/modules/{structure_condition/openpose β†’ extra_condition/midas/midas}/__init__.py +0 -0
  45. ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/base_model.py +0 -0
  46. ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/blocks.py +0 -0
  47. ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/dpt_depth.py +0 -0
  48. ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/midas_net.py +0 -0
  49. ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/midas_net_custom.py +0 -0
  50. ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/transforms.py +0 -0
.gitignore CHANGED
@@ -1,6 +1,3 @@
1
- # ignored folders
2
- models
3
-
4
  # ignored folders
5
  tmp/*
6
 
@@ -23,7 +20,6 @@ version.py
23
 
24
  # Byte-compiled / optimized / DLL files
25
  __pycache__/
26
- *.pyc
27
  *.py[cod]
28
  *$py.class
29
 
@@ -125,4 +121,4 @@ venv.bak/
125
  /site
126
 
127
  # mypy
128
- .mypy_cache/
 
 
 
 
1
  # ignored folders
2
  tmp/*
3
 
 
20
 
21
  # Byte-compiled / optimized / DLL files
22
  __pycache__/
 
23
  *.py[cod]
24
  *$py.class
25
 
 
121
  /site
122
 
123
  # mypy
124
+ .mypy_cache/
app.py CHANGED
@@ -1,29 +1,44 @@
 
 
 
1
  import os
2
- # os.system('pip3 install openmim')
3
- os.system('mim install mmcv-full==1.7.0')
4
- # os.system('pip3 install mmpose')
5
- # os.system('pip3 install mmdet')
6
- # os.system('pip3 install gradio==3.19.1')
7
- #os.system('pip3 install psutil')
8
-
9
- from demo.model import Model_all
10
  import gradio as gr
11
- from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose, create_demo_color, create_demo_color_sketch, create_demo_openpose, create_demo_style_sketch, create_demo_canny
12
  import torch
13
- import subprocess
14
- import shlex
15
  from huggingface_hub import hf_hub_url
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  urls = {
18
- 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_color_sd14v1.pth', 'models/t2iadapter_openpose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth','third-party-models/body_pose_model.pth', "models/t2iadapter_style_sd14v1.pth", "models/t2iadapter_canny_sd14v1.pth"],
19
- 'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
20
- 'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
 
 
 
 
 
 
21
  }
22
- urls_mmpose = [
23
- 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
24
- 'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth',
25
- 'https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth'
26
- ]
27
  if os.path.exists('models') == False:
28
  os.mkdir('models')
29
  for repo in urls:
@@ -31,58 +46,257 @@ for repo in urls:
31
  for file in files:
32
  url = hf_hub_url(repo, file)
33
  name_ckp = url.split('/')[-1]
34
- save_path = os.path.join('models',name_ckp)
35
  if os.path.exists(save_path) == False:
36
  subprocess.run(shlex.split(f'wget {url} -O {save_path}'))
37
 
38
- for url in urls_mmpose:
39
- name_ckp = url.split('/')[-1]
40
- save_path = os.path.join('models',name_ckp)
41
- if os.path.exists(save_path) == False:
42
- subprocess.run(shlex.split(f'wget {url} -O {save_path}'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
- model = Model_all(device)
 
 
 
 
46
 
47
- DESCRIPTION = '''# T2I-Adapter
 
48
 
49
- Gradio demo for **T2I-Adapter**: [[GitHub]](https://github.com/TencentARC/T2I-Adapter), [[Paper]](https://arxiv.org/abs/2302.08453).
 
 
50
 
51
- It also supports **multiple adapters** in the follwing tabs showing **"A adapter + B adapter"**.
 
 
 
 
52
 
53
- If T2I-Adapter is helpful, please help to ⭐ the [Github Repo](https://github.com/TencentARC/T2I-Adapter) and recommend it to your friends 😊
54
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  with gr.Blocks(css='style.css') as demo:
57
  gr.Markdown(DESCRIPTION)
58
-
59
- gr.HTML("""<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
60
- <br/>
61
- <a href="https://huggingface.co/spaces/Adapter/T2I-Adapter?duplicate=true">
62
- <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
63
- <p/>""")
64
-
65
- with gr.Tabs():
66
- with gr.TabItem('Openpose'):
67
- create_demo_openpose(model.process_openpose)
68
- with gr.TabItem('Keypose'):
69
- create_demo_keypose(model.process_keypose)
70
- with gr.TabItem('Canny'):
71
- create_demo_canny(model.process_canny)
72
- with gr.TabItem('Sketch'):
73
- create_demo_sketch(model.process_sketch)
74
- with gr.TabItem('Draw'):
75
- create_demo_draw(model.process_draw)
76
- with gr.TabItem('Depth'):
77
- create_demo_depth(model.process_depth)
78
- with gr.TabItem('Depth + Keypose'):
79
- create_demo_depth_keypose(model.process_depth_keypose)
80
- with gr.TabItem('Color'):
81
- create_demo_color(model.process_color)
82
- with gr.TabItem('Color + Sketch'):
83
- create_demo_color_sketch(model.process_color_sketch)
84
- with gr.TabItem('Style + Sketch'):
85
- create_demo_style_sketch(model.process_style_sketch)
86
- with gr.TabItem('Segmentation'):
87
- create_demo_seg(model.process_seg)
88
- demo.queue().launch(debug=True, server_name='0.0.0.0')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo inspired by https://huggingface.co/spaces/lambdalabs/image-mixer-demo
2
+ import argparse
3
+ import copy
4
  import os
5
+ import shlex
6
+ import subprocess
7
+ from functools import partial
8
+ from itertools import chain
9
+
10
+ import cv2
 
 
11
  import gradio as gr
 
12
  import torch
13
+ from basicsr.utils import tensor2img
 
14
  from huggingface_hub import hf_hub_url
15
+ from pytorch_lightning import seed_everything
16
+ from torch import autocast
17
 
18
+ from ldm.inference_base import (DEFAULT_NEGATIVE_PROMPT, diffusion_inference,
19
+ get_adapters, get_sd_models)
20
+ from ldm.modules.extra_condition import api
21
+ from ldm.modules.extra_condition.api import (ExtraCondition,
22
+ get_adapter_feature,
23
+ get_cond_model)
24
+
25
+ torch.set_grad_enabled(False)
26
+
27
+ supported_cond = ['style', 'color', 'canny', 'sketch', 'openpose', 'depth']
28
+
29
+ # download the checkpoints
30
  urls = {
31
+ 'TencentARC/T2I-Adapter': [
32
+ 'models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_color_sd14v1.pth',
33
+ 'models/t2iadapter_openpose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth',
34
+ 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth',
35
+ 'third-party-models/body_pose_model.pth', "models/t2iadapter_style_sd14v1.pth",
36
+ "models/t2iadapter_canny_sd14v1.pth", "third-party-models/table5_pidinet.pth"
37
+ ],
38
+ 'runwayml/stable-diffusion-v1-5': ['v1-5-pruned-emaonly.ckpt'],
39
+ 'andite/anything-v4.0': ['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
40
  }
41
+
 
 
 
 
42
  if os.path.exists('models') == False:
43
  os.mkdir('models')
44
  for repo in urls:
 
46
  for file in files:
47
  url = hf_hub_url(repo, file)
48
  name_ckp = url.split('/')[-1]
49
+ save_path = os.path.join('models', name_ckp)
50
  if os.path.exists(save_path) == False:
51
  subprocess.run(shlex.split(f'wget {url} -O {save_path}'))
52
 
53
+ # config
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument(
56
+ '--sd_ckpt',
57
+ type=str,
58
+ default='models/v1-5-pruned-emaonly.ckpt',
59
+ help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported',
60
+ )
61
+ parser.add_argument(
62
+ '--vae_ckpt',
63
+ type=str,
64
+ default=None,
65
+ help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded',
66
+ )
67
+ global_opt = parser.parse_args()
68
+ global_opt.config = 'configs/stable-diffusion/sd-v1-inference.yaml'
69
+ for cond_name in supported_cond:
70
+ setattr(global_opt, f'{cond_name}_adapter_ckpt', f'models/t2iadapter_{cond_name}_sd14v1.pth')
71
+ global_opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
72
+ global_opt.max_resolution = 512 * 512
73
+ global_opt.sampler = 'ddim'
74
+ global_opt.cond_weight = 1.0
75
+ global_opt.C = 4
76
+ global_opt.f = 8
77
+
78
+ # stable-diffusion model
79
+ sd_model, sampler = get_sd_models(global_opt)
80
+ # adapters and models to processing condition inputs
81
+ adapters = {}
82
+ cond_models = {}
83
+ torch.cuda.empty_cache()
84
+
85
+
86
+ def run(*args):
87
+ with torch.inference_mode(), \
88
+ sd_model.ema_scope(), \
89
+ autocast('cuda'):
90
+
91
+ inps = []
92
+ for i in range(0, len(args) - 8, len(supported_cond)):
93
+ inps.append(args[i:i + len(supported_cond)])
94
+
95
+ opt = copy.deepcopy(global_opt)
96
+ opt.prompt, opt.neg_prompt, opt.scale, opt.n_samples, opt.seed, opt.steps, opt.resize_short_edge, opt.cond_tau \
97
+ = args[-8:]
98
+
99
+ conds = []
100
+ activated_conds = []
101
+
102
+ ims1 = []
103
+ ims2 = []
104
+ for idx, (b, im1, im2, cond_weight) in enumerate(zip(*inps)):
105
+ if idx > 1:
106
+ if im1 is not None or im2 is not None:
107
+ if im1 is not None:
108
+ h, w, _ = im1.shape
109
+ else:
110
+ h, w, _ = im2.shape
111
+ break
112
+ # resize all the images to the same size
113
+ for idx, (b, im1, im2, cond_weight) in enumerate(zip(*inps)):
114
+ if idx == 0:
115
+ ims1.append(im1)
116
+ ims2.append(im2)
117
+ continue
118
+ if im1 is not None:
119
+ im1 = cv2.resize(im1, (w, h), interpolation=cv2.INTER_CUBIC)
120
+ if im2 is not None:
121
+ im2 = cv2.resize(im2, (w, h), interpolation=cv2.INTER_CUBIC)
122
+ ims1.append(im1)
123
+ ims2.append(im2)
124
+
125
+ for idx, (b, _, _, cond_weight) in enumerate(zip(*inps)):
126
+ cond_name = supported_cond[idx]
127
+ if b == 'Nothing':
128
+ if cond_name in adapters:
129
+ adapters[cond_name]['model'] = adapters[cond_name]['model'].cpu()
130
+ else:
131
+ activated_conds.append(cond_name)
132
+ if cond_name in adapters:
133
+ adapters[cond_name]['model'] = adapters[cond_name]['model'].to(opt.device)
134
+ else:
135
+ adapters[cond_name] = get_adapters(opt, getattr(ExtraCondition, cond_name))
136
+ adapters[cond_name]['cond_weight'] = cond_weight
137
+
138
+ process_cond_module = getattr(api, f'get_cond_{cond_name}')
139
 
140
+ if b == 'Image':
141
+ if cond_name not in cond_models:
142
+ cond_models[cond_name] = get_cond_model(opt, getattr(ExtraCondition, cond_name))
143
+ conds.append(process_cond_module(opt, ims1[idx], 'image', cond_models[cond_name]))
144
+ else:
145
+ conds.append(process_cond_module(opt, ims2[idx], cond_name, None))
146
 
147
+ adapter_features, append_to_context = get_adapter_feature(
148
+ conds, [adapters[cond_name] for cond_name in activated_conds])
149
 
150
+ output_conds = []
151
+ for cond in conds:
152
+ output_conds.append(tensor2img(cond, rgb2bgr=False))
153
 
154
+ ims = []
155
+ seed_everything(opt.seed)
156
+ for _ in range(opt.n_samples):
157
+ result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context)
158
+ ims.append(tensor2img(result, rgb2bgr=False))
159
 
160
+ # Clear GPU memory cache so less likely to OOM
161
+ torch.cuda.empty_cache()
162
+ return ims, output_conds
163
+
164
+
165
+ def change_visible(im1, im2, val):
166
+ outputs = {}
167
+ if val == "Image":
168
+ outputs[im1] = gr.update(visible=True)
169
+ outputs[im2] = gr.update(visible=False)
170
+ elif val == "Nothing":
171
+ outputs[im1] = gr.update(visible=False)
172
+ outputs[im2] = gr.update(visible=False)
173
+ else:
174
+ outputs[im1] = gr.update(visible=False)
175
+ outputs[im2] = gr.update(visible=True)
176
+ return outputs
177
+
178
+
179
+ DESCRIPTION = '# [Composable T2I-Adapter](https://github.com/TencentARC/T2I-Adapter)'
180
+
181
+ DESCRIPTION += f'<p>Gradio demo for **T2I-Adapter**: [[GitHub]](https://github.com/TencentARC/T2I-Adapter), [[Paper]](https://arxiv.org/abs/2302.08453). If T2I-Adapter is helpful, please help to ⭐ the [Github Repo](https://github.com/TencentARC/T2I-Adapter) and recommend it to your friends 😊 </p>'
182
+
183
+ DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/Adapter/T2I-Adapter?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
184
 
185
  with gr.Blocks(css='style.css') as demo:
186
  gr.Markdown(DESCRIPTION)
187
+
188
+ btns = []
189
+ ims1 = []
190
+ ims2 = []
191
+ cond_weights = []
192
+
193
+ with gr.Row():
194
+ with gr.Column(scale=1.9):
195
+ with gr.Box():
196
+ gr.Markdown("<h5><center>Style & Color</center></h5>")
197
+ with gr.Row():
198
+ for cond_name in supported_cond[:2]:
199
+ with gr.Box():
200
+ with gr.Column():
201
+ if cond_name == 'style':
202
+ btn1 = gr.Radio(
203
+ choices=["Image", "Nothing"],
204
+ label=f"Input type for {cond_name}",
205
+ interactive=True,
206
+ value="Nothing",
207
+ )
208
+ else:
209
+ btn1 = gr.Radio(
210
+ choices=["Image", cond_name, "Nothing"],
211
+ label=f"Input type for {cond_name}",
212
+ interactive=True,
213
+ value="Nothing",
214
+ )
215
+ im1 = gr.Image(
216
+ source='upload', label="Image", interactive=True, visible=False, type="numpy")
217
+ im2 = gr.Image(
218
+ source='upload', label=cond_name, interactive=True, visible=False, type="numpy")
219
+ cond_weight = gr.Slider(
220
+ label="Condition weight",
221
+ minimum=0,
222
+ maximum=5,
223
+ step=0.05,
224
+ value=1,
225
+ interactive=True)
226
+
227
+ fn = partial(change_visible, im1, im2)
228
+ btn1.change(fn=fn, inputs=[btn1], outputs=[im1, im2], queue=False)
229
+
230
+ btns.append(btn1)
231
+ ims1.append(im1)
232
+ ims2.append(im2)
233
+ cond_weights.append(cond_weight)
234
+ with gr.Column(scale=4):
235
+ with gr.Box():
236
+ gr.Markdown("<h5><center>Structure</center></h5>")
237
+ with gr.Row():
238
+ for cond_name in supported_cond[2:6]:
239
+ with gr.Box():
240
+ with gr.Column():
241
+ if cond_name == 'openpose':
242
+ btn1 = gr.Radio(
243
+ choices=["Image", 'pose', "Nothing"],
244
+ label=f"Input type for {cond_name}",
245
+ interactive=True,
246
+ value="Nothing",
247
+ )
248
+ else:
249
+ btn1 = gr.Radio(
250
+ choices=["Image", cond_name, "Nothing"],
251
+ label=f"Input type for {cond_name}",
252
+ interactive=True,
253
+ value="Nothing",
254
+ )
255
+ im1 = gr.Image(
256
+ source='upload', label="Image", interactive=True, visible=False, type="numpy")
257
+ im2 = gr.Image(
258
+ source='upload', label=cond_name, interactive=True, visible=False, type="numpy")
259
+ cond_weight = gr.Slider(
260
+ label="Condition weight",
261
+ minimum=0,
262
+ maximum=5,
263
+ step=0.05,
264
+ value=1,
265
+ interactive=True)
266
+
267
+ fn = partial(change_visible, im1, im2)
268
+ btn1.change(fn=fn, inputs=[btn1], outputs=[im1, im2], queue=False)
269
+
270
+ btns.append(btn1)
271
+ ims1.append(im1)
272
+ ims2.append(im2)
273
+ cond_weights.append(cond_weight)
274
+
275
+ with gr.Column():
276
+ prompt = gr.Textbox(label="Prompt")
277
+
278
+ with gr.Accordion('Advanced options', open=False):
279
+ neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
280
+ scale = gr.Slider(
281
+ label="Guidance Scale (Classifier free guidance)", value=7.5, minimum=1, maximum=20, step=0.1)
282
+ n_samples = gr.Slider(label="Num samples", value=1, minimum=1, maximum=8, step=1)
283
+ seed = gr.Slider(label="Seed", value=42, minimum=0, maximum=10000, step=1)
284
+ steps = gr.Slider(label="Steps", value=50, minimum=10, maximum=100, step=1)
285
+ resize_short_edge = gr.Slider(label="Image resolution", value=512, minimum=320, maximum=1024, step=1)
286
+ cond_tau = gr.Slider(
287
+ label="timestamp parameter that determines until which step the adapter is applied",
288
+ value=1.0,
289
+ minimum=0.1,
290
+ maximum=1.0,
291
+ step=0.05)
292
+
293
+ with gr.Row():
294
+ submit = gr.Button("Generate")
295
+ output = gr.Gallery().style(grid=2, height='auto')
296
+ cond = gr.Gallery().style(grid=2, height='auto')
297
+
298
+ inps = list(chain(btns, ims1, ims2, cond_weights))
299
+
300
+ inps.extend([prompt, neg_prompt, scale, n_samples, seed, steps, resize_short_edge, cond_tau])
301
+ submit.click(fn=run, inputs=inps, outputs=[output, cond])
302
+ demo.launch(server_name='0.0.0.0', share=False, server_port=47313)
{models β†’ configs/mm}/faster_rcnn_r50_fpn_coco.py RENAMED
@@ -1,182 +1,182 @@
1
- checkpoint_config = dict(interval=1)
2
- # yapf:disable
3
- log_config = dict(
4
- interval=50,
5
- hooks=[
6
- dict(type='TextLoggerHook'),
7
- # dict(type='TensorboardLoggerHook')
8
- ])
9
- # yapf:enable
10
- dist_params = dict(backend='nccl')
11
- log_level = 'INFO'
12
- load_from = None
13
- resume_from = None
14
- workflow = [('train', 1)]
15
- # optimizer
16
- optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
17
- optimizer_config = dict(grad_clip=None)
18
- # learning policy
19
- lr_config = dict(
20
- policy='step',
21
- warmup='linear',
22
- warmup_iters=500,
23
- warmup_ratio=0.001,
24
- step=[8, 11])
25
- total_epochs = 12
26
-
27
- model = dict(
28
- type='FasterRCNN',
29
- pretrained='torchvision://resnet50',
30
- backbone=dict(
31
- type='ResNet',
32
- depth=50,
33
- num_stages=4,
34
- out_indices=(0, 1, 2, 3),
35
- frozen_stages=1,
36
- norm_cfg=dict(type='BN', requires_grad=True),
37
- norm_eval=True,
38
- style='pytorch'),
39
- neck=dict(
40
- type='FPN',
41
- in_channels=[256, 512, 1024, 2048],
42
- out_channels=256,
43
- num_outs=5),
44
- rpn_head=dict(
45
- type='RPNHead',
46
- in_channels=256,
47
- feat_channels=256,
48
- anchor_generator=dict(
49
- type='AnchorGenerator',
50
- scales=[8],
51
- ratios=[0.5, 1.0, 2.0],
52
- strides=[4, 8, 16, 32, 64]),
53
- bbox_coder=dict(
54
- type='DeltaXYWHBBoxCoder',
55
- target_means=[.0, .0, .0, .0],
56
- target_stds=[1.0, 1.0, 1.0, 1.0]),
57
- loss_cls=dict(
58
- type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
59
- loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
60
- roi_head=dict(
61
- type='StandardRoIHead',
62
- bbox_roi_extractor=dict(
63
- type='SingleRoIExtractor',
64
- roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
65
- out_channels=256,
66
- featmap_strides=[4, 8, 16, 32]),
67
- bbox_head=dict(
68
- type='Shared2FCBBoxHead',
69
- in_channels=256,
70
- fc_out_channels=1024,
71
- roi_feat_size=7,
72
- num_classes=80,
73
- bbox_coder=dict(
74
- type='DeltaXYWHBBoxCoder',
75
- target_means=[0., 0., 0., 0.],
76
- target_stds=[0.1, 0.1, 0.2, 0.2]),
77
- reg_class_agnostic=False,
78
- loss_cls=dict(
79
- type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
80
- loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
81
- # model training and testing settings
82
- train_cfg=dict(
83
- rpn=dict(
84
- assigner=dict(
85
- type='MaxIoUAssigner',
86
- pos_iou_thr=0.7,
87
- neg_iou_thr=0.3,
88
- min_pos_iou=0.3,
89
- match_low_quality=True,
90
- ignore_iof_thr=-1),
91
- sampler=dict(
92
- type='RandomSampler',
93
- num=256,
94
- pos_fraction=0.5,
95
- neg_pos_ub=-1,
96
- add_gt_as_proposals=False),
97
- allowed_border=-1,
98
- pos_weight=-1,
99
- debug=False),
100
- rpn_proposal=dict(
101
- nms_pre=2000,
102
- max_per_img=1000,
103
- nms=dict(type='nms', iou_threshold=0.7),
104
- min_bbox_size=0),
105
- rcnn=dict(
106
- assigner=dict(
107
- type='MaxIoUAssigner',
108
- pos_iou_thr=0.5,
109
- neg_iou_thr=0.5,
110
- min_pos_iou=0.5,
111
- match_low_quality=False,
112
- ignore_iof_thr=-1),
113
- sampler=dict(
114
- type='RandomSampler',
115
- num=512,
116
- pos_fraction=0.25,
117
- neg_pos_ub=-1,
118
- add_gt_as_proposals=True),
119
- pos_weight=-1,
120
- debug=False)),
121
- test_cfg=dict(
122
- rpn=dict(
123
- nms_pre=1000,
124
- max_per_img=1000,
125
- nms=dict(type='nms', iou_threshold=0.7),
126
- min_bbox_size=0),
127
- rcnn=dict(
128
- score_thr=0.05,
129
- nms=dict(type='nms', iou_threshold=0.5),
130
- max_per_img=100)
131
- # soft-nms is also supported for rcnn testing
132
- # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
133
- ))
134
-
135
- dataset_type = 'CocoDataset'
136
- data_root = 'data/coco'
137
- img_norm_cfg = dict(
138
- mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
139
- train_pipeline = [
140
- dict(type='LoadImageFromFile'),
141
- dict(type='LoadAnnotations', with_bbox=True),
142
- dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
143
- dict(type='RandomFlip', flip_ratio=0.5),
144
- dict(type='Normalize', **img_norm_cfg),
145
- dict(type='Pad', size_divisor=32),
146
- dict(type='DefaultFormatBundle'),
147
- dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
148
- ]
149
- test_pipeline = [
150
- dict(type='LoadImageFromFile'),
151
- dict(
152
- type='MultiScaleFlipAug',
153
- img_scale=(1333, 800),
154
- flip=False,
155
- transforms=[
156
- dict(type='Resize', keep_ratio=True),
157
- dict(type='RandomFlip'),
158
- dict(type='Normalize', **img_norm_cfg),
159
- dict(type='Pad', size_divisor=32),
160
- dict(type='DefaultFormatBundle'),
161
- dict(type='Collect', keys=['img']),
162
- ])
163
- ]
164
- data = dict(
165
- samples_per_gpu=2,
166
- workers_per_gpu=2,
167
- train=dict(
168
- type=dataset_type,
169
- ann_file=f'{data_root}/annotations/instances_train2017.json',
170
- img_prefix=f'{data_root}/train2017/',
171
- pipeline=train_pipeline),
172
- val=dict(
173
- type=dataset_type,
174
- ann_file=f'{data_root}/annotations/instances_val2017.json',
175
- img_prefix=f'{data_root}/val2017/',
176
- pipeline=test_pipeline),
177
- test=dict(
178
- type=dataset_type,
179
- ann_file=f'{data_root}/annotations/instances_val2017.json',
180
- img_prefix=f'{data_root}/val2017/',
181
- pipeline=test_pipeline))
182
- evaluation = dict(interval=1, metric='bbox')
 
1
+ checkpoint_config = dict(interval=1)
2
+ # yapf:disable
3
+ log_config = dict(
4
+ interval=50,
5
+ hooks=[
6
+ dict(type='TextLoggerHook'),
7
+ # dict(type='TensorboardLoggerHook')
8
+ ])
9
+ # yapf:enable
10
+ dist_params = dict(backend='nccl')
11
+ log_level = 'INFO'
12
+ load_from = None
13
+ resume_from = None
14
+ workflow = [('train', 1)]
15
+ # optimizer
16
+ optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
17
+ optimizer_config = dict(grad_clip=None)
18
+ # learning policy
19
+ lr_config = dict(
20
+ policy='step',
21
+ warmup='linear',
22
+ warmup_iters=500,
23
+ warmup_ratio=0.001,
24
+ step=[8, 11])
25
+ total_epochs = 12
26
+
27
+ model = dict(
28
+ type='FasterRCNN',
29
+ pretrained='torchvision://resnet50',
30
+ backbone=dict(
31
+ type='ResNet',
32
+ depth=50,
33
+ num_stages=4,
34
+ out_indices=(0, 1, 2, 3),
35
+ frozen_stages=1,
36
+ norm_cfg=dict(type='BN', requires_grad=True),
37
+ norm_eval=True,
38
+ style='pytorch'),
39
+ neck=dict(
40
+ type='FPN',
41
+ in_channels=[256, 512, 1024, 2048],
42
+ out_channels=256,
43
+ num_outs=5),
44
+ rpn_head=dict(
45
+ type='RPNHead',
46
+ in_channels=256,
47
+ feat_channels=256,
48
+ anchor_generator=dict(
49
+ type='AnchorGenerator',
50
+ scales=[8],
51
+ ratios=[0.5, 1.0, 2.0],
52
+ strides=[4, 8, 16, 32, 64]),
53
+ bbox_coder=dict(
54
+ type='DeltaXYWHBBoxCoder',
55
+ target_means=[.0, .0, .0, .0],
56
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
57
+ loss_cls=dict(
58
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
59
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
60
+ roi_head=dict(
61
+ type='StandardRoIHead',
62
+ bbox_roi_extractor=dict(
63
+ type='SingleRoIExtractor',
64
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
65
+ out_channels=256,
66
+ featmap_strides=[4, 8, 16, 32]),
67
+ bbox_head=dict(
68
+ type='Shared2FCBBoxHead',
69
+ in_channels=256,
70
+ fc_out_channels=1024,
71
+ roi_feat_size=7,
72
+ num_classes=80,
73
+ bbox_coder=dict(
74
+ type='DeltaXYWHBBoxCoder',
75
+ target_means=[0., 0., 0., 0.],
76
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
77
+ reg_class_agnostic=False,
78
+ loss_cls=dict(
79
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
80
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
81
+ # model training and testing settings
82
+ train_cfg=dict(
83
+ rpn=dict(
84
+ assigner=dict(
85
+ type='MaxIoUAssigner',
86
+ pos_iou_thr=0.7,
87
+ neg_iou_thr=0.3,
88
+ min_pos_iou=0.3,
89
+ match_low_quality=True,
90
+ ignore_iof_thr=-1),
91
+ sampler=dict(
92
+ type='RandomSampler',
93
+ num=256,
94
+ pos_fraction=0.5,
95
+ neg_pos_ub=-1,
96
+ add_gt_as_proposals=False),
97
+ allowed_border=-1,
98
+ pos_weight=-1,
99
+ debug=False),
100
+ rpn_proposal=dict(
101
+ nms_pre=2000,
102
+ max_per_img=1000,
103
+ nms=dict(type='nms', iou_threshold=0.7),
104
+ min_bbox_size=0),
105
+ rcnn=dict(
106
+ assigner=dict(
107
+ type='MaxIoUAssigner',
108
+ pos_iou_thr=0.5,
109
+ neg_iou_thr=0.5,
110
+ min_pos_iou=0.5,
111
+ match_low_quality=False,
112
+ ignore_iof_thr=-1),
113
+ sampler=dict(
114
+ type='RandomSampler',
115
+ num=512,
116
+ pos_fraction=0.25,
117
+ neg_pos_ub=-1,
118
+ add_gt_as_proposals=True),
119
+ pos_weight=-1,
120
+ debug=False)),
121
+ test_cfg=dict(
122
+ rpn=dict(
123
+ nms_pre=1000,
124
+ max_per_img=1000,
125
+ nms=dict(type='nms', iou_threshold=0.7),
126
+ min_bbox_size=0),
127
+ rcnn=dict(
128
+ score_thr=0.05,
129
+ nms=dict(type='nms', iou_threshold=0.5),
130
+ max_per_img=100)
131
+ # soft-nms is also supported for rcnn testing
132
+ # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
133
+ ))
134
+
135
+ dataset_type = 'CocoDataset'
136
+ data_root = 'data/coco'
137
+ img_norm_cfg = dict(
138
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
139
+ train_pipeline = [
140
+ dict(type='LoadImageFromFile'),
141
+ dict(type='LoadAnnotations', with_bbox=True),
142
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
143
+ dict(type='RandomFlip', flip_ratio=0.5),
144
+ dict(type='Normalize', **img_norm_cfg),
145
+ dict(type='Pad', size_divisor=32),
146
+ dict(type='DefaultFormatBundle'),
147
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
148
+ ]
149
+ test_pipeline = [
150
+ dict(type='LoadImageFromFile'),
151
+ dict(
152
+ type='MultiScaleFlipAug',
153
+ img_scale=(1333, 800),
154
+ flip=False,
155
+ transforms=[
156
+ dict(type='Resize', keep_ratio=True),
157
+ dict(type='RandomFlip'),
158
+ dict(type='Normalize', **img_norm_cfg),
159
+ dict(type='Pad', size_divisor=32),
160
+ dict(type='DefaultFormatBundle'),
161
+ dict(type='Collect', keys=['img']),
162
+ ])
163
+ ]
164
+ data = dict(
165
+ samples_per_gpu=2,
166
+ workers_per_gpu=2,
167
+ train=dict(
168
+ type=dataset_type,
169
+ ann_file=f'{data_root}/annotations/instances_train2017.json',
170
+ img_prefix=f'{data_root}/train2017/',
171
+ pipeline=train_pipeline),
172
+ val=dict(
173
+ type=dataset_type,
174
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
175
+ img_prefix=f'{data_root}/val2017/',
176
+ pipeline=test_pipeline),
177
+ test=dict(
178
+ type=dataset_type,
179
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
180
+ img_prefix=f'{data_root}/val2017/',
181
+ pipeline=test_pipeline))
182
+ evaluation = dict(interval=1, metric='bbox')
{models β†’ configs/mm}/hrnet_w48_coco_256x192.py RENAMED
@@ -1,169 +1,169 @@
1
- # _base_ = [
2
- # '../../../../_base_/default_runtime.py',
3
- # '../../../../_base_/datasets/coco.py'
4
- # ]
5
- evaluation = dict(interval=10, metric='mAP', save_best='AP')
6
-
7
- optimizer = dict(
8
- type='Adam',
9
- lr=5e-4,
10
- )
11
- optimizer_config = dict(grad_clip=None)
12
- # learning policy
13
- lr_config = dict(
14
- policy='step',
15
- warmup='linear',
16
- warmup_iters=500,
17
- warmup_ratio=0.001,
18
- step=[170, 200])
19
- total_epochs = 210
20
- channel_cfg = dict(
21
- num_output_channels=17,
22
- dataset_joints=17,
23
- dataset_channel=[
24
- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
25
- ],
26
- inference_channel=[
27
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
28
- ])
29
-
30
- # model settings
31
- model = dict(
32
- type='TopDown',
33
- pretrained='https://download.openmmlab.com/mmpose/'
34
- 'pretrain_models/hrnet_w48-8ef0771d.pth',
35
- backbone=dict(
36
- type='HRNet',
37
- in_channels=3,
38
- extra=dict(
39
- stage1=dict(
40
- num_modules=1,
41
- num_branches=1,
42
- block='BOTTLENECK',
43
- num_blocks=(4, ),
44
- num_channels=(64, )),
45
- stage2=dict(
46
- num_modules=1,
47
- num_branches=2,
48
- block='BASIC',
49
- num_blocks=(4, 4),
50
- num_channels=(48, 96)),
51
- stage3=dict(
52
- num_modules=4,
53
- num_branches=3,
54
- block='BASIC',
55
- num_blocks=(4, 4, 4),
56
- num_channels=(48, 96, 192)),
57
- stage4=dict(
58
- num_modules=3,
59
- num_branches=4,
60
- block='BASIC',
61
- num_blocks=(4, 4, 4, 4),
62
- num_channels=(48, 96, 192, 384))),
63
- ),
64
- keypoint_head=dict(
65
- type='TopdownHeatmapSimpleHead',
66
- in_channels=48,
67
- out_channels=channel_cfg['num_output_channels'],
68
- num_deconv_layers=0,
69
- extra=dict(final_conv_kernel=1, ),
70
- loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
71
- train_cfg=dict(),
72
- test_cfg=dict(
73
- flip_test=True,
74
- post_process='default',
75
- shift_heatmap=True,
76
- modulate_kernel=11))
77
-
78
- data_cfg = dict(
79
- image_size=[192, 256],
80
- heatmap_size=[48, 64],
81
- num_output_channels=channel_cfg['num_output_channels'],
82
- num_joints=channel_cfg['dataset_joints'],
83
- dataset_channel=channel_cfg['dataset_channel'],
84
- inference_channel=channel_cfg['inference_channel'],
85
- soft_nms=False,
86
- nms_thr=1.0,
87
- oks_thr=0.9,
88
- vis_thr=0.2,
89
- use_gt_bbox=False,
90
- det_bbox_thr=0.0,
91
- bbox_file='data/coco/person_detection_results/'
92
- 'COCO_val2017_detections_AP_H_56_person.json',
93
- )
94
-
95
- train_pipeline = [
96
- dict(type='LoadImageFromFile'),
97
- dict(type='TopDownGetBboxCenterScale', padding=1.25),
98
- dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
99
- dict(type='TopDownRandomFlip', flip_prob=0.5),
100
- dict(
101
- type='TopDownHalfBodyTransform',
102
- num_joints_half_body=8,
103
- prob_half_body=0.3),
104
- dict(
105
- type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
106
- dict(type='TopDownAffine'),
107
- dict(type='ToTensor'),
108
- dict(
109
- type='NormalizeTensor',
110
- mean=[0.485, 0.456, 0.406],
111
- std=[0.229, 0.224, 0.225]),
112
- dict(type='TopDownGenerateTarget', sigma=2),
113
- dict(
114
- type='Collect',
115
- keys=['img', 'target', 'target_weight'],
116
- meta_keys=[
117
- 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
118
- 'rotation', 'bbox_score', 'flip_pairs'
119
- ]),
120
- ]
121
-
122
- val_pipeline = [
123
- dict(type='LoadImageFromFile'),
124
- dict(type='TopDownGetBboxCenterScale', padding=1.25),
125
- dict(type='TopDownAffine'),
126
- dict(type='ToTensor'),
127
- dict(
128
- type='NormalizeTensor',
129
- mean=[0.485, 0.456, 0.406],
130
- std=[0.229, 0.224, 0.225]),
131
- dict(
132
- type='Collect',
133
- keys=['img'],
134
- meta_keys=[
135
- 'image_file', 'center', 'scale', 'rotation', 'bbox_score',
136
- 'flip_pairs'
137
- ]),
138
- ]
139
-
140
- test_pipeline = val_pipeline
141
-
142
- data_root = 'data/coco'
143
- data = dict(
144
- samples_per_gpu=32,
145
- workers_per_gpu=2,
146
- val_dataloader=dict(samples_per_gpu=32),
147
- test_dataloader=dict(samples_per_gpu=32),
148
- train=dict(
149
- type='TopDownCocoDataset',
150
- ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
151
- img_prefix=f'{data_root}/train2017/',
152
- data_cfg=data_cfg,
153
- pipeline=train_pipeline,
154
- dataset_info={{_base_.dataset_info}}),
155
- val=dict(
156
- type='TopDownCocoDataset',
157
- ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
158
- img_prefix=f'{data_root}/val2017/',
159
- data_cfg=data_cfg,
160
- pipeline=val_pipeline,
161
- dataset_info={{_base_.dataset_info}}),
162
- test=dict(
163
- type='TopDownCocoDataset',
164
- ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
165
- img_prefix=f'{data_root}/val2017/',
166
- data_cfg=data_cfg,
167
- pipeline=test_pipeline,
168
- dataset_info={{_base_.dataset_info}}),
169
- )
 
1
+ # _base_ = [
2
+ # '../../../../_base_/default_runtime.py',
3
+ # '../../../../_base_/datasets/coco.py'
4
+ # ]
5
+ evaluation = dict(interval=10, metric='mAP', save_best='AP')
6
+
7
+ optimizer = dict(
8
+ type='Adam',
9
+ lr=5e-4,
10
+ )
11
+ optimizer_config = dict(grad_clip=None)
12
+ # learning policy
13
+ lr_config = dict(
14
+ policy='step',
15
+ warmup='linear',
16
+ warmup_iters=500,
17
+ warmup_ratio=0.001,
18
+ step=[170, 200])
19
+ total_epochs = 210
20
+ channel_cfg = dict(
21
+ num_output_channels=17,
22
+ dataset_joints=17,
23
+ dataset_channel=[
24
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
25
+ ],
26
+ inference_channel=[
27
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
28
+ ])
29
+
30
+ # model settings
31
+ model = dict(
32
+ type='TopDown',
33
+ pretrained='https://download.openmmlab.com/mmpose/'
34
+ 'pretrain_models/hrnet_w48-8ef0771d.pth',
35
+ backbone=dict(
36
+ type='HRNet',
37
+ in_channels=3,
38
+ extra=dict(
39
+ stage1=dict(
40
+ num_modules=1,
41
+ num_branches=1,
42
+ block='BOTTLENECK',
43
+ num_blocks=(4, ),
44
+ num_channels=(64, )),
45
+ stage2=dict(
46
+ num_modules=1,
47
+ num_branches=2,
48
+ block='BASIC',
49
+ num_blocks=(4, 4),
50
+ num_channels=(48, 96)),
51
+ stage3=dict(
52
+ num_modules=4,
53
+ num_branches=3,
54
+ block='BASIC',
55
+ num_blocks=(4, 4, 4),
56
+ num_channels=(48, 96, 192)),
57
+ stage4=dict(
58
+ num_modules=3,
59
+ num_branches=4,
60
+ block='BASIC',
61
+ num_blocks=(4, 4, 4, 4),
62
+ num_channels=(48, 96, 192, 384))),
63
+ ),
64
+ keypoint_head=dict(
65
+ type='TopdownHeatmapSimpleHead',
66
+ in_channels=48,
67
+ out_channels=channel_cfg['num_output_channels'],
68
+ num_deconv_layers=0,
69
+ extra=dict(final_conv_kernel=1, ),
70
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
71
+ train_cfg=dict(),
72
+ test_cfg=dict(
73
+ flip_test=True,
74
+ post_process='default',
75
+ shift_heatmap=True,
76
+ modulate_kernel=11))
77
+
78
+ data_cfg = dict(
79
+ image_size=[192, 256],
80
+ heatmap_size=[48, 64],
81
+ num_output_channels=channel_cfg['num_output_channels'],
82
+ num_joints=channel_cfg['dataset_joints'],
83
+ dataset_channel=channel_cfg['dataset_channel'],
84
+ inference_channel=channel_cfg['inference_channel'],
85
+ soft_nms=False,
86
+ nms_thr=1.0,
87
+ oks_thr=0.9,
88
+ vis_thr=0.2,
89
+ use_gt_bbox=False,
90
+ det_bbox_thr=0.0,
91
+ bbox_file='data/coco/person_detection_results/'
92
+ 'COCO_val2017_detections_AP_H_56_person.json',
93
+ )
94
+
95
+ train_pipeline = [
96
+ dict(type='LoadImageFromFile'),
97
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
98
+ dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
99
+ dict(type='TopDownRandomFlip', flip_prob=0.5),
100
+ dict(
101
+ type='TopDownHalfBodyTransform',
102
+ num_joints_half_body=8,
103
+ prob_half_body=0.3),
104
+ dict(
105
+ type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
106
+ dict(type='TopDownAffine'),
107
+ dict(type='ToTensor'),
108
+ dict(
109
+ type='NormalizeTensor',
110
+ mean=[0.485, 0.456, 0.406],
111
+ std=[0.229, 0.224, 0.225]),
112
+ dict(type='TopDownGenerateTarget', sigma=2),
113
+ dict(
114
+ type='Collect',
115
+ keys=['img', 'target', 'target_weight'],
116
+ meta_keys=[
117
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
118
+ 'rotation', 'bbox_score', 'flip_pairs'
119
+ ]),
120
+ ]
121
+
122
+ val_pipeline = [
123
+ dict(type='LoadImageFromFile'),
124
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
125
+ dict(type='TopDownAffine'),
126
+ dict(type='ToTensor'),
127
+ dict(
128
+ type='NormalizeTensor',
129
+ mean=[0.485, 0.456, 0.406],
130
+ std=[0.229, 0.224, 0.225]),
131
+ dict(
132
+ type='Collect',
133
+ keys=['img'],
134
+ meta_keys=[
135
+ 'image_file', 'center', 'scale', 'rotation', 'bbox_score',
136
+ 'flip_pairs'
137
+ ]),
138
+ ]
139
+
140
+ test_pipeline = val_pipeline
141
+
142
+ data_root = 'data/coco'
143
+ data = dict(
144
+ samples_per_gpu=32,
145
+ workers_per_gpu=2,
146
+ val_dataloader=dict(samples_per_gpu=32),
147
+ test_dataloader=dict(samples_per_gpu=32),
148
+ train=dict(
149
+ type='TopDownCocoDataset',
150
+ ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
151
+ img_prefix=f'{data_root}/train2017/',
152
+ data_cfg=data_cfg,
153
+ pipeline=train_pipeline,
154
+ dataset_info={{_base_.dataset_info}}),
155
+ val=dict(
156
+ type='TopDownCocoDataset',
157
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
158
+ img_prefix=f'{data_root}/val2017/',
159
+ data_cfg=data_cfg,
160
+ pipeline=val_pipeline,
161
+ dataset_info={{_base_.dataset_info}}),
162
+ test=dict(
163
+ type='TopDownCocoDataset',
164
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
165
+ img_prefix=f'{data_root}/val2017/',
166
+ data_cfg=data_cfg,
167
+ pipeline=test_pipeline,
168
+ dataset_info={{_base_.dataset_info}}),
169
+ )
configs/stable-diffusion/sd-v1-inference.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ unet_config:
21
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22
+ params:
23
+ use_fp16: True
24
+ image_size: 32 # unused
25
+ in_channels: 4
26
+ out_channels: 4
27
+ model_channels: 320
28
+ attention_resolutions: [ 4, 2, 1 ]
29
+ num_res_blocks: 2
30
+ channel_mult: [ 1, 2, 4, 4 ]
31
+ num_heads: 8
32
+ use_spatial_transformer: True
33
+ transformer_depth: 1
34
+ context_dim: 768
35
+ use_checkpoint: True
36
+ legacy: False
37
+
38
+ first_stage_config:
39
+ target: ldm.models.autoencoder.AutoencoderKL
40
+ params:
41
+ embed_dim: 4
42
+ monitor: val/rec_loss
43
+ ddconfig:
44
+ double_z: true
45
+ z_channels: 4
46
+ resolution: 512
47
+ in_channels: 3
48
+ out_ch: 3
49
+ ch: 128
50
+ ch_mult:
51
+ - 1
52
+ - 2
53
+ - 4
54
+ - 4
55
+ num_res_blocks: 2
56
+ attn_resolutions: []
57
+ dropout: 0.0
58
+ lossconfig:
59
+ target: torch.nn.Identity
60
+
61
+ cond_stage_config:
62
+ target: ldm.modules.encoders.modules.WebUIFrozenCLIPEmebedder
63
+ params:
64
+ version: openai/clip-vit-large-patch14
65
+ layer: last
configs/stable-diffusion/sd-v1-train.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config: #__is_unconditional__
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71
+ params:
72
+ version: openai/clip-vit-large-patch14
73
+
74
+ logger:
75
+ print_freq: 100
76
+ save_checkpoint_freq: !!float 1e4
77
+ use_tb_logger: true
78
+ wandb:
79
+ project: ~
80
+ resume_id: ~
81
+ dist_params:
82
+ backend: nccl
83
+ port: 29500
84
+ training:
85
+ lr: !!float 1e-5
86
+ save_freq: 1e4
configs/stable-diffusion/train_keypose.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: train_keypose
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: openai/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/train_mask.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: train_mask
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: openai/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/train_sketch.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: train_sketch
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: openai/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
demo/demos.py DELETED
@@ -1,309 +0,0 @@
1
- import gradio as gr
2
- import numpy as np
3
- import psutil
4
-
5
- def create_map():
6
- return np.zeros(shape=(512, 512), dtype=np.uint8)+255
7
-
8
- def get_system_memory():
9
- memory = psutil.virtual_memory()
10
- memory_percent = memory.percent
11
- memory_used = memory.used / (1024.0 ** 3)
12
- memory_total = memory.total / (1024.0 ** 3)
13
- return {"percent": f"{memory_percent}%", "used": f"{memory_used:.3f}GB", "total": f"{memory_total:.3f}GB"}
14
-
15
-
16
-
17
- def create_demo_keypose(process):
18
- with gr.Blocks() as demo:
19
- with gr.Row():
20
- gr.Markdown('## T2I-Adapter (Keypose)')
21
- with gr.Row():
22
- with gr.Column():
23
- input_img = gr.Image(source='upload', type="numpy")
24
- prompt = gr.Textbox(label="Prompt")
25
- neg_prompt = gr.Textbox(label="Negative Prompt",
26
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
27
- pos_prompt = gr.Textbox(label="Positive Prompt",
28
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
29
- with gr.Row():
30
- type_in = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a keypose map)')
31
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
32
- run_button = gr.Button(label="Run")
33
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the keypose to the result)", minimum=0, maximum=1, value=1, step=0.1)
34
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
35
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
36
- with gr.Column():
37
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
38
- ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
39
- run_button.click(fn=process, inputs=ips, outputs=[result])
40
- return demo
41
-
42
- def create_demo_openpose(process):
43
- with gr.Blocks() as demo:
44
- with gr.Row():
45
- gr.Markdown('## T2I-Adapter (Openpose)')
46
- with gr.Row():
47
- with gr.Column():
48
- input_img = gr.Image(source='upload', type="numpy")
49
- prompt = gr.Textbox(label="Prompt")
50
- neg_prompt = gr.Textbox(label="Negative Prompt",
51
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
52
- pos_prompt = gr.Textbox(label="Positive Prompt",
53
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
54
- with gr.Row():
55
- type_in = gr.inputs.Radio(['Openpose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a openpose map)')
56
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
57
- run_button = gr.Button(label="Run")
58
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the openpose to the result)", minimum=0, maximum=1, value=1, step=0.1)
59
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
60
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
61
- with gr.Column():
62
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
63
- ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
64
- run_button.click(fn=process, inputs=ips, outputs=[result])
65
- return demo
66
-
67
- def create_demo_sketch(process):
68
- with gr.Blocks() as demo:
69
- with gr.Row():
70
- gr.Markdown('## T2I-Adapter (Sketch)')
71
- with gr.Row():
72
- with gr.Column():
73
- input_img = gr.Image(source='upload', type="numpy")
74
- prompt = gr.Textbox(label="Prompt")
75
- neg_prompt = gr.Textbox(label="Negative Prompt",
76
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
77
- pos_prompt = gr.Textbox(label="Positive Prompt",
78
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
79
- with gr.Row():
80
- type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a sketch)')
81
- color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
82
- run_button = gr.Button(label="Run")
83
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
84
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
85
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
86
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
87
- with gr.Column():
88
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
89
- ips = [input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
90
- run_button.click(fn=process, inputs=ips, outputs=[result])
91
- return demo
92
-
93
- def create_demo_canny(process):
94
- with gr.Blocks() as demo:
95
- with gr.Row():
96
- gr.Markdown('## T2I-Adapter (Canny)')
97
- with gr.Row():
98
- with gr.Column():
99
- input_img = gr.Image(source='upload', type="numpy")
100
- prompt = gr.Textbox(label="Prompt")
101
- neg_prompt = gr.Textbox(label="Negative Prompt",
102
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
103
- pos_prompt = gr.Textbox(label="Positive Prompt",
104
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
105
- with gr.Row():
106
- type_in = gr.inputs.Radio(['Canny', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a canny map)')
107
- color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the canny background\n (Only work for canny input)')
108
- run_button = gr.Button(label="Run")
109
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the canny to the result)", minimum=0, maximum=1, value=1, step=0.1)
110
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
111
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
112
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
113
- with gr.Column():
114
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
115
- ips = [input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
116
- run_button.click(fn=process, inputs=ips, outputs=[result])
117
- return demo
118
-
119
- def create_demo_color_sketch(process):
120
- with gr.Blocks() as demo:
121
- with gr.Row():
122
- gr.Markdown('## T2I-Adapter (Color + Sketch)')
123
- with gr.Row():
124
- with gr.Column():
125
- with gr.Row():
126
- input_img_sketch = gr.Image(source='upload', type="numpy", label='Sketch guidance')
127
- input_img_color = gr.Image(source='upload', type="numpy", label='Color guidance')
128
- prompt = gr.Textbox(label="Prompt")
129
- neg_prompt = gr.Textbox(label="Negative Prompt",
130
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
131
- pos_prompt = gr.Textbox(label="Positive Prompt",
132
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
133
- type_in_color = gr.inputs.Radio(['ColorMap', 'Image'], type="value", default='Image', label='Input Types of Color\n (You can input an image or a color map)')
134
- with gr.Row():
135
- type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types of Sketch\n (You can input an image or a sketch)')
136
- color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
137
- with gr.Row():
138
- w_sketch = gr.Slider(label="Sketch guidance weight", minimum=0, maximum=2, value=1.0, step=0.1)
139
- w_color = gr.Slider(label="Color guidance weight", minimum=0, maximum=2, value=1.2, step=0.1)
140
- run_button = gr.Button(label="Run")
141
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
142
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
143
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
144
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
145
- with gr.Column():
146
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=3, height='auto')
147
- ips = [input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
148
- run_button.click(fn=process, inputs=ips, outputs=[result])
149
- return demo
150
-
151
- def create_demo_style_sketch(process):
152
- with gr.Blocks() as demo:
153
- with gr.Row():
154
- gr.Markdown('## T2I-Adapter (Style + Sketch)')
155
- with gr.Row():
156
- with gr.Column():
157
- with gr.Row():
158
- input_img_sketch = gr.Image(source='upload', type="numpy", label='Sketch guidance')
159
- input_img_style = gr.Image(source='upload', type="numpy", label='Style guidance')
160
- prompt = gr.Textbox(label="Prompt")
161
- neg_prompt = gr.Textbox(label="Negative Prompt",
162
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
163
- pos_prompt = gr.Textbox(label="Positive Prompt",
164
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
165
- with gr.Row():
166
- type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types of Sketch\n (You can input an image or a sketch)')
167
- color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
168
- run_button = gr.Button(label="Run")
169
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=1, step=0.1)
170
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
171
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
172
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
173
- with gr.Column():
174
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
175
- ips = [input_img_sketch, input_img_style, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
176
- run_button.click(fn=process, inputs=ips, outputs=[result])
177
- return demo
178
-
179
- def create_demo_color(process):
180
- with gr.Blocks() as demo:
181
- with gr.Row():
182
- gr.Markdown('## T2I-Adapter (Color)')
183
- with gr.Row():
184
- with gr.Column():
185
- input_img = gr.Image(source='upload', type="numpy", label='Color guidance')
186
- prompt = gr.Textbox(label="Prompt")
187
- neg_prompt = gr.Textbox(label="Negative Prompt",
188
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
189
- pos_prompt = gr.Textbox(label="Positive Prompt",
190
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
191
- type_in_color = gr.inputs.Radio(['ColorMap', 'Image'], type="value", default='Image', label='Input Types of Color\n (You can input an image or a color map)')
192
- w_color = gr.Slider(label="Color guidance weight", minimum=0, maximum=2, value=1, step=0.1)
193
- run_button = gr.Button(label="Run")
194
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=1, step=0.1)
195
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
196
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
197
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
198
- with gr.Column():
199
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
200
- ips = [input_img, prompt, neg_prompt, pos_prompt, w_color, type_in_color, fix_sample, scale, con_strength, base_model]
201
- run_button.click(fn=process, inputs=ips, outputs=[result])
202
- return demo
203
-
204
- def create_demo_seg(process):
205
- with gr.Blocks() as demo:
206
- with gr.Row():
207
- gr.Markdown('## T2I-Adapter (Segmentation)')
208
- with gr.Row():
209
- with gr.Column():
210
- input_img = gr.Image(source='upload', type="numpy")
211
- prompt = gr.Textbox(label="Prompt")
212
- neg_prompt = gr.Textbox(label="Negative Prompt",
213
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
214
- pos_prompt = gr.Textbox(label="Positive Prompt",
215
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
216
- with gr.Row():
217
- type_in = gr.inputs.Radio(['Segmentation', 'Image'], type="value", default='Image', label='You can input an image or a segmentation. If you choose to input a segmentation, it must correspond to the coco-stuff')
218
- run_button = gr.Button(label="Run")
219
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=1, step=0.1)
220
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
221
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
222
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
223
- with gr.Column():
224
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
225
- ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
226
- run_button.click(fn=process, inputs=ips, outputs=[result])
227
- return demo
228
-
229
- def create_demo_depth(process):
230
- with gr.Blocks() as demo:
231
- with gr.Row():
232
- gr.Markdown('## T2I-Adapter (Depth)')
233
- with gr.Row():
234
- with gr.Column():
235
- input_img = gr.Image(source='upload', type="numpy")
236
- prompt = gr.Textbox(label="Prompt")
237
- neg_prompt = gr.Textbox(label="Negative Prompt",
238
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
239
- pos_prompt = gr.Textbox(label="Positive Prompt",
240
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
241
- with gr.Row():
242
- type_in = gr.inputs.Radio(['Depth', 'Image'], type="value", default='Image', label='You can input an image or a depth map')
243
- run_button = gr.Button(label="Run")
244
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the depth map to the result)", minimum=0, maximum=1, value=1, step=0.1)
245
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
246
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
247
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
248
- with gr.Column():
249
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
250
- ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
251
- run_button.click(fn=process, inputs=ips, outputs=[result])
252
- return demo
253
-
254
- def create_demo_depth_keypose(process):
255
- with gr.Blocks() as demo:
256
- with gr.Row():
257
- gr.Markdown('## T2I-Adapter (Depth & Keypose)')
258
- with gr.Row():
259
- with gr.Column():
260
- with gr.Row():
261
- input_img_depth = gr.Image(source='upload', type="numpy", label='Depth guidance')
262
- input_img_keypose = gr.Image(source='upload', type="numpy", label='Keypose guidance')
263
-
264
- prompt = gr.Textbox(label="Prompt")
265
- neg_prompt = gr.Textbox(label="Negative Prompt",
266
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
267
- pos_prompt = gr.Textbox(label="Positive Prompt",
268
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
269
- with gr.Row():
270
- type_in_depth = gr.inputs.Radio(['Depth', 'Image'], type="value", default='Image', label='You can input an image or a depth map')
271
- type_in_keypose = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='You can input an image or a keypose map (mmpose style)')
272
- with gr.Row():
273
- w_depth = gr.Slider(label="Depth guidance weight", minimum=0, maximum=2, value=1.0, step=0.1)
274
- w_keypose = gr.Slider(label="Keypose guidance weight", minimum=0, maximum=2, value=1.5, step=0.1)
275
- run_button = gr.Button(label="Run")
276
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the multi-guidance to the result)", minimum=0, maximum=1, value=1, step=0.1)
277
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
278
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
279
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
280
- with gr.Column():
281
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=3, height='auto')
282
- ips = [input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth, w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
283
- run_button.click(fn=process, inputs=ips, outputs=[result])
284
- return demo
285
-
286
- def create_demo_draw(process):
287
- with gr.Blocks() as demo:
288
- with gr.Row():
289
- gr.Markdown('## T2I-Adapter (Hand-free drawing)')
290
- with gr.Row():
291
- with gr.Column():
292
- create_button = gr.Button(label="Start", value='Hand-free drawing')
293
- input_img = gr.Image(source='upload', type="numpy",tool='sketch')
294
- create_button.click(fn=create_map, outputs=[input_img], queue=False)
295
- prompt = gr.Textbox(label="Prompt")
296
- neg_prompt = gr.Textbox(label="Negative Prompt",
297
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
298
- pos_prompt = gr.Textbox(label="Positive Prompt",
299
- value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
300
- run_button = gr.Button(label="Run")
301
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
302
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
303
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
304
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
305
- with gr.Column():
306
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
307
- ips = [input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
308
- run_button.click(fn=process, inputs=ips, outputs=[result])
309
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/model.py DELETED
@@ -1,979 +0,0 @@
1
- import torch
2
- from basicsr.utils import img2tensor, tensor2img
3
- from pytorch_lightning import seed_everything
4
- from ldm.models.diffusion.plms import PLMSSampler
5
- from ldm.modules.encoders.adapter import Adapter, Adapter_light, StyleAdapter
6
- from ldm.util import instantiate_from_config
7
- from ldm.modules.structure_condition.model_edge import pidinet
8
- from ldm.modules.structure_condition.model_seg import seger, Colorize
9
- from ldm.modules.structure_condition.midas.api import MiDaSInference
10
- import gradio as gr
11
- from omegaconf import OmegaConf
12
- import mmcv
13
- from mmdet.apis import inference_detector, init_detector
14
- from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result)
15
- import os
16
- import cv2
17
- import numpy as np
18
- import torch.nn.functional as F
19
- from transformers import CLIPProcessor, CLIPVisionModel
20
- from PIL import Image
21
-
22
-
23
- def preprocessing(image, device):
24
- # Resize
25
- scale = 640 / max(image.shape[:2])
26
- image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
27
- raw_image = image.astype(np.uint8)
28
-
29
- # Subtract mean values
30
- image = image.astype(np.float32)
31
- image -= np.array(
32
- [
33
- float(104.008),
34
- float(116.669),
35
- float(122.675),
36
- ]
37
- )
38
-
39
- # Convert to torch.Tensor and add "batch" axis
40
- image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
41
- image = image.to(device)
42
-
43
- return image, raw_image
44
-
45
-
46
- def imshow_keypoints(img,
47
- pose_result,
48
- skeleton=None,
49
- kpt_score_thr=0.1,
50
- pose_kpt_color=None,
51
- pose_link_color=None,
52
- radius=4,
53
- thickness=1):
54
- """Draw keypoints and links on an image.
55
-
56
- Args:
57
- img (ndarry): The image to draw poses on.
58
- pose_result (list[kpts]): The poses to draw. Each element kpts is
59
- a set of K keypoints as an Kx3 numpy.ndarray, where each
60
- keypoint is represented as x, y, score.
61
- kpt_score_thr (float, optional): Minimum score of keypoints
62
- to be shown. Default: 0.3.
63
- pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
64
- the keypoint will not be drawn.
65
- pose_link_color (np.array[Mx3]): Color of M links. If None, the
66
- links will not be drawn.
67
- thickness (int): Thickness of lines.
68
- """
69
-
70
- img_h, img_w, _ = img.shape
71
- img = np.zeros(img.shape)
72
-
73
- for idx, kpts in enumerate(pose_result):
74
- if idx > 1:
75
- continue
76
- kpts = kpts['keypoints']
77
- kpts = np.array(kpts, copy=False)
78
-
79
- # draw each point on image
80
- if pose_kpt_color is not None:
81
- assert len(pose_kpt_color) == len(kpts)
82
-
83
- for kid, kpt in enumerate(kpts):
84
- x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
85
-
86
- if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
87
- # skip the point that should not be drawn
88
- continue
89
-
90
- color = tuple(int(c) for c in pose_kpt_color[kid])
91
- cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
92
-
93
- # draw links
94
- if skeleton is not None and pose_link_color is not None:
95
- assert len(pose_link_color) == len(skeleton)
96
-
97
- for sk_id, sk in enumerate(skeleton):
98
- pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
99
- pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
100
-
101
- if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
102
- or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
103
- or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
104
- # skip the link that should not be drawn
105
- continue
106
- color = tuple(int(c) for c in pose_link_color[sk_id])
107
- cv2.line(img, pos1, pos2, color, thickness=thickness)
108
-
109
- return img
110
-
111
-
112
- def load_model_from_config(config, ckpt, verbose=False):
113
- print(f"Loading model from {ckpt}")
114
- pl_sd = torch.load(ckpt, map_location="cpu")
115
- if "global_step" in pl_sd:
116
- print(f"Global Step: {pl_sd['global_step']}")
117
- if "state_dict" in pl_sd:
118
- sd = pl_sd["state_dict"]
119
- else:
120
- sd = pl_sd
121
- model = instantiate_from_config(config.model)
122
- _, _ = model.load_state_dict(sd, strict=False)
123
-
124
- model.cuda()
125
- model.eval()
126
- return model
127
-
128
-
129
- class Model_all:
130
- def __init__(self, device='cpu'):
131
- # common part
132
- self.device = device
133
- self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
134
- self.config.model.params.cond_stage_config.params.device = device
135
- self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
136
- self.current_base = 'sd-v1-4.ckpt'
137
- self.sampler = PLMSSampler(self.base_model)
138
-
139
- # sketch part
140
- self.model_canny = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
141
- use_conv=False).to(device)
142
- self.model_canny.load_state_dict(torch.load("models/t2iadapter_canny_sd14v1.pth", map_location=device))
143
- self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
144
- use_conv=False).to(device)
145
- self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
146
- self.model_edge = pidinet().to(device)
147
- self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in
148
- torch.load('models/table5_pidinet.pth', map_location=device)[
149
- 'state_dict'].items()})
150
-
151
- # segmentation part
152
- self.model_seger = seger().to(device)
153
- self.model_seger.eval()
154
- self.coler = Colorize(n=182)
155
- self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
156
- use_conv=False).to(device)
157
- self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
158
-
159
- # depth part
160
- self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
161
- self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
162
- use_conv=False).to(device)
163
- self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
164
-
165
- # keypose part
166
- self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
167
- use_conv=False).to(device)
168
- self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))
169
-
170
- # openpose part
171
- self.model_openpose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
172
- use_conv=False).to(device)
173
- self.model_openpose.load_state_dict(torch.load("models/t2iadapter_openpose_sd14v1.pth", map_location=device))
174
-
175
- # color part
176
- self.model_color = Adapter_light(cin=int(3 * 64), channels=[320, 640, 1280, 1280], nums_rb=4).to(device)
177
- self.model_color.load_state_dict(torch.load("models/t2iadapter_color_sd14v1.pth", map_location=device))
178
-
179
- # style part
180
- self.model_style = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(device)
181
- self.model_style.load_state_dict(torch.load("models/t2iadapter_style_sd14v1.pth", map_location=device))
182
- self.clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
183
- self.clip_vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14').to(device)
184
-
185
- device = 'cpu'
186
- ## mmpose
187
- det_config = 'models/faster_rcnn_r50_fpn_coco.py'
188
- det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
189
- pose_config = 'models/hrnet_w48_coco_256x192.py'
190
- pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
191
- self.det_cat_id = 1
192
- self.bbox_thr = 0.2
193
- ## detector
194
- det_config_mmcv = mmcv.Config.fromfile(det_config)
195
- self.det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
196
- pose_config_mmcv = mmcv.Config.fromfile(pose_config)
197
- self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
198
- ## color
199
- self.skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8],
200
- [7, 9], [8, 10],
201
- [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
202
- self.pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
203
- [0, 255, 0],
204
- [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0],
205
- [255, 128, 0],
206
- [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
207
- self.pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
208
- [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
209
- [255, 128, 0],
210
- [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
211
- [51, 153, 255],
212
- [51, 153, 255], [51, 153, 255], [51, 153, 255]]
213
-
214
- def load_vae(self):
215
- vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
216
- sd = vae_sd["state_dict"]
217
- self.base_model.first_stage_model.load_state_dict(sd, strict=False)
218
-
219
- @torch.no_grad()
220
- def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
221
- con_strength, base_model):
222
- if self.current_base != base_model:
223
- ckpt = os.path.join("models", base_model)
224
- pl_sd = torch.load(ckpt, map_location="cuda")
225
- if "state_dict" in pl_sd:
226
- sd = pl_sd["state_dict"]
227
- else:
228
- sd = pl_sd
229
- self.base_model.load_state_dict(sd, strict=False)
230
- self.current_base = base_model
231
- if 'anything' in base_model.lower():
232
- self.load_vae()
233
-
234
- con_strength = int((1 - con_strength) * 50)
235
- if fix_sample == 'True':
236
- seed_everything(42)
237
- im = cv2.resize(input_img, (512, 512))
238
-
239
- if type_in == 'Sketch':
240
- if color_back == 'White':
241
- im = 255 - im
242
- im_edge = im.copy()
243
- im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
244
- im = im > 0.5
245
- im = im.float()
246
- elif type_in == 'Image':
247
- im = img2tensor(im).unsqueeze(0) / 255.
248
- im = self.model_edge(im.to(self.device))[-1]
249
- im = im > 0.5
250
- im = im.float()
251
- im_edge = tensor2img(im)
252
-
253
- # extract condition features
254
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
255
- nc = self.base_model.get_learned_conditioning([neg_prompt])
256
- features_adapter = self.model_sketch(im.to(self.device))
257
- shape = [4, 64, 64]
258
-
259
- # sampling
260
- samples_ddim, _ = self.sampler.sample(S=50,
261
- conditioning=c,
262
- batch_size=1,
263
- shape=shape,
264
- verbose=False,
265
- unconditional_guidance_scale=scale,
266
- unconditional_conditioning=nc,
267
- eta=0.0,
268
- x_T=None,
269
- features_adapter1=features_adapter,
270
- mode='sketch',
271
- con_strength=con_strength)
272
-
273
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
274
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
275
- x_samples_ddim = x_samples_ddim.to('cpu')
276
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
277
- x_samples_ddim = 255. * x_samples_ddim
278
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
279
-
280
- return [im_edge, x_samples_ddim]
281
-
282
- @torch.no_grad()
283
- def process_canny(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
284
- con_strength, base_model):
285
- if self.current_base != base_model:
286
- ckpt = os.path.join("models", base_model)
287
- pl_sd = torch.load(ckpt, map_location="cuda")
288
- if "state_dict" in pl_sd:
289
- sd = pl_sd["state_dict"]
290
- else:
291
- sd = pl_sd
292
- self.base_model.load_state_dict(sd, strict=False)
293
- self.current_base = base_model
294
- if 'anything' in base_model.lower():
295
- self.load_vae()
296
-
297
- con_strength = int((1 - con_strength) * 50)
298
- if fix_sample == 'True':
299
- seed_everything(42)
300
- im = cv2.resize(input_img, (512, 512))
301
-
302
- if type_in == 'Canny':
303
- if color_back == 'White':
304
- im = 255 - im
305
- im_edge = im.copy()
306
- im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
307
- elif type_in == 'Image':
308
- im = cv2.Canny(im,100,200)
309
- im = img2tensor(im[..., None], bgr2rgb=True, float32=True).unsqueeze(0) / 255.
310
- im_edge = tensor2img(im)
311
-
312
- # extract condition features
313
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
314
- nc = self.base_model.get_learned_conditioning([neg_prompt])
315
- features_adapter = self.model_canny(im.to(self.device))
316
- shape = [4, 64, 64]
317
-
318
- # sampling
319
- samples_ddim, _ = self.sampler.sample(S=50,
320
- conditioning=c,
321
- batch_size=1,
322
- shape=shape,
323
- verbose=False,
324
- unconditional_guidance_scale=scale,
325
- unconditional_conditioning=nc,
326
- eta=0.0,
327
- x_T=None,
328
- features_adapter1=features_adapter,
329
- mode='sketch',
330
- con_strength=con_strength)
331
-
332
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
333
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
334
- x_samples_ddim = x_samples_ddim.to('cpu')
335
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
336
- x_samples_ddim = 255. * x_samples_ddim
337
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
338
-
339
- return [im_edge, x_samples_ddim]
340
-
341
- @torch.no_grad()
342
- def process_color_sketch(self, input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
343
- if self.current_base != base_model:
344
- ckpt = os.path.join("models", base_model)
345
- pl_sd = torch.load(ckpt, map_location="cuda")
346
- if "state_dict" in pl_sd:
347
- sd = pl_sd["state_dict"]
348
- else:
349
- sd = pl_sd
350
- self.base_model.load_state_dict(sd, strict=False)
351
- self.current_base = base_model
352
- if 'anything' in base_model.lower():
353
- self.load_vae()
354
-
355
- con_strength = int((1 - con_strength) * 50)
356
- if fix_sample == 'True':
357
- seed_everything(42)
358
- im = cv2.resize(input_img_sketch, (512, 512))
359
-
360
- if type_in == 'Sketch':
361
- if color_back == 'White':
362
- im = 255 - im
363
- im_edge = im.copy()
364
- im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
365
- im = im > 0.5
366
- im = im.float()
367
- elif type_in == 'Image':
368
- im = img2tensor(im).unsqueeze(0) / 255.
369
- im = self.model_edge(im.to(self.device))[-1]#.cuda()
370
- im = im > 0.5
371
- im = im.float()
372
- im_edge = tensor2img(im)
373
- if type_in_color == 'Image':
374
- input_img_color = cv2.resize(input_img_color,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)
375
- input_img_color = cv2.resize(input_img_color,(512,512), interpolation=cv2.INTER_NEAREST)
376
- else:
377
- input_img_color = cv2.resize(input_img_color, (512, 512))
378
- im_color = input_img_color.copy()
379
- im_color_tensor = img2tensor(input_img_color, bgr2rgb=False).unsqueeze(0) / 255.
380
-
381
- # extract condition features
382
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
383
- nc = self.base_model.get_learned_conditioning([neg_prompt])
384
- features_adapter_sketch = self.model_sketch(im.to(self.device))
385
- features_adapter_color = self.model_color(im_color_tensor.to(self.device))
386
- features_adapter = [fs*w_sketch+fc*w_color for fs, fc in zip(features_adapter_sketch,features_adapter_color)]
387
- shape = [4, 64, 64]
388
-
389
- # sampling
390
- samples_ddim, _ = self.sampler.sample(S=50,
391
- conditioning=c,
392
- batch_size=1,
393
- shape=shape,
394
- verbose=False,
395
- unconditional_guidance_scale=scale,
396
- unconditional_conditioning=nc,
397
- eta=0.0,
398
- x_T=None,
399
- features_adapter1=features_adapter,
400
- mode='sketch',
401
- con_strength=con_strength)
402
-
403
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
404
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
405
- x_samples_ddim = x_samples_ddim.to('cpu')
406
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
407
- x_samples_ddim = 255. * x_samples_ddim
408
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
409
-
410
- return [im_edge, im_color, x_samples_ddim]
411
-
412
- @torch.no_grad()
413
- def process_style_sketch(self, input_img_sketch, input_img_style, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
414
- if self.current_base != base_model:
415
- ckpt = os.path.join("models", base_model)
416
- pl_sd = torch.load(ckpt, map_location="cuda")
417
- if "state_dict" in pl_sd:
418
- sd = pl_sd["state_dict"]
419
- else:
420
- sd = pl_sd
421
- self.base_model.load_state_dict(sd, strict=False)
422
- self.current_base = base_model
423
- if 'anything' in base_model.lower():
424
- self.load_vae()
425
-
426
- con_strength = int((1 - con_strength) * 50)
427
- if fix_sample == 'True':
428
- seed_everything(42)
429
- im = cv2.resize(input_img_sketch, (512, 512))
430
-
431
- if type_in == 'Sketch':
432
- if color_back == 'White':
433
- im = 255 - im
434
- im_edge = im.copy()
435
- im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
436
- im = im > 0.5
437
- im = im.float()
438
- elif type_in == 'Image':
439
- im = img2tensor(im).unsqueeze(0) / 255.
440
- im = self.model_edge(im.to(self.device))[-1]#.cuda()
441
- im = im > 0.5
442
- im = im.float()
443
- im_edge = tensor2img(im)
444
-
445
- style = Image.fromarray(input_img_style)
446
- style_for_clip = self.clip_processor(images=style, return_tensors="pt")['pixel_values']
447
- style_feat = self.clip_vision_model(style_for_clip.to(self.device))['last_hidden_state']
448
- style_feat = self.model_style(style_feat)
449
-
450
- # extract condition features
451
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
452
- nc = self.base_model.get_learned_conditioning([neg_prompt])
453
- features_adapter = self.model_sketch(im.to(self.device))
454
- shape = [4, 64, 64]
455
-
456
- # sampling
457
- samples_ddim, _ = self.sampler.sample(S=50,
458
- conditioning=c,
459
- batch_size=1,
460
- shape=shape,
461
- verbose=False,
462
- unconditional_guidance_scale=scale,
463
- unconditional_conditioning=nc,
464
- eta=0.0,
465
- x_T=None,
466
- features_adapter1=features_adapter,
467
- mode='style',
468
- con_strength=con_strength,
469
- style_feature=style_feat)
470
-
471
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
472
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
473
- x_samples_ddim = x_samples_ddim.to('cpu')
474
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
475
- x_samples_ddim = 255. * x_samples_ddim
476
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
477
-
478
- return [im_edge, x_samples_ddim]
479
-
480
- @torch.no_grad()
481
- def process_color(self, input_img, prompt, neg_prompt, pos_prompt, w_color, type_in_color, fix_sample, scale, con_strength, base_model):
482
- if self.current_base != base_model:
483
- ckpt = os.path.join("models", base_model)
484
- pl_sd = torch.load(ckpt, map_location="cuda")
485
- if "state_dict" in pl_sd:
486
- sd = pl_sd["state_dict"]
487
- else:
488
- sd = pl_sd
489
- self.base_model.load_state_dict(sd, strict=False)
490
- self.current_base = base_model
491
- if 'anything' in base_model.lower():
492
- self.load_vae()
493
-
494
- con_strength = int((1 - con_strength) * 50)
495
- if fix_sample == 'True':
496
- seed_everything(42)
497
- if type_in_color == 'Image':
498
- input_img = cv2.resize(input_img,(512//64, 512//64), interpolation=cv2.INTER_CUBIC)
499
- input_img = cv2.resize(input_img,(512,512), interpolation=cv2.INTER_NEAREST)
500
- else:
501
- input_img = cv2.resize(input_img, (512, 512))
502
-
503
- im_color = input_img.copy()
504
- im = img2tensor(input_img, bgr2rgb=False).unsqueeze(0) / 255.
505
-
506
- # extract condition features
507
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
508
- nc = self.base_model.get_learned_conditioning([neg_prompt])
509
- features_adapter = self.model_color(im.to(self.device))
510
- features_adapter = [fi*w_color for fi in features_adapter]
511
- shape = [4, 64, 64]
512
-
513
- # sampling
514
- samples_ddim, _ = self.sampler.sample(S=50,
515
- conditioning=c,
516
- batch_size=1,
517
- shape=shape,
518
- verbose=False,
519
- unconditional_guidance_scale=scale,
520
- unconditional_conditioning=nc,
521
- eta=0.0,
522
- x_T=None,
523
- features_adapter1=features_adapter,
524
- mode='sketch',
525
- con_strength=con_strength)
526
-
527
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
528
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
529
- x_samples_ddim = x_samples_ddim.to('cpu')
530
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
531
- x_samples_ddim = 255. * x_samples_ddim
532
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
533
-
534
- return [im_color, x_samples_ddim]
535
-
536
- @torch.no_grad()
537
- def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
538
- con_strength, base_model):
539
- if self.current_base != base_model:
540
- ckpt = os.path.join("models", base_model)
541
- pl_sd = torch.load(ckpt, map_location="cuda")
542
- if "state_dict" in pl_sd:
543
- sd = pl_sd["state_dict"]
544
- else:
545
- sd = pl_sd
546
- self.base_model.load_state_dict(sd, strict=False)
547
- self.current_base = base_model
548
- if 'anything' in base_model.lower():
549
- self.load_vae()
550
-
551
- con_strength = int((1 - con_strength) * 50)
552
- if fix_sample == 'True':
553
- seed_everything(42)
554
- im = cv2.resize(input_img, (512, 512))
555
-
556
- if type_in == 'Depth':
557
- im_depth = im.copy()
558
- depth = img2tensor(im).unsqueeze(0) / 255.
559
- elif type_in == 'Image':
560
- im = img2tensor(im).unsqueeze(0) / 127.5 - 1.0
561
- depth = self.depth_model(im.to(self.device)).repeat(1, 3, 1, 1)
562
- depth -= torch.min(depth)
563
- depth /= torch.max(depth)
564
- im_depth = tensor2img(depth)
565
-
566
- # extract condition features
567
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
568
- nc = self.base_model.get_learned_conditioning([neg_prompt])
569
- features_adapter = self.model_depth(depth.to(self.device))
570
- shape = [4, 64, 64]
571
-
572
- # sampling
573
- samples_ddim, _ = self.sampler.sample(S=50,
574
- conditioning=c,
575
- batch_size=1,
576
- shape=shape,
577
- verbose=False,
578
- unconditional_guidance_scale=scale,
579
- unconditional_conditioning=nc,
580
- eta=0.0,
581
- x_T=None,
582
- features_adapter1=features_adapter,
583
- mode='sketch',
584
- con_strength=con_strength)
585
-
586
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
587
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
588
- x_samples_ddim = x_samples_ddim.to('cpu')
589
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
590
- x_samples_ddim = 255. * x_samples_ddim
591
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
592
-
593
- return [im_depth, x_samples_ddim]
594
-
595
- @torch.no_grad()
596
- def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth,
597
- w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
598
- if self.current_base != base_model:
599
- ckpt = os.path.join("models", base_model)
600
- pl_sd = torch.load(ckpt, map_location="cuda")
601
- if "state_dict" in pl_sd:
602
- sd = pl_sd["state_dict"]
603
- else:
604
- sd = pl_sd
605
- self.base_model.load_state_dict(sd, strict=False)
606
- self.current_base = base_model
607
- if 'anything' in base_model.lower():
608
- self.load_vae()
609
-
610
- if fix_sample == 'True':
611
- seed_everything(42)
612
- im_depth = cv2.resize(input_img_depth, (512, 512))
613
- im_keypose = cv2.resize(input_img_keypose, (512, 512))
614
-
615
- # get depth
616
- if type_in_depth == 'Depth':
617
- im_depth_out = im_depth.copy()
618
- depth = img2tensor(im_depth).unsqueeze(0) / 255.
619
- elif type_in_depth == 'Image':
620
- im_depth = img2tensor(im_depth).unsqueeze(0) / 127.5 - 1.0
621
- depth = self.depth_model(im_depth.to(self.device)).repeat(1, 3, 1, 1)
622
- depth -= torch.min(depth)
623
- depth /= torch.max(depth)
624
- im_depth_out = tensor2img(depth)
625
-
626
- # get keypose
627
- if type_in_keypose == 'Keypose':
628
- im_keypose_out = im_keypose.copy()[:,:,::-1]
629
- elif type_in_keypose == 'Image':
630
- image = im_keypose.copy()
631
- im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
632
- mmdet_results = inference_detector(self.det_model, image)
633
- # keep the person class bounding boxes.
634
- person_results = process_mmdet_results(mmdet_results, self.det_cat_id)
635
-
636
- # optional
637
- return_heatmap = False
638
- dataset = self.pose_model.cfg.data['test']['type']
639
-
640
- # e.g. use ('backbone', ) to return backbone feature
641
- output_layer_names = None
642
- pose_results, _ = inference_top_down_pose_model(
643
- self.pose_model,
644
- image,
645
- person_results,
646
- bbox_thr=self.bbox_thr,
647
- format='xyxy',
648
- dataset=dataset,
649
- dataset_info=None,
650
- return_heatmap=return_heatmap,
651
- outputs=output_layer_names)
652
-
653
- # show the results
654
- im_keypose_out = imshow_keypoints(
655
- image,
656
- pose_results,
657
- skeleton=self.skeleton,
658
- pose_kpt_color=self.pose_kpt_color,
659
- pose_link_color=self.pose_link_color,
660
- radius=2,
661
- thickness=2)
662
- im_keypose_out = im_keypose_out.astype(np.uint8)
663
-
664
- # extract condition features
665
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
666
- nc = self.base_model.get_learned_conditioning([neg_prompt])
667
- features_adapter_depth = self.model_depth(depth.to(self.device))
668
- pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
669
- pose = pose.unsqueeze(0)
670
- features_adapter_keypose = self.model_pose(pose.to(self.device))
671
- features_adapter = [f_d * w_depth + f_k * w_keypose for f_d, f_k in
672
- zip(features_adapter_depth, features_adapter_keypose)]
673
- shape = [4, 64, 64]
674
-
675
- # sampling
676
- con_strength = int((1 - con_strength) * 50)
677
- samples_ddim, _ = self.sampler.sample(S=50,
678
- conditioning=c,
679
- batch_size=1,
680
- shape=shape,
681
- verbose=False,
682
- unconditional_guidance_scale=scale,
683
- unconditional_conditioning=nc,
684
- eta=0.0,
685
- x_T=None,
686
- features_adapter1=features_adapter,
687
- mode='sketch',
688
- con_strength=con_strength)
689
-
690
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
691
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
692
- x_samples_ddim = x_samples_ddim.to('cpu')
693
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
694
- x_samples_ddim = 255. * x_samples_ddim
695
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
696
-
697
- return [im_depth_out, im_keypose_out[:, :, ::-1], x_samples_ddim]
698
-
699
- @torch.no_grad()
700
- def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
701
- con_strength, base_model):
702
- if self.current_base != base_model:
703
- ckpt = os.path.join("models", base_model)
704
- pl_sd = torch.load(ckpt, map_location="cuda")
705
- if "state_dict" in pl_sd:
706
- sd = pl_sd["state_dict"]
707
- else:
708
- sd = pl_sd
709
- self.base_model.load_state_dict(sd, strict=False)
710
- self.current_base = base_model
711
- if 'anything' in base_model.lower():
712
- self.load_vae()
713
-
714
- con_strength = int((1 - con_strength) * 50)
715
- if fix_sample == 'True':
716
- seed_everything(42)
717
- im = cv2.resize(input_img, (512, 512))
718
-
719
- if type_in == 'Segmentation':
720
- im_seg = im.copy()
721
- im = img2tensor(im).unsqueeze(0) / 255.
722
- labelmap = im.float()
723
- elif type_in == 'Image':
724
- im, _ = preprocessing(im, self.device)
725
- _, _, H, W = im.shape
726
-
727
- # Image -> Probability map
728
- logits = self.model_seger(im)
729
- logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
730
- probs = F.softmax(logits, dim=1)[0]
731
- probs = probs.cpu().data.numpy()
732
- labelmap = np.argmax(probs, axis=0)
733
-
734
- labelmap = self.coler(labelmap)
735
- labelmap = np.transpose(labelmap, (1, 2, 0))
736
- labelmap = cv2.resize(labelmap, (512, 512))
737
- labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True) / 255.
738
- im_seg = tensor2img(labelmap)[:, :, ::-1]
739
- labelmap = labelmap.unsqueeze(0)
740
-
741
- # extract condition features
742
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
743
- nc = self.base_model.get_learned_conditioning([neg_prompt])
744
- features_adapter = self.model_seg(labelmap.to(self.device))
745
- shape = [4, 64, 64]
746
-
747
- # sampling
748
- samples_ddim, _ = self.sampler.sample(S=50,
749
- conditioning=c,
750
- batch_size=1,
751
- shape=shape,
752
- verbose=False,
753
- unconditional_guidance_scale=scale,
754
- unconditional_conditioning=nc,
755
- eta=0.0,
756
- x_T=None,
757
- features_adapter1=features_adapter,
758
- mode='sketch',
759
- con_strength=con_strength)
760
-
761
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
762
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
763
- x_samples_ddim = x_samples_ddim.to('cpu')
764
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
765
- x_samples_ddim = 255. * x_samples_ddim
766
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
767
-
768
- return [im_seg, x_samples_ddim]
769
-
770
- @torch.no_grad()
771
- def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
772
- if self.current_base != base_model:
773
- ckpt = os.path.join("models", base_model)
774
- pl_sd = torch.load(ckpt, map_location="cuda")
775
- if "state_dict" in pl_sd:
776
- sd = pl_sd["state_dict"]
777
- else:
778
- sd = pl_sd
779
- self.base_model.load_state_dict(sd, strict=False)
780
- self.current_base = base_model
781
- if 'anything' in base_model.lower():
782
- self.load_vae()
783
-
784
- con_strength = int((1 - con_strength) * 50)
785
- if fix_sample == 'True':
786
- seed_everything(42)
787
- input_img = input_img['mask']
788
- c = input_img[:, :, 0:3].astype(np.float32)
789
- a = input_img[:, :, 3:4].astype(np.float32) / 255.0
790
- im = c * a + 255.0 * (1.0 - a)
791
- im = im.clip(0, 255).astype(np.uint8)
792
- im = cv2.resize(im, (512, 512))
793
-
794
- im_edge = im.copy()
795
- im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
796
- im = im > 0.5
797
- im = im.float()
798
-
799
- # extract condition features
800
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
801
- nc = self.base_model.get_learned_conditioning([neg_prompt])
802
- features_adapter = self.model_sketch(im.to(self.device))
803
- shape = [4, 64, 64]
804
-
805
- # sampling
806
- samples_ddim, _ = self.sampler.sample(S=50,
807
- conditioning=c,
808
- batch_size=1,
809
- shape=shape,
810
- verbose=False,
811
- unconditional_guidance_scale=scale,
812
- unconditional_conditioning=nc,
813
- eta=0.0,
814
- x_T=None,
815
- features_adapter1=features_adapter,
816
- mode='sketch',
817
- con_strength=con_strength)
818
-
819
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
820
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
821
- x_samples_ddim = x_samples_ddim.to('cpu')
822
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
823
- x_samples_ddim = 255. * x_samples_ddim
824
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
825
-
826
- return [im_edge, x_samples_ddim]
827
-
828
- @torch.no_grad()
829
- def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength,
830
- base_model):
831
- if self.current_base != base_model:
832
- ckpt = os.path.join("models", base_model)
833
- pl_sd = torch.load(ckpt, map_location="cuda")
834
- if "state_dict" in pl_sd:
835
- sd = pl_sd["state_dict"]
836
- else:
837
- sd = pl_sd
838
- self.base_model.load_state_dict(sd, strict=False)
839
- self.current_base = base_model
840
- if 'anything' in base_model.lower():
841
- self.load_vae()
842
-
843
- con_strength = int((1 - con_strength) * 50)
844
- if fix_sample == 'True':
845
- seed_everything(42)
846
- im = cv2.resize(input_img, (512, 512))
847
-
848
- if type_in == 'Keypose':
849
- im_pose = im.copy()[:,:,::-1]
850
- elif type_in == 'Image':
851
- image = im.copy()
852
- im = img2tensor(im).unsqueeze(0) / 255.
853
- mmdet_results = inference_detector(self.det_model, image)
854
- # keep the person class bounding boxes.
855
- person_results = process_mmdet_results(mmdet_results, self.det_cat_id)
856
-
857
- # optional
858
- return_heatmap = False
859
- dataset = self.pose_model.cfg.data['test']['type']
860
-
861
- # e.g. use ('backbone', ) to return backbone feature
862
- output_layer_names = None
863
- pose_results, _ = inference_top_down_pose_model(
864
- self.pose_model,
865
- image,
866
- person_results,
867
- bbox_thr=self.bbox_thr,
868
- format='xyxy',
869
- dataset=dataset,
870
- dataset_info=None,
871
- return_heatmap=return_heatmap,
872
- outputs=output_layer_names)
873
-
874
- # show the results
875
- im_pose = imshow_keypoints(
876
- image,
877
- pose_results,
878
- skeleton=self.skeleton,
879
- pose_kpt_color=self.pose_kpt_color,
880
- pose_link_color=self.pose_link_color,
881
- radius=2,
882
- thickness=2)
883
- # im_pose = cv2.resize(im_pose, (512, 512))
884
-
885
- # extract condition features
886
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
887
- nc = self.base_model.get_learned_conditioning([neg_prompt])
888
- pose = img2tensor(im_pose, bgr2rgb=True, float32=True) / 255.
889
- pose = pose.unsqueeze(0)
890
- features_adapter = self.model_pose(pose.to(self.device))
891
-
892
- shape = [4, 64, 64]
893
-
894
- # sampling
895
- samples_ddim, _ = self.sampler.sample(S=50,
896
- conditioning=c,
897
- batch_size=1,
898
- shape=shape,
899
- verbose=False,
900
- unconditional_guidance_scale=scale,
901
- unconditional_conditioning=nc,
902
- eta=0.0,
903
- x_T=None,
904
- features_adapter1=features_adapter,
905
- mode='sketch',
906
- con_strength=con_strength)
907
-
908
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
909
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
910
- x_samples_ddim = x_samples_ddim.to('cpu')
911
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
912
- x_samples_ddim = 255. * x_samples_ddim
913
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
914
-
915
- return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
916
-
917
- @torch.no_grad()
918
- def process_openpose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength,
919
- base_model):
920
- if self.current_base != base_model:
921
- ckpt = os.path.join("models", base_model)
922
- pl_sd = torch.load(ckpt, map_location="cuda")
923
- if "state_dict" in pl_sd:
924
- sd = pl_sd["state_dict"]
925
- else:
926
- sd = pl_sd
927
- self.base_model.load_state_dict(sd, strict=False)
928
- self.current_base = base_model
929
- if 'anything' in base_model.lower():
930
- self.load_vae()
931
-
932
- con_strength = int((1 - con_strength) * 50)
933
- if fix_sample == 'True':
934
- seed_everything(42)
935
- im = cv2.resize(input_img, (512, 512))
936
-
937
- if type_in == 'Openpose':
938
- im_pose = im.copy()[:,:,::-1]
939
- elif type_in == 'Image':
940
- from ldm.modules.structure_condition.openpose.api import OpenposeInference
941
- model = OpenposeInference()
942
- keypose = model(im[:,:,::-1])
943
- im_pose = keypose.copy()
944
-
945
- # extract condition features
946
- c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
947
- nc = self.base_model.get_learned_conditioning([neg_prompt])
948
- pose = img2tensor(im_pose, bgr2rgb=True, float32=True) / 255.
949
- pose = pose.unsqueeze(0)
950
- features_adapter = self.model_openpose(pose.to(self.device))
951
-
952
- shape = [4, 64, 64]
953
-
954
- # sampling
955
- samples_ddim, _ = self.sampler.sample(S=50,
956
- conditioning=c,
957
- batch_size=1,
958
- shape=shape,
959
- verbose=False,
960
- unconditional_guidance_scale=scale,
961
- unconditional_conditioning=nc,
962
- eta=0.0,
963
- x_T=None,
964
- features_adapter1=features_adapter,
965
- mode='sketch',
966
- con_strength=con_strength)
967
-
968
- x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
969
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
970
- x_samples_ddim = x_samples_ddim.to('cpu')
971
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
972
- x_samples_ddim = 255. * x_samples_ddim
973
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
974
-
975
- return [im_pose[:, :, ::-1].astype(np.uint8), x_samples_ddim]
976
-
977
-
978
- if __name__ == '__main__':
979
- model = Model_all('cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dist_util.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.multiprocessing as mp
8
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
9
+
10
+
11
+ def init_dist(launcher, backend='nccl', **kwargs):
12
+ if mp.get_start_method(allow_none=True) is None:
13
+ mp.set_start_method('spawn')
14
+ if launcher == 'pytorch':
15
+ _init_dist_pytorch(backend, **kwargs)
16
+ elif launcher == 'slurm':
17
+ _init_dist_slurm(backend, **kwargs)
18
+ else:
19
+ raise ValueError(f'Invalid launcher type: {launcher}')
20
+
21
+
22
+ def _init_dist_pytorch(backend, **kwargs):
23
+ rank = int(os.environ['RANK'])
24
+ num_gpus = torch.cuda.device_count()
25
+ torch.cuda.set_device(rank % num_gpus)
26
+ dist.init_process_group(backend=backend, **kwargs)
27
+
28
+
29
+ def _init_dist_slurm(backend, port=None):
30
+ """Initialize slurm distributed training environment.
31
+
32
+ If argument ``port`` is not specified, then the master port will be system
33
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
34
+ environment variable, then a default port ``29500`` will be used.
35
+
36
+ Args:
37
+ backend (str): Backend of torch.distributed.
38
+ port (int, optional): Master port. Defaults to None.
39
+ """
40
+ proc_id = int(os.environ['SLURM_PROCID'])
41
+ ntasks = int(os.environ['SLURM_NTASKS'])
42
+ node_list = os.environ['SLURM_NODELIST']
43
+ num_gpus = torch.cuda.device_count()
44
+ torch.cuda.set_device(proc_id % num_gpus)
45
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
46
+ # specify master port
47
+ if port is not None:
48
+ os.environ['MASTER_PORT'] = str(port)
49
+ elif 'MASTER_PORT' in os.environ:
50
+ pass # use MASTER_PORT in the environment variable
51
+ else:
52
+ # 29500 is torch.distributed default port
53
+ os.environ['MASTER_PORT'] = '29500'
54
+ os.environ['MASTER_ADDR'] = addr
55
+ os.environ['WORLD_SIZE'] = str(ntasks)
56
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
57
+ os.environ['RANK'] = str(proc_id)
58
+ dist.init_process_group(backend=backend)
59
+
60
+
61
+ def get_dist_info():
62
+ if dist.is_available():
63
+ initialized = dist.is_initialized()
64
+ else:
65
+ initialized = False
66
+ if initialized:
67
+ rank = dist.get_rank()
68
+ world_size = dist.get_world_size()
69
+ else:
70
+ rank = 0
71
+ world_size = 1
72
+ return rank, world_size
73
+
74
+
75
+ def master_only(func):
76
+
77
+ @functools.wraps(func)
78
+ def wrapper(*args, **kwargs):
79
+ rank, _ = get_dist_info()
80
+ if rank == 0:
81
+ return func(*args, **kwargs)
82
+
83
+ return wrapper
84
+
85
+ def get_bare_model(net):
86
+ """Get bare model, especially under wrapping with
87
+ DistributedDataParallel or DataParallel.
88
+ """
89
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
90
+ net = net.module
91
+ return net
docs/AdapterZoo.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapter Zoo
2
+
3
+ You can download the adapters from <https://huggingface.co/TencentARC/T2I-Adapter/tree/main>
4
+
5
+ All the following adapters are trained with Stable Diffusion (SD) V1.4, and they can be directly used on custom models as long as they are fine-tuned from the same text-to-image models, such as Anything-4.0 or models on the <https://civitai.com/>.
6
+
7
+ | Adapter Name | Adapter Description | Demos|Model Parameters| Model Storage | |
8
+ | --- | --- |--- |--- |--- |---|
9
+ | t2iadapter_color_sd14v1.pth | Spatial color palette β†’ image | [Demos](examples.md#color-adapter-spatial-palette) |18 M | 75 MB | |
10
+ | t2iadapter_style_sd14v1.pth | Image style β†’ image | [Demos](examples.md#style-adapter)|| 154MB | Preliminary model. Style adapters with finer controls are on the way|
11
+ | t2iadapter_openpose_sd14v1.pth | Openpose β†’ image| [Demos](examples.md#openpose-adapter) |77 M| 309 MB | |
12
+ | t2iadapter_canny_sd14v1.pth | Canny edges β†’ image | [Demos](examples.md#canny-adapter-edge )|77 M | 309 MB ||
13
+ | t2iadapter_sketch_sd14v1.pth | sketch β†’ image ||77 M| 308 MB | |
14
+ | t2iadapter_keypose_sd14v1.pth | keypose β†’ image || 77 M| 309 MB | mmpose style |
15
+ | t2iadapter_seg_sd14v1.pth | segmentation β†’ image ||77 M| 309 MB ||
16
+ | t2iadapter_depth_sd14v1.pth | depth maps β†’ image ||77 M | 309 MB | Not the final model, still under training|
docs/FAQ.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ - **Q: The openpose adapter (t2iadapter_openpose_sd14v1) outputs gray-scale images.**
4
+
5
+ **A:** You can add `colorful` in the prompt to avoid this problem.
docs/examples.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demos
2
+
3
+ ## Style Adapter
4
+
5
+ <p align="center">
6
+ <img src="https://user-images.githubusercontent.com/17445847/222734169-d47789e8-e83c-48c2-80ef-a896c2bafbb0.png" height=450>
7
+ </p>
8
+
9
+ ## Color Adapter (Spatial Palette)
10
+
11
+ <p align="center">
12
+ <img src="https://user-images.githubusercontent.com/17445847/222915829-ccfb0366-13a8-484a-9561-627fabd87d29.png" height=450>
13
+ </p>
14
+
15
+ ## Openpose Adapter
16
+
17
+ <p align="center">
18
+ <img src="https://user-images.githubusercontent.com/17445847/222733916-dc26a66e-d786-4407-8889-b81804862b1a.png" height=450>
19
+ </p>
20
+
21
+ ## Canny Adapter (Edge)
22
+
23
+ <p align="center">
24
+ <img src="https://user-images.githubusercontent.com/17445847/222915813-c8f264bd-1be6-4496-97ff-aec4f6b53788.png" height=450>
25
+ </p>
26
+
27
+ ## Multi-adapters
28
+ <p align="center">
29
+ <img src="https://user-images.githubusercontent.com/17445847/220939329-379f88b7-444f-4a3a-9de0-8f90605d1d34.png" height=450>
30
+ </p>
31
+
32
+ <div align="center">
33
+
34
+ *T2I adapters naturally support using multiple adapters together.*
35
+
36
+ </div><br />
37
+ The testing script usage for this example is similar to the command line given below, except that we replaced the pretrained SD model with Anything 4.5 and Kenshi
38
+
39
+ >python test_composable_adapters.py --prompt "1gril, computer desk, best quality, extremely detailed" --neg_prompt "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" --depth_cond_path examples/depth/desk_depth.png --depth_cond_weight 1.0 --depth_ckpt models/t2iadapter_depth_sd14v1.pth --depth_type_in depth --pose_cond_path examples/keypose/person_keypose.png --pose_cond_weight 1.5 --ckpt models/anything-v4.0-pruned.ckpt --n_sample 4 --max_resolution 524288
40
+
41
+ [Image source](https://twitter.com/toyxyz3/status/1628375164781211648)
environment.yaml DELETED
@@ -1,31 +0,0 @@
1
- name: ldm
2
- channels:
3
- - pytorch
4
- - defaults
5
- dependencies:
6
- - python=3.8.5
7
- - pip=20.3
8
- - cudatoolkit=11.3
9
- - pytorch=1.11.0
10
- - torchvision=0.12.0
11
- - numpy=1.19.2
12
- - pip:
13
- - albumentations==0.4.3
14
- - diffusers
15
- - opencv-python==4.1.2.30
16
- - pudb==2019.2
17
- - invisible-watermark
18
- - imageio==2.9.0
19
- - imageio-ffmpeg==0.4.2
20
- - pytorch-lightning==1.4.2
21
- - omegaconf==2.1.1
22
- - test-tube>=0.7.5
23
- - streamlit>=0.73.1
24
- - einops==0.3.0
25
- - torch-fidelity==0.3.0
26
- - transformers==4.19.2
27
- - torchmetrics==0.6.0
28
- - kornia==0.6
29
- - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
30
- - -e git+https://github.com/openai/CLIP.git@main#egg=clip
31
- - -e .
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/structure_condition/midas/__init__.py β†’ experiments/README.md RENAMED
File without changes
ldm/data/base.py DELETED
@@ -1,23 +0,0 @@
1
- from abc import abstractmethod
2
- from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3
-
4
-
5
- class Txt2ImgIterableBaseDataset(IterableDataset):
6
- '''
7
- Define an interface to make the IterableDatasets for text2img data chainable
8
- '''
9
- def __init__(self, num_records=0, valid_ids=None, size=256):
10
- super().__init__()
11
- self.num_records = num_records
12
- self.valid_ids = valid_ids
13
- self.sample_ids = valid_ids
14
- self.size = size
15
-
16
- print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17
-
18
- def __len__(self):
19
- return self.num_records
20
-
21
- @abstractmethod
22
- def __iter__(self):
23
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/data/dataset_coco.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import cv2
3
+ import os
4
+ from basicsr.utils import img2tensor
5
+
6
+
7
+ class dataset_coco_mask_color():
8
+ def __init__(self, path_json, root_path_im, root_path_mask, image_size):
9
+ super(dataset_coco_mask_color, self).__init__()
10
+ with open(path_json, 'r', encoding='utf-8') as fp:
11
+ data = json.load(fp)
12
+ data = data['annotations']
13
+ self.files = []
14
+ self.root_path_im = root_path_im
15
+ self.root_path_mask = root_path_mask
16
+ for file in data:
17
+ name = "%012d.png" % file['image_id']
18
+ self.files.append({'name': name, 'sentence': file['caption']})
19
+
20
+ def __getitem__(self, idx):
21
+ file = self.files[idx]
22
+ name = file['name']
23
+ # print(os.path.join(self.root_path_im, name))
24
+ im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png', '.jpg')))
25
+ im = cv2.resize(im, (512, 512))
26
+ im = img2tensor(im, bgr2rgb=True, float32=True) / 255.
27
+
28
+ mask = cv2.imread(os.path.join(self.root_path_mask, name)) # [:,:,0]
29
+ mask = cv2.resize(mask, (512, 512))
30
+ mask = img2tensor(mask, bgr2rgb=True, float32=True) / 255. # [0].unsqueeze(0)#/255.
31
+
32
+ sentence = file['sentence']
33
+ return {'im': im, 'mask': mask, 'sentence': sentence}
34
+
35
+ def __len__(self):
36
+ return len(self.files)
ldm/data/dataset_depth.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import cv2
3
+ import os
4
+ from basicsr.utils import img2tensor
5
+
6
+
7
+ class DepthDataset():
8
+ def __init__(self, meta_file):
9
+ super(DepthDataset, self).__init__()
10
+
11
+ self.files = []
12
+ with open(meta_file, 'r') as f:
13
+ lines = f.readlines()
14
+ for line in lines:
15
+ img_path = line.strip()
16
+ depth_img_path = img_path.rsplit('.', 1)[0] + '.depth.png'
17
+ txt_path = img_path.rsplit('.', 1)[0] + '.txt'
18
+ self.files.append({'img_path': img_path, 'depth_img_path': depth_img_path, 'txt_path': txt_path})
19
+
20
+ def __getitem__(self, idx):
21
+ file = self.files[idx]
22
+
23
+ im = cv2.imread(file['img_path'])
24
+ im = img2tensor(im, bgr2rgb=True, float32=True) / 255.
25
+
26
+ depth = cv2.imread(file['depth_img_path']) # [:,:,0]
27
+ depth = img2tensor(depth, bgr2rgb=True, float32=True) / 255. # [0].unsqueeze(0)#/255.
28
+
29
+ with open(file['txt_path'], 'r') as fs:
30
+ sentence = fs.readline().strip()
31
+
32
+ return {'im': im, 'depth': depth, 'sentence': sentence}
33
+
34
+ def __len__(self):
35
+ return len(self.files)
ldm/data/dataset_laion.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+ import os
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ import webdataset as wds
8
+ from torchvision.transforms import transforms
9
+
10
+ from ldm.util import instantiate_from_config
11
+
12
+
13
+ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
14
+ """Take a list of samples (as dictionary) and create a batch, preserving the keys.
15
+ If `tensors` is True, `ndarray` objects are combined into
16
+ tensor batches.
17
+ :param dict samples: list of samples
18
+ :param bool tensors: whether to turn lists of ndarrays into a single ndarray
19
+ :returns: single sample consisting of a batch
20
+ :rtype: dict
21
+ """
22
+ keys = set.intersection(*[set(sample.keys()) for sample in samples])
23
+ batched = {key: [] for key in keys}
24
+
25
+ for s in samples:
26
+ [batched[key].append(s[key]) for key in batched]
27
+
28
+ result = {}
29
+ for key in batched:
30
+ if isinstance(batched[key][0], (int, float)):
31
+ if combine_scalars:
32
+ result[key] = np.array(list(batched[key]))
33
+ elif isinstance(batched[key][0], torch.Tensor):
34
+ if combine_tensors:
35
+ result[key] = torch.stack(list(batched[key]))
36
+ elif isinstance(batched[key][0], np.ndarray):
37
+ if combine_tensors:
38
+ result[key] = np.array(list(batched[key]))
39
+ else:
40
+ result[key] = list(batched[key])
41
+ return result
42
+
43
+
44
+ class WebDataModuleFromConfig(pl.LightningDataModule):
45
+
46
+ def __init__(self,
47
+ tar_base,
48
+ batch_size,
49
+ train=None,
50
+ validation=None,
51
+ test=None,
52
+ num_workers=4,
53
+ multinode=True,
54
+ min_size=None,
55
+ max_pwatermark=1.0,
56
+ **kwargs):
57
+ super().__init__()
58
+ print(f'Setting tar base to {tar_base}')
59
+ self.tar_base = tar_base
60
+ self.batch_size = batch_size
61
+ self.num_workers = num_workers
62
+ self.train = train
63
+ self.validation = validation
64
+ self.test = test
65
+ self.multinode = multinode
66
+ self.min_size = min_size # filter out very small images
67
+ self.max_pwatermark = max_pwatermark # filter out watermarked images
68
+
69
+ def make_loader(self, dataset_config):
70
+ image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
71
+ image_transforms = transforms.Compose(image_transforms)
72
+
73
+ process = instantiate_from_config(dataset_config['process'])
74
+
75
+ shuffle = dataset_config.get('shuffle', 0)
76
+ shardshuffle = shuffle > 0
77
+
78
+ nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
79
+
80
+ tars = os.path.join(self.tar_base, dataset_config.shards)
81
+
82
+ dset = wds.WebDataset(
83
+ tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle,
84
+ handler=wds.warn_and_continue).repeat().shuffle(shuffle)
85
+ print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
86
+
87
+ dset = (
88
+ dset.select(self.filter_keys).decode('pil',
89
+ handler=wds.warn_and_continue).select(self.filter_size).map_dict(
90
+ jpg=image_transforms, handler=wds.warn_and_continue).map(process))
91
+ dset = (dset.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn))
92
+
93
+ loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers)
94
+
95
+ return loader
96
+
97
+ def filter_size(self, x):
98
+ if self.min_size is None:
99
+ return True
100
+ try:
101
+ return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[
102
+ 'json']['pwatermark'] <= self.max_pwatermark
103
+ except Exception:
104
+ return False
105
+
106
+ def filter_keys(self, x):
107
+ try:
108
+ return ("jpg" in x) and ("txt" in x)
109
+ except Exception:
110
+ return False
111
+
112
+ def train_dataloader(self):
113
+ return self.make_loader(self.train)
114
+
115
+ def val_dataloader(self):
116
+ return None
117
+
118
+ def test_dataloader(self):
119
+ return None
120
+
121
+
122
+ if __name__ == '__main__':
123
+ from omegaconf import OmegaConf
124
+ config = OmegaConf.load("configs/stable-diffusion/train_canny_sd_v1.yaml")
125
+ datamod = WebDataModuleFromConfig(**config["data"]["params"])
126
+ dataloader = datamod.train_dataloader()
127
+
128
+ for batch in dataloader:
129
+ print(batch.keys())
130
+ print(batch['jpg'].shape)
ldm/data/dataset_wikiart.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+
4
+ from PIL import Image
5
+ from torch.utils.data import DataLoader
6
+
7
+ from transformers import CLIPProcessor
8
+ from torchvision.transforms import transforms
9
+
10
+ import pytorch_lightning as pl
11
+
12
+
13
+ class WikiArtDataset():
14
+ def __init__(self, meta_file):
15
+ super(WikiArtDataset, self).__init__()
16
+
17
+ self.files = []
18
+ with open(meta_file, 'r') as f:
19
+ js = json.load(f)
20
+ for img_path in js:
21
+ img_name = os.path.splitext(os.path.basename(img_path))[0]
22
+ caption = img_name.split('_')[-1]
23
+ caption = caption.split('-')
24
+ j = len(caption) - 1
25
+ while j >= 0:
26
+ if not caption[j].isdigit():
27
+ break
28
+ j -= 1
29
+ if j < 0:
30
+ continue
31
+ sentence = ' '.join(caption[:j + 1])
32
+ self.files.append({'img_path': os.path.join('datasets/wikiart', img_path), 'sentence': sentence})
33
+
34
+ version = 'openai/clip-vit-large-patch14'
35
+ self.processor = CLIPProcessor.from_pretrained(version)
36
+
37
+ self.jpg_transform = transforms.Compose([
38
+ transforms.Resize(512),
39
+ transforms.RandomCrop(512),
40
+ transforms.ToTensor(),
41
+ ])
42
+
43
+ def __getitem__(self, idx):
44
+ file = self.files[idx]
45
+
46
+ im = Image.open(file['img_path'])
47
+
48
+ im_tensor = self.jpg_transform(im)
49
+
50
+ clip_im = self.processor(images=im, return_tensors="pt")['pixel_values'][0]
51
+
52
+ return {'jpg': im_tensor, 'style': clip_im, 'txt': file['sentence']}
53
+
54
+ def __len__(self):
55
+ return len(self.files)
56
+
57
+
58
+ class WikiArtDataModule(pl.LightningDataModule):
59
+ def __init__(self, meta_file, batch_size, num_workers):
60
+ super(WikiArtDataModule, self).__init__()
61
+ self.train_dataset = WikiArtDataset(meta_file)
62
+ self.batch_size = batch_size
63
+ self.num_workers = num_workers
64
+
65
+ def train_dataloader(self):
66
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
67
+ pin_memory=True)
ldm/data/imagenet.py DELETED
@@ -1,394 +0,0 @@
1
- import os, yaml, pickle, shutil, tarfile, glob
2
- import cv2
3
- import albumentations
4
- import PIL
5
- import numpy as np
6
- import torchvision.transforms.functional as TF
7
- from omegaconf import OmegaConf
8
- from functools import partial
9
- from PIL import Image
10
- from tqdm import tqdm
11
- from torch.utils.data import Dataset, Subset
12
-
13
- import taming.data.utils as tdu
14
- from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
- from taming.data.imagenet import ImagePaths
16
-
17
- from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
-
19
-
20
- def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
- with open(path_to_yaml) as f:
22
- di2s = yaml.load(f)
23
- return dict((v,k) for k,v in di2s.items())
24
-
25
-
26
- class ImageNetBase(Dataset):
27
- def __init__(self, config=None):
28
- self.config = config or OmegaConf.create()
29
- if not type(self.config)==dict:
30
- self.config = OmegaConf.to_container(self.config)
31
- self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
- self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
- self._prepare()
34
- self._prepare_synset_to_human()
35
- self._prepare_idx_to_synset()
36
- self._prepare_human_to_integer_label()
37
- self._load()
38
-
39
- def __len__(self):
40
- return len(self.data)
41
-
42
- def __getitem__(self, i):
43
- return self.data[i]
44
-
45
- def _prepare(self):
46
- raise NotImplementedError()
47
-
48
- def _filter_relpaths(self, relpaths):
49
- ignore = set([
50
- "n06596364_9591.JPEG",
51
- ])
52
- relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
- if "sub_indices" in self.config:
54
- indices = str_to_indices(self.config["sub_indices"])
55
- synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
- self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
- files = []
58
- for rpath in relpaths:
59
- syn = rpath.split("/")[0]
60
- if syn in synsets:
61
- files.append(rpath)
62
- return files
63
- else:
64
- return relpaths
65
-
66
- def _prepare_synset_to_human(self):
67
- SIZE = 2655750
68
- URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
- self.human_dict = os.path.join(self.root, "synset_human.txt")
70
- if (not os.path.exists(self.human_dict) or
71
- not os.path.getsize(self.human_dict)==SIZE):
72
- download(URL, self.human_dict)
73
-
74
- def _prepare_idx_to_synset(self):
75
- URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
- self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
- if (not os.path.exists(self.idx2syn)):
78
- download(URL, self.idx2syn)
79
-
80
- def _prepare_human_to_integer_label(self):
81
- URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
- self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
- if (not os.path.exists(self.human2integer)):
84
- download(URL, self.human2integer)
85
- with open(self.human2integer, "r") as f:
86
- lines = f.read().splitlines()
87
- assert len(lines) == 1000
88
- self.human2integer_dict = dict()
89
- for line in lines:
90
- value, key = line.split(":")
91
- self.human2integer_dict[key] = int(value)
92
-
93
- def _load(self):
94
- with open(self.txt_filelist, "r") as f:
95
- self.relpaths = f.read().splitlines()
96
- l1 = len(self.relpaths)
97
- self.relpaths = self._filter_relpaths(self.relpaths)
98
- print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
-
100
- self.synsets = [p.split("/")[0] for p in self.relpaths]
101
- self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
-
103
- unique_synsets = np.unique(self.synsets)
104
- class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
- if not self.keep_orig_class_label:
106
- self.class_labels = [class_dict[s] for s in self.synsets]
107
- else:
108
- self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
-
110
- with open(self.human_dict, "r") as f:
111
- human_dict = f.read().splitlines()
112
- human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
-
114
- self.human_labels = [human_dict[s] for s in self.synsets]
115
-
116
- labels = {
117
- "relpath": np.array(self.relpaths),
118
- "synsets": np.array(self.synsets),
119
- "class_label": np.array(self.class_labels),
120
- "human_label": np.array(self.human_labels),
121
- }
122
-
123
- if self.process_images:
124
- self.size = retrieve(self.config, "size", default=256)
125
- self.data = ImagePaths(self.abspaths,
126
- labels=labels,
127
- size=self.size,
128
- random_crop=self.random_crop,
129
- )
130
- else:
131
- self.data = self.abspaths
132
-
133
-
134
- class ImageNetTrain(ImageNetBase):
135
- NAME = "ILSVRC2012_train"
136
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
- AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
- FILES = [
139
- "ILSVRC2012_img_train.tar",
140
- ]
141
- SIZES = [
142
- 147897477120,
143
- ]
144
-
145
- def __init__(self, process_images=True, data_root=None, **kwargs):
146
- self.process_images = process_images
147
- self.data_root = data_root
148
- super().__init__(**kwargs)
149
-
150
- def _prepare(self):
151
- if self.data_root:
152
- self.root = os.path.join(self.data_root, self.NAME)
153
- else:
154
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
-
157
- self.datadir = os.path.join(self.root, "data")
158
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
- self.expected_length = 1281167
160
- self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
- default=True)
162
- if not tdu.is_prepared(self.root):
163
- # prep
164
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
-
166
- datadir = self.datadir
167
- if not os.path.exists(datadir):
168
- path = os.path.join(self.root, self.FILES[0])
169
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
- import academictorrents as at
171
- atpath = at.get(self.AT_HASH, datastore=self.root)
172
- assert atpath == path
173
-
174
- print("Extracting {} to {}".format(path, datadir))
175
- os.makedirs(datadir, exist_ok=True)
176
- with tarfile.open(path, "r:") as tar:
177
- tar.extractall(path=datadir)
178
-
179
- print("Extracting sub-tars.")
180
- subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
- for subpath in tqdm(subpaths):
182
- subdir = subpath[:-len(".tar")]
183
- os.makedirs(subdir, exist_ok=True)
184
- with tarfile.open(subpath, "r:") as tar:
185
- tar.extractall(path=subdir)
186
-
187
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
- filelist = sorted(filelist)
190
- filelist = "\n".join(filelist)+"\n"
191
- with open(self.txt_filelist, "w") as f:
192
- f.write(filelist)
193
-
194
- tdu.mark_prepared(self.root)
195
-
196
-
197
- class ImageNetValidation(ImageNetBase):
198
- NAME = "ILSVRC2012_validation"
199
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
- AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
- VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
- FILES = [
203
- "ILSVRC2012_img_val.tar",
204
- "validation_synset.txt",
205
- ]
206
- SIZES = [
207
- 6744924160,
208
- 1950000,
209
- ]
210
-
211
- def __init__(self, process_images=True, data_root=None, **kwargs):
212
- self.data_root = data_root
213
- self.process_images = process_images
214
- super().__init__(**kwargs)
215
-
216
- def _prepare(self):
217
- if self.data_root:
218
- self.root = os.path.join(self.data_root, self.NAME)
219
- else:
220
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
- self.datadir = os.path.join(self.root, "data")
223
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
- self.expected_length = 50000
225
- self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
- default=False)
227
- if not tdu.is_prepared(self.root):
228
- # prep
229
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
-
231
- datadir = self.datadir
232
- if not os.path.exists(datadir):
233
- path = os.path.join(self.root, self.FILES[0])
234
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
- import academictorrents as at
236
- atpath = at.get(self.AT_HASH, datastore=self.root)
237
- assert atpath == path
238
-
239
- print("Extracting {} to {}".format(path, datadir))
240
- os.makedirs(datadir, exist_ok=True)
241
- with tarfile.open(path, "r:") as tar:
242
- tar.extractall(path=datadir)
243
-
244
- vspath = os.path.join(self.root, self.FILES[1])
245
- if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
- download(self.VS_URL, vspath)
247
-
248
- with open(vspath, "r") as f:
249
- synset_dict = f.read().splitlines()
250
- synset_dict = dict(line.split() for line in synset_dict)
251
-
252
- print("Reorganizing into synset folders")
253
- synsets = np.unique(list(synset_dict.values()))
254
- for s in synsets:
255
- os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
- for k, v in synset_dict.items():
257
- src = os.path.join(datadir, k)
258
- dst = os.path.join(datadir, v)
259
- shutil.move(src, dst)
260
-
261
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
- filelist = sorted(filelist)
264
- filelist = "\n".join(filelist)+"\n"
265
- with open(self.txt_filelist, "w") as f:
266
- f.write(filelist)
267
-
268
- tdu.mark_prepared(self.root)
269
-
270
-
271
-
272
- class ImageNetSR(Dataset):
273
- def __init__(self, size=None,
274
- degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
- random_crop=True):
276
- """
277
- Imagenet Superresolution Dataloader
278
- Performs following ops in order:
279
- 1. crops a crop of size s from image either as random or center crop
280
- 2. resizes crop to size with cv2.area_interpolation
281
- 3. degrades resized crop with degradation_fn
282
-
283
- :param size: resizing to size after cropping
284
- :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
- :param downscale_f: Low Resolution Downsample factor
286
- :param min_crop_f: determines crop size s,
287
- where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
- :param max_crop_f: ""
289
- :param data_root:
290
- :param random_crop:
291
- """
292
- self.base = self.get_base()
293
- assert size
294
- assert (size / downscale_f).is_integer()
295
- self.size = size
296
- self.LR_size = int(size / downscale_f)
297
- self.min_crop_f = min_crop_f
298
- self.max_crop_f = max_crop_f
299
- assert(max_crop_f <= 1.)
300
- self.center_crop = not random_crop
301
-
302
- self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
-
304
- self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
-
306
- if degradation == "bsrgan":
307
- self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
-
309
- elif degradation == "bsrgan_light":
310
- self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
-
312
- else:
313
- interpolation_fn = {
314
- "cv_nearest": cv2.INTER_NEAREST,
315
- "cv_bilinear": cv2.INTER_LINEAR,
316
- "cv_bicubic": cv2.INTER_CUBIC,
317
- "cv_area": cv2.INTER_AREA,
318
- "cv_lanczos": cv2.INTER_LANCZOS4,
319
- "pil_nearest": PIL.Image.NEAREST,
320
- "pil_bilinear": PIL.Image.BILINEAR,
321
- "pil_bicubic": PIL.Image.BICUBIC,
322
- "pil_box": PIL.Image.BOX,
323
- "pil_hamming": PIL.Image.HAMMING,
324
- "pil_lanczos": PIL.Image.LANCZOS,
325
- }[degradation]
326
-
327
- self.pil_interpolation = degradation.startswith("pil_")
328
-
329
- if self.pil_interpolation:
330
- self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
-
332
- else:
333
- self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
- interpolation=interpolation_fn)
335
-
336
- def __len__(self):
337
- return len(self.base)
338
-
339
- def __getitem__(self, i):
340
- example = self.base[i]
341
- image = Image.open(example["file_path_"])
342
-
343
- if not image.mode == "RGB":
344
- image = image.convert("RGB")
345
-
346
- image = np.array(image).astype(np.uint8)
347
-
348
- min_side_len = min(image.shape[:2])
349
- crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
- crop_side_len = int(crop_side_len)
351
-
352
- if self.center_crop:
353
- self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
-
355
- else:
356
- self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
-
358
- image = self.cropper(image=image)["image"]
359
- image = self.image_rescaler(image=image)["image"]
360
-
361
- if self.pil_interpolation:
362
- image_pil = PIL.Image.fromarray(image)
363
- LR_image = self.degradation_process(image_pil)
364
- LR_image = np.array(LR_image).astype(np.uint8)
365
-
366
- else:
367
- LR_image = self.degradation_process(image=image)["image"]
368
-
369
- example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
- example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
-
372
- return example
373
-
374
-
375
- class ImageNetSRTrain(ImageNetSR):
376
- def __init__(self, **kwargs):
377
- super().__init__(**kwargs)
378
-
379
- def get_base(self):
380
- with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
- indices = pickle.load(f)
382
- dset = ImageNetTrain(process_images=False,)
383
- return Subset(dset, indices)
384
-
385
-
386
- class ImageNetSRValidation(ImageNetSR):
387
- def __init__(self, **kwargs):
388
- super().__init__(**kwargs)
389
-
390
- def get_base(self):
391
- with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
- indices = pickle.load(f)
393
- dset = ImageNetValidation(process_images=False,)
394
- return Subset(dset, indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/data/lsun.py DELETED
@@ -1,92 +0,0 @@
1
- import os
2
- import numpy as np
3
- import PIL
4
- from PIL import Image
5
- from torch.utils.data import Dataset
6
- from torchvision import transforms
7
-
8
-
9
- class LSUNBase(Dataset):
10
- def __init__(self,
11
- txt_file,
12
- data_root,
13
- size=None,
14
- interpolation="bicubic",
15
- flip_p=0.5
16
- ):
17
- self.data_paths = txt_file
18
- self.data_root = data_root
19
- with open(self.data_paths, "r") as f:
20
- self.image_paths = f.read().splitlines()
21
- self._length = len(self.image_paths)
22
- self.labels = {
23
- "relative_file_path_": [l for l in self.image_paths],
24
- "file_path_": [os.path.join(self.data_root, l)
25
- for l in self.image_paths],
26
- }
27
-
28
- self.size = size
29
- self.interpolation = {"linear": PIL.Image.LINEAR,
30
- "bilinear": PIL.Image.BILINEAR,
31
- "bicubic": PIL.Image.BICUBIC,
32
- "lanczos": PIL.Image.LANCZOS,
33
- }[interpolation]
34
- self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
-
36
- def __len__(self):
37
- return self._length
38
-
39
- def __getitem__(self, i):
40
- example = dict((k, self.labels[k][i]) for k in self.labels)
41
- image = Image.open(example["file_path_"])
42
- if not image.mode == "RGB":
43
- image = image.convert("RGB")
44
-
45
- # default to score-sde preprocessing
46
- img = np.array(image).astype(np.uint8)
47
- crop = min(img.shape[0], img.shape[1])
48
- h, w, = img.shape[0], img.shape[1]
49
- img = img[(h - crop) // 2:(h + crop) // 2,
50
- (w - crop) // 2:(w + crop) // 2]
51
-
52
- image = Image.fromarray(img)
53
- if self.size is not None:
54
- image = image.resize((self.size, self.size), resample=self.interpolation)
55
-
56
- image = self.flip(image)
57
- image = np.array(image).astype(np.uint8)
58
- example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
- return example
60
-
61
-
62
- class LSUNChurchesTrain(LSUNBase):
63
- def __init__(self, **kwargs):
64
- super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
-
66
-
67
- class LSUNChurchesValidation(LSUNBase):
68
- def __init__(self, flip_p=0., **kwargs):
69
- super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
- flip_p=flip_p, **kwargs)
71
-
72
-
73
- class LSUNBedroomsTrain(LSUNBase):
74
- def __init__(self, **kwargs):
75
- super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
-
77
-
78
- class LSUNBedroomsValidation(LSUNBase):
79
- def __init__(self, flip_p=0.0, **kwargs):
80
- super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
- flip_p=flip_p, **kwargs)
82
-
83
-
84
- class LSUNCatsTrain(LSUNBase):
85
- def __init__(self, **kwargs):
86
- super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
-
88
-
89
- class LSUNCatsValidation(LSUNBase):
90
- def __init__(self, flip_p=0., **kwargs):
91
- super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
- flip_p=flip_p, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/data/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from torchvision.transforms import transforms
6
+ from torchvision.transforms.functional import to_tensor
7
+ from transformers import CLIPProcessor
8
+
9
+ from basicsr.utils import img2tensor
10
+
11
+
12
+ class AddCannyFreezeThreshold(object):
13
+
14
+ def __init__(self, low_threshold=100, high_threshold=200):
15
+ self.low_threshold = low_threshold
16
+ self.high_threshold = high_threshold
17
+
18
+ def __call__(self, sample):
19
+ # sample['jpg'] is PIL image
20
+ x = sample['jpg']
21
+ img = cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR)
22
+ canny = cv2.Canny(img, self.low_threshold, self.high_threshold)[..., None]
23
+ sample['canny'] = img2tensor(canny, bgr2rgb=True, float32=True) / 255.
24
+ sample['jpg'] = to_tensor(x)
25
+ return sample
26
+
27
+
28
+ class AddStyle(object):
29
+
30
+ def __init__(self, version):
31
+ self.processor = CLIPProcessor.from_pretrained(version)
32
+ self.pil_to_tensor = transforms.ToTensor()
33
+
34
+ def __call__(self, sample):
35
+ # sample['jpg'] is PIL image
36
+ x = sample['jpg']
37
+ style = self.processor(images=x, return_tensors="pt")['pixel_values'][0]
38
+ sample['style'] = style
39
+ sample['jpg'] = to_tensor(x)
40
+ return sample
ldm/inference_base.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from omegaconf import OmegaConf
4
+
5
+ from ldm.models.diffusion.ddim import DDIMSampler
6
+ from ldm.models.diffusion.plms import PLMSSampler
7
+ from ldm.modules.encoders.adapter import Adapter, StyleAdapter, Adapter_light
8
+ from ldm.modules.extra_condition.api import ExtraCondition
9
+ from ldm.util import fix_cond_shapes, load_model_from_config, read_state_dict
10
+
11
+ DEFAULT_NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
12
+ 'fewer digits, cropped, worst quality, low quality'
13
+
14
+
15
+ def get_base_argument_parser() -> argparse.ArgumentParser:
16
+ """get the base argument parser for inference scripts"""
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ '--outdir',
20
+ type=str,
21
+ help='dir to write results to',
22
+ default=None,
23
+ )
24
+
25
+ parser.add_argument(
26
+ '--prompt',
27
+ type=str,
28
+ nargs='?',
29
+ default=None,
30
+ help='positive prompt',
31
+ )
32
+
33
+ parser.add_argument(
34
+ '--neg_prompt',
35
+ type=str,
36
+ default=DEFAULT_NEGATIVE_PROMPT,
37
+ help='negative prompt',
38
+ )
39
+
40
+ parser.add_argument(
41
+ '--cond_path',
42
+ type=str,
43
+ default=None,
44
+ help='condition image path',
45
+ )
46
+
47
+ parser.add_argument(
48
+ '--cond_inp_type',
49
+ type=str,
50
+ default='image',
51
+ help='the type of the input condition image, take depth T2I as example, the input can be raw image, '
52
+ 'which depth will be calculated, or the input can be a directly a depth map image',
53
+ )
54
+
55
+ parser.add_argument(
56
+ '--sampler',
57
+ type=str,
58
+ default='ddim',
59
+ choices=['ddim', 'plms'],
60
+ help='sampling algorithm, currently, only ddim and plms are supported, more are on the way',
61
+ )
62
+
63
+ parser.add_argument(
64
+ '--steps',
65
+ type=int,
66
+ default=50,
67
+ help='number of sampling steps',
68
+ )
69
+
70
+ parser.add_argument(
71
+ '--sd_ckpt',
72
+ type=str,
73
+ default='models/sd-v1-4.ckpt',
74
+ help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported',
75
+ )
76
+
77
+ parser.add_argument(
78
+ '--vae_ckpt',
79
+ type=str,
80
+ default=None,
81
+ help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded',
82
+ )
83
+
84
+ parser.add_argument(
85
+ '--adapter_ckpt',
86
+ type=str,
87
+ default=None,
88
+ help='path to checkpoint of adapter',
89
+ )
90
+
91
+ parser.add_argument(
92
+ '--config',
93
+ type=str,
94
+ default='configs/stable-diffusion/sd-v1-inference.yaml',
95
+ help='path to config which constructs SD model',
96
+ )
97
+
98
+ parser.add_argument(
99
+ '--max_resolution',
100
+ type=float,
101
+ default=512 * 512,
102
+ help='max image height * width, only for computer with limited vram',
103
+ )
104
+
105
+ parser.add_argument(
106
+ '--resize_short_edge',
107
+ type=int,
108
+ default=None,
109
+ help='resize short edge of the input image, if this arg is set, max_resolution will not be used',
110
+ )
111
+
112
+ parser.add_argument(
113
+ '--C',
114
+ type=int,
115
+ default=4,
116
+ help='latent channels',
117
+ )
118
+
119
+ parser.add_argument(
120
+ '--f',
121
+ type=int,
122
+ default=8,
123
+ help='downsampling factor',
124
+ )
125
+
126
+ parser.add_argument(
127
+ '--scale',
128
+ type=float,
129
+ default=7.5,
130
+ help='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))',
131
+ )
132
+
133
+ parser.add_argument(
134
+ '--cond_tau',
135
+ type=float,
136
+ default=1.0,
137
+ help='timestamp parameter that determines until which step the adapter is applied, '
138
+ 'similar as Prompt-to-Prompt tau')
139
+
140
+ parser.add_argument(
141
+ '--cond_weight',
142
+ type=float,
143
+ default=1.0,
144
+ help='the adapter features are multiplied by the cond_weight. The larger the cond_weight, the more aligned '
145
+ 'the generated image and condition will be, but the generated quality may be reduced',
146
+ )
147
+
148
+ parser.add_argument(
149
+ '--seed',
150
+ type=int,
151
+ default=42,
152
+ )
153
+
154
+ parser.add_argument(
155
+ '--n_samples',
156
+ type=int,
157
+ default=4,
158
+ help='# of samples to generate',
159
+ )
160
+
161
+ return parser
162
+
163
+
164
+ def get_sd_models(opt):
165
+ """
166
+ build stable diffusion model, sampler
167
+ """
168
+ # SD
169
+ config = OmegaConf.load(f"{opt.config}")
170
+ model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt)
171
+ sd_model = model.to(opt.device)
172
+
173
+ # sampler
174
+ if opt.sampler == 'plms':
175
+ sampler = PLMSSampler(model)
176
+ elif opt.sampler == 'ddim':
177
+ sampler = DDIMSampler(model)
178
+ else:
179
+ raise NotImplementedError
180
+
181
+ return sd_model, sampler
182
+
183
+
184
+ def get_t2i_adapter_models(opt):
185
+ config = OmegaConf.load(f"{opt.config}")
186
+ model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt)
187
+ adapter_ckpt_path = getattr(opt, f'{opt.which_cond}_adapter_ckpt', None)
188
+ if adapter_ckpt_path is None:
189
+ adapter_ckpt_path = getattr(opt, 'adapter_ckpt')
190
+ adapter_ckpt = read_state_dict(adapter_ckpt_path)
191
+ new_state_dict = {}
192
+ for k, v in adapter_ckpt.items():
193
+ if not k.startswith('adapter.'):
194
+ new_state_dict[f'adapter.{k}'] = v
195
+ else:
196
+ new_state_dict[k] = v
197
+ m, u = model.load_state_dict(new_state_dict, strict=False)
198
+ if len(u) > 0:
199
+ print(f"unexpected keys in loading adapter ckpt {adapter_ckpt_path}:")
200
+ print(u)
201
+
202
+ model = model.to(opt.device)
203
+
204
+ # sampler
205
+ if opt.sampler == 'plms':
206
+ sampler = PLMSSampler(model)
207
+ elif opt.sampler == 'ddim':
208
+ sampler = DDIMSampler(model)
209
+ else:
210
+ raise NotImplementedError
211
+
212
+ return model, sampler
213
+
214
+
215
+ def get_cond_ch(cond_type: ExtraCondition):
216
+ if cond_type == ExtraCondition.sketch or cond_type == ExtraCondition.canny:
217
+ return 1
218
+ return 3
219
+
220
+
221
+ def get_adapters(opt, cond_type: ExtraCondition):
222
+ adapter = {}
223
+ cond_weight = getattr(opt, f'{cond_type.name}_weight', None)
224
+ if cond_weight is None:
225
+ cond_weight = getattr(opt, 'cond_weight')
226
+ adapter['cond_weight'] = cond_weight
227
+
228
+ if cond_type == ExtraCondition.style:
229
+ adapter['model'] = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(opt.device)
230
+ elif cond_type == ExtraCondition.color:
231
+ adapter['model'] = Adapter_light(
232
+ cin=64 * get_cond_ch(cond_type),
233
+ channels=[320, 640, 1280, 1280],
234
+ nums_rb=4).to(opt.device)
235
+ else:
236
+ adapter['model'] = Adapter(
237
+ cin=64 * get_cond_ch(cond_type),
238
+ channels=[320, 640, 1280, 1280][:4],
239
+ nums_rb=2,
240
+ ksize=1,
241
+ sk=True,
242
+ use_conv=False).to(opt.device)
243
+ ckpt_path = getattr(opt, f'{cond_type.name}_adapter_ckpt', None)
244
+ if ckpt_path is None:
245
+ ckpt_path = getattr(opt, 'adapter_ckpt')
246
+ adapter['model'].load_state_dict(torch.load(ckpt_path))
247
+
248
+ return adapter
249
+
250
+
251
+ def diffusion_inference(opt, model, sampler, adapter_features, append_to_context=None):
252
+ # get text embedding
253
+ c = model.get_learned_conditioning([opt.prompt])
254
+ if opt.scale != 1.0:
255
+ uc = model.get_learned_conditioning([opt.neg_prompt])
256
+ else:
257
+ uc = None
258
+ c, uc = fix_cond_shapes(model, c, uc)
259
+
260
+ if not hasattr(opt, 'H'):
261
+ opt.H = 512
262
+ opt.W = 512
263
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
264
+
265
+ samples_latents, _ = sampler.sample(
266
+ S=opt.steps,
267
+ conditioning=c,
268
+ batch_size=1,
269
+ shape=shape,
270
+ verbose=False,
271
+ unconditional_guidance_scale=opt.scale,
272
+ unconditional_conditioning=uc,
273
+ x_T=None,
274
+ features_adapter=adapter_features,
275
+ append_to_context=append_to_context,
276
+ cond_tau=opt.cond_tau,
277
+ )
278
+
279
+ x_samples = model.decode_first_stage(samples_latents)
280
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
281
+
282
+ return x_samples
ldm/models/autoencoder.py CHANGED
@@ -1,64 +1,65 @@
1
  import torch
2
  import pytorch_lightning as pl
3
  import torch.nn.functional as F
 
4
  from contextlib import contextmanager
5
 
6
- from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
-
8
  from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
  from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
 
11
  from ldm.util import instantiate_from_config
 
12
 
13
 
14
- class VQModel(pl.LightningModule):
15
  def __init__(self,
16
  ddconfig,
17
  lossconfig,
18
- n_embed,
19
  embed_dim,
20
  ckpt_path=None,
21
  ignore_keys=[],
22
  image_key="image",
23
  colorize_nlabels=None,
24
  monitor=None,
25
- batch_resize_range=None,
26
- scheduler_config=None,
27
- lr_g_factor=1.0,
28
- remap=None,
29
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
- use_ema=False
31
  ):
32
  super().__init__()
33
- self.embed_dim = embed_dim
34
- self.n_embed = n_embed
35
  self.image_key = image_key
36
  self.encoder = Encoder(**ddconfig)
37
  self.decoder = Decoder(**ddconfig)
38
  self.loss = instantiate_from_config(lossconfig)
39
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
- remap=remap,
41
- sane_index_shape=sane_index_shape)
42
- self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
  self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
 
44
  if colorize_nlabels is not None:
45
  assert type(colorize_nlabels)==int
46
  self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
  if monitor is not None:
48
  self.monitor = monitor
49
- self.batch_resize_range = batch_resize_range
50
- if self.batch_resize_range is not None:
51
- print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
 
53
- self.use_ema = use_ema
54
  if self.use_ema:
55
- self.model_ema = LitEma(self)
 
 
56
  print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
 
58
  if ckpt_path is not None:
59
  self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
- self.scheduler_config = scheduler_config
61
- self.lr_g_factor = lr_g_factor
 
 
 
 
 
 
 
 
 
62
 
63
  @contextmanager
64
  def ema_scope(self, context=None):
@@ -75,252 +76,10 @@ class VQModel(pl.LightningModule):
75
  if context is not None:
76
  print(f"{context}: Restored training weights")
77
 
78
- def init_from_ckpt(self, path, ignore_keys=list()):
79
- sd = torch.load(path, map_location="cpu")["state_dict"]
80
- keys = list(sd.keys())
81
- for k in keys:
82
- for ik in ignore_keys:
83
- if k.startswith(ik):
84
- print("Deleting key {} from state_dict.".format(k))
85
- del sd[k]
86
- missing, unexpected = self.load_state_dict(sd, strict=False)
87
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
- if len(missing) > 0:
89
- print(f"Missing Keys: {missing}")
90
- print(f"Unexpected Keys: {unexpected}")
91
-
92
  def on_train_batch_end(self, *args, **kwargs):
93
  if self.use_ema:
94
  self.model_ema(self)
95
 
96
- def encode(self, x):
97
- h = self.encoder(x)
98
- h = self.quant_conv(h)
99
- quant, emb_loss, info = self.quantize(h)
100
- return quant, emb_loss, info
101
-
102
- def encode_to_prequant(self, x):
103
- h = self.encoder(x)
104
- h = self.quant_conv(h)
105
- return h
106
-
107
- def decode(self, quant):
108
- quant = self.post_quant_conv(quant)
109
- dec = self.decoder(quant)
110
- return dec
111
-
112
- def decode_code(self, code_b):
113
- quant_b = self.quantize.embed_code(code_b)
114
- dec = self.decode(quant_b)
115
- return dec
116
-
117
- def forward(self, input, return_pred_indices=False):
118
- quant, diff, (_,_,ind) = self.encode(input)
119
- dec = self.decode(quant)
120
- if return_pred_indices:
121
- return dec, diff, ind
122
- return dec, diff
123
-
124
- def get_input(self, batch, k):
125
- x = batch[k]
126
- if len(x.shape) == 3:
127
- x = x[..., None]
128
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
- if self.batch_resize_range is not None:
130
- lower_size = self.batch_resize_range[0]
131
- upper_size = self.batch_resize_range[1]
132
- if self.global_step <= 4:
133
- # do the first few batches with max size to avoid later oom
134
- new_resize = upper_size
135
- else:
136
- new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
- if new_resize != x.shape[2]:
138
- x = F.interpolate(x, size=new_resize, mode="bicubic")
139
- x = x.detach()
140
- return x
141
-
142
- def training_step(self, batch, batch_idx, optimizer_idx):
143
- # https://github.com/pytorch/pytorch/issues/37142
144
- # try not to fool the heuristics
145
- x = self.get_input(batch, self.image_key)
146
- xrec, qloss, ind = self(x, return_pred_indices=True)
147
-
148
- if optimizer_idx == 0:
149
- # autoencode
150
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
- last_layer=self.get_last_layer(), split="train",
152
- predicted_indices=ind)
153
-
154
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
- return aeloss
156
-
157
- if optimizer_idx == 1:
158
- # discriminator
159
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
- last_layer=self.get_last_layer(), split="train")
161
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
- return discloss
163
-
164
- def validation_step(self, batch, batch_idx):
165
- log_dict = self._validation_step(batch, batch_idx)
166
- with self.ema_scope():
167
- log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
- return log_dict
169
-
170
- def _validation_step(self, batch, batch_idx, suffix=""):
171
- x = self.get_input(batch, self.image_key)
172
- xrec, qloss, ind = self(x, return_pred_indices=True)
173
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
- self.global_step,
175
- last_layer=self.get_last_layer(),
176
- split="val"+suffix,
177
- predicted_indices=ind
178
- )
179
-
180
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
- self.global_step,
182
- last_layer=self.get_last_layer(),
183
- split="val"+suffix,
184
- predicted_indices=ind
185
- )
186
- rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
- self.log(f"val{suffix}/rec_loss", rec_loss,
188
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
- self.log(f"val{suffix}/aeloss", aeloss,
190
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
- if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
- del log_dict_ae[f"val{suffix}/rec_loss"]
193
- self.log_dict(log_dict_ae)
194
- self.log_dict(log_dict_disc)
195
- return self.log_dict
196
-
197
- def configure_optimizers(self):
198
- lr_d = self.learning_rate
199
- lr_g = self.lr_g_factor*self.learning_rate
200
- print("lr_d", lr_d)
201
- print("lr_g", lr_g)
202
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
- list(self.decoder.parameters())+
204
- list(self.quantize.parameters())+
205
- list(self.quant_conv.parameters())+
206
- list(self.post_quant_conv.parameters()),
207
- lr=lr_g, betas=(0.5, 0.9))
208
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
- lr=lr_d, betas=(0.5, 0.9))
210
-
211
- if self.scheduler_config is not None:
212
- scheduler = instantiate_from_config(self.scheduler_config)
213
-
214
- print("Setting up LambdaLR scheduler...")
215
- scheduler = [
216
- {
217
- 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
- 'interval': 'step',
219
- 'frequency': 1
220
- },
221
- {
222
- 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
- 'interval': 'step',
224
- 'frequency': 1
225
- },
226
- ]
227
- return [opt_ae, opt_disc], scheduler
228
- return [opt_ae, opt_disc], []
229
-
230
- def get_last_layer(self):
231
- return self.decoder.conv_out.weight
232
-
233
- def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
- log = dict()
235
- x = self.get_input(batch, self.image_key)
236
- x = x.to(self.device)
237
- if only_inputs:
238
- log["inputs"] = x
239
- return log
240
- xrec, _ = self(x)
241
- if x.shape[1] > 3:
242
- # colorize with random projection
243
- assert xrec.shape[1] > 3
244
- x = self.to_rgb(x)
245
- xrec = self.to_rgb(xrec)
246
- log["inputs"] = x
247
- log["reconstructions"] = xrec
248
- if plot_ema:
249
- with self.ema_scope():
250
- xrec_ema, _ = self(x)
251
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
- log["reconstructions_ema"] = xrec_ema
253
- return log
254
-
255
- def to_rgb(self, x):
256
- assert self.image_key == "segmentation"
257
- if not hasattr(self, "colorize"):
258
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
- x = F.conv2d(x, weight=self.colorize)
260
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
- return x
262
-
263
-
264
- class VQModelInterface(VQModel):
265
- def __init__(self, embed_dim, *args, **kwargs):
266
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
- self.embed_dim = embed_dim
268
-
269
- def encode(self, x):
270
- h = self.encoder(x)
271
- h = self.quant_conv(h)
272
- return h
273
-
274
- def decode(self, h, force_not_quantize=False):
275
- # also go through quantization layer
276
- if not force_not_quantize:
277
- quant, emb_loss, info = self.quantize(h)
278
- else:
279
- quant = h
280
- quant = self.post_quant_conv(quant)
281
- dec = self.decoder(quant)
282
- return dec
283
-
284
-
285
- class AutoencoderKL(pl.LightningModule):
286
- def __init__(self,
287
- ddconfig,
288
- lossconfig,
289
- embed_dim,
290
- ckpt_path=None,
291
- ignore_keys=[],
292
- image_key="image",
293
- colorize_nlabels=None,
294
- monitor=None,
295
- ):
296
- super().__init__()
297
- self.image_key = image_key
298
- self.encoder = Encoder(**ddconfig)
299
- self.decoder = Decoder(**ddconfig)
300
- self.loss = instantiate_from_config(lossconfig)
301
- assert ddconfig["double_z"]
302
- self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304
- self.embed_dim = embed_dim
305
- if colorize_nlabels is not None:
306
- assert type(colorize_nlabels)==int
307
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308
- if monitor is not None:
309
- self.monitor = monitor
310
- if ckpt_path is not None:
311
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312
-
313
- def init_from_ckpt(self, path, ignore_keys=list()):
314
- sd = torch.load(path, map_location="cpu")["state_dict"]
315
- keys = list(sd.keys())
316
- for k in keys:
317
- for ik in ignore_keys:
318
- if k.startswith(ik):
319
- print("Deleting key {} from state_dict.".format(k))
320
- del sd[k]
321
- self.load_state_dict(sd, strict=False)
322
- print(f"Restored from {path}")
323
-
324
  def encode(self, x):
325
  h = self.encoder(x)
326
  moments = self.quant_conv(h)
@@ -370,25 +129,33 @@ class AutoencoderKL(pl.LightningModule):
370
  return discloss
371
 
372
  def validation_step(self, batch, batch_idx):
 
 
 
 
 
 
373
  inputs = self.get_input(batch, self.image_key)
374
  reconstructions, posterior = self(inputs)
375
  aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376
- last_layer=self.get_last_layer(), split="val")
377
 
378
  discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379
- last_layer=self.get_last_layer(), split="val")
380
 
381
- self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382
  self.log_dict(log_dict_ae)
383
  self.log_dict(log_dict_disc)
384
  return self.log_dict
385
 
386
  def configure_optimizers(self):
387
  lr = self.learning_rate
388
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389
- list(self.decoder.parameters())+
390
- list(self.quant_conv.parameters())+
391
- list(self.post_quant_conv.parameters()),
 
 
392
  lr=lr, betas=(0.5, 0.9))
393
  opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394
  lr=lr, betas=(0.5, 0.9))
@@ -398,7 +165,7 @@ class AutoencoderKL(pl.LightningModule):
398
  return self.decoder.conv_out.weight
399
 
400
  @torch.no_grad()
401
- def log_images(self, batch, only_inputs=False, **kwargs):
402
  log = dict()
403
  x = self.get_input(batch, self.image_key)
404
  x = x.to(self.device)
@@ -423,9 +190,9 @@ class AutoencoderKL(pl.LightningModule):
423
  return x
424
 
425
 
426
- class IdentityFirstStage(torch.nn.Module):
427
  def __init__(self, *args, vq_interface=False, **kwargs):
428
- self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429
  super().__init__()
430
 
431
  def encode(self, x, *args, **kwargs):
@@ -441,3 +208,4 @@ class IdentityFirstStage(torch.nn.Module):
441
 
442
  def forward(self, x, *args, **kwargs):
443
  return x
 
 
1
  import torch
2
  import pytorch_lightning as pl
3
  import torch.nn.functional as F
4
+ import torch.nn as nn
5
  from contextlib import contextmanager
6
 
 
 
7
  from ldm.modules.diffusionmodules.model import Encoder, Decoder
8
  from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
9
 
10
  from ldm.util import instantiate_from_config
11
+ from ldm.modules.ema import LitEma
12
 
13
 
14
+ class AutoencoderKL(pl.LightningModule):
15
  def __init__(self,
16
  ddconfig,
17
  lossconfig,
 
18
  embed_dim,
19
  ckpt_path=None,
20
  ignore_keys=[],
21
  image_key="image",
22
  colorize_nlabels=None,
23
  monitor=None,
24
+ ema_decay=None,
25
+ learn_logvar=False
 
 
 
 
26
  ):
27
  super().__init__()
28
+ self.learn_logvar = learn_logvar
 
29
  self.image_key = image_key
30
  self.encoder = Encoder(**ddconfig)
31
  self.decoder = Decoder(**ddconfig)
32
  self.loss = instantiate_from_config(lossconfig)
33
+ assert ddconfig["double_z"]
34
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
 
 
35
  self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
36
+ self.embed_dim = embed_dim
37
  if colorize_nlabels is not None:
38
  assert type(colorize_nlabels)==int
39
  self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
40
  if monitor is not None:
41
  self.monitor = monitor
 
 
 
42
 
43
+ self.use_ema = ema_decay is not None
44
  if self.use_ema:
45
+ self.ema_decay = ema_decay
46
+ assert 0. < ema_decay < 1.
47
+ self.model_ema = LitEma(self, decay=ema_decay)
48
  print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
49
 
50
  if ckpt_path is not None:
51
  self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
52
+
53
+ def init_from_ckpt(self, path, ignore_keys=list()):
54
+ sd = torch.load(path, map_location="cpu")["state_dict"]
55
+ keys = list(sd.keys())
56
+ for k in keys:
57
+ for ik in ignore_keys:
58
+ if k.startswith(ik):
59
+ print("Deleting key {} from state_dict.".format(k))
60
+ del sd[k]
61
+ self.load_state_dict(sd, strict=False)
62
+ print(f"Restored from {path}")
63
 
64
  @contextmanager
65
  def ema_scope(self, context=None):
 
76
  if context is not None:
77
  print(f"{context}: Restored training weights")
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def on_train_batch_end(self, *args, **kwargs):
80
  if self.use_ema:
81
  self.model_ema(self)
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def encode(self, x):
84
  h = self.encoder(x)
85
  moments = self.quant_conv(h)
 
129
  return discloss
130
 
131
  def validation_step(self, batch, batch_idx):
132
+ log_dict = self._validation_step(batch, batch_idx)
133
+ with self.ema_scope():
134
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
135
+ return log_dict
136
+
137
+ def _validation_step(self, batch, batch_idx, postfix=""):
138
  inputs = self.get_input(batch, self.image_key)
139
  reconstructions, posterior = self(inputs)
140
  aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
141
+ last_layer=self.get_last_layer(), split="val"+postfix)
142
 
143
  discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
144
+ last_layer=self.get_last_layer(), split="val"+postfix)
145
 
146
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
147
  self.log_dict(log_dict_ae)
148
  self.log_dict(log_dict_disc)
149
  return self.log_dict
150
 
151
  def configure_optimizers(self):
152
  lr = self.learning_rate
153
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
154
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
155
+ if self.learn_logvar:
156
+ print(f"{self.__class__.__name__}: Learning logvar")
157
+ ae_params_list.append(self.loss.logvar)
158
+ opt_ae = torch.optim.Adam(ae_params_list,
159
  lr=lr, betas=(0.5, 0.9))
160
  opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
161
  lr=lr, betas=(0.5, 0.9))
 
165
  return self.decoder.conv_out.weight
166
 
167
  @torch.no_grad()
168
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
169
  log = dict()
170
  x = self.get_input(batch, self.image_key)
171
  x = x.to(self.device)
 
190
  return x
191
 
192
 
193
+ class IdentityFirstStage(nn.Module):
194
  def __init__(self, *args, vq_interface=False, **kwargs):
195
+ self.vq_interface = vq_interface
196
  super().__init__()
197
 
198
  def encode(self, x, *args, **kwargs):
 
208
 
209
  def forward(self, x, *args, **kwargs):
210
  return x
211
+
ldm/models/diffusion/classifier.py DELETED
@@ -1,267 +0,0 @@
1
- import os
2
- import torch
3
- import pytorch_lightning as pl
4
- from omegaconf import OmegaConf
5
- from torch.nn import functional as F
6
- from torch.optim import AdamW
7
- from torch.optim.lr_scheduler import LambdaLR
8
- from copy import deepcopy
9
- from einops import rearrange
10
- from glob import glob
11
- from natsort import natsorted
12
-
13
- from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
- from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
-
16
- __models__ = {
17
- 'class_label': EncoderUNetModel,
18
- 'segmentation': UNetModel
19
- }
20
-
21
-
22
- def disabled_train(self, mode=True):
23
- """Overwrite model.train with this function to make sure train/eval mode
24
- does not change anymore."""
25
- return self
26
-
27
-
28
- class NoisyLatentImageClassifier(pl.LightningModule):
29
-
30
- def __init__(self,
31
- diffusion_path,
32
- num_classes,
33
- ckpt_path=None,
34
- pool='attention',
35
- label_key=None,
36
- diffusion_ckpt_path=None,
37
- scheduler_config=None,
38
- weight_decay=1.e-2,
39
- log_steps=10,
40
- monitor='val/loss',
41
- *args,
42
- **kwargs):
43
- super().__init__(*args, **kwargs)
44
- self.num_classes = num_classes
45
- # get latest config of diffusion model
46
- diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
- self.diffusion_config = OmegaConf.load(diffusion_config).model
48
- self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
- self.load_diffusion()
50
-
51
- self.monitor = monitor
52
- self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
- self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
- self.log_steps = log_steps
55
-
56
- self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
- else self.diffusion_model.cond_stage_key
58
-
59
- assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
-
61
- if self.label_key not in __models__:
62
- raise NotImplementedError()
63
-
64
- self.load_classifier(ckpt_path, pool)
65
-
66
- self.scheduler_config = scheduler_config
67
- self.use_scheduler = self.scheduler_config is not None
68
- self.weight_decay = weight_decay
69
-
70
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
- sd = torch.load(path, map_location="cpu")
72
- if "state_dict" in list(sd.keys()):
73
- sd = sd["state_dict"]
74
- keys = list(sd.keys())
75
- for k in keys:
76
- for ik in ignore_keys:
77
- if k.startswith(ik):
78
- print("Deleting key {} from state_dict.".format(k))
79
- del sd[k]
80
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
- sd, strict=False)
82
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
- if len(missing) > 0:
84
- print(f"Missing Keys: {missing}")
85
- if len(unexpected) > 0:
86
- print(f"Unexpected Keys: {unexpected}")
87
-
88
- def load_diffusion(self):
89
- model = instantiate_from_config(self.diffusion_config)
90
- self.diffusion_model = model.eval()
91
- self.diffusion_model.train = disabled_train
92
- for param in self.diffusion_model.parameters():
93
- param.requires_grad = False
94
-
95
- def load_classifier(self, ckpt_path, pool):
96
- model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
- model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
- model_config.out_channels = self.num_classes
99
- if self.label_key == 'class_label':
100
- model_config.pool = pool
101
-
102
- self.model = __models__[self.label_key](**model_config)
103
- if ckpt_path is not None:
104
- print('#####################################################################')
105
- print(f'load from ckpt "{ckpt_path}"')
106
- print('#####################################################################')
107
- self.init_from_ckpt(ckpt_path)
108
-
109
- @torch.no_grad()
110
- def get_x_noisy(self, x, t, noise=None):
111
- noise = default(noise, lambda: torch.randn_like(x))
112
- continuous_sqrt_alpha_cumprod = None
113
- if self.diffusion_model.use_continuous_noise:
114
- continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
- # todo: make sure t+1 is correct here
116
-
117
- return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
- continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
-
120
- def forward(self, x_noisy, t, *args, **kwargs):
121
- return self.model(x_noisy, t)
122
-
123
- @torch.no_grad()
124
- def get_input(self, batch, k):
125
- x = batch[k]
126
- if len(x.shape) == 3:
127
- x = x[..., None]
128
- x = rearrange(x, 'b h w c -> b c h w')
129
- x = x.to(memory_format=torch.contiguous_format).float()
130
- return x
131
-
132
- @torch.no_grad()
133
- def get_conditioning(self, batch, k=None):
134
- if k is None:
135
- k = self.label_key
136
- assert k is not None, 'Needs to provide label key'
137
-
138
- targets = batch[k].to(self.device)
139
-
140
- if self.label_key == 'segmentation':
141
- targets = rearrange(targets, 'b h w c -> b c h w')
142
- for down in range(self.numd):
143
- h, w = targets.shape[-2:]
144
- targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
-
146
- # targets = rearrange(targets,'b c h w -> b h w c')
147
-
148
- return targets
149
-
150
- def compute_top_k(self, logits, labels, k, reduction="mean"):
151
- _, top_ks = torch.topk(logits, k, dim=1)
152
- if reduction == "mean":
153
- return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
- elif reduction == "none":
155
- return (top_ks == labels[:, None]).float().sum(dim=-1)
156
-
157
- def on_train_epoch_start(self):
158
- # save some memory
159
- self.diffusion_model.model.to('cpu')
160
-
161
- @torch.no_grad()
162
- def write_logs(self, loss, logits, targets):
163
- log_prefix = 'train' if self.training else 'val'
164
- log = {}
165
- log[f"{log_prefix}/loss"] = loss.mean()
166
- log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
- logits, targets, k=1, reduction="mean"
168
- )
169
- log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
- logits, targets, k=5, reduction="mean"
171
- )
172
-
173
- self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
- self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
- self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
- lr = self.optimizers().param_groups[0]['lr']
177
- self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
-
179
- def shared_step(self, batch, t=None):
180
- x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
- targets = self.get_conditioning(batch)
182
- if targets.dim() == 4:
183
- targets = targets.argmax(dim=1)
184
- if t is None:
185
- t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
- else:
187
- t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
- x_noisy = self.get_x_noisy(x, t)
189
- logits = self(x_noisy, t)
190
-
191
- loss = F.cross_entropy(logits, targets, reduction='none')
192
-
193
- self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
-
195
- loss = loss.mean()
196
- return loss, logits, x_noisy, targets
197
-
198
- def training_step(self, batch, batch_idx):
199
- loss, *_ = self.shared_step(batch)
200
- return loss
201
-
202
- def reset_noise_accs(self):
203
- self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
- range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
-
206
- def on_validation_start(self):
207
- self.reset_noise_accs()
208
-
209
- @torch.no_grad()
210
- def validation_step(self, batch, batch_idx):
211
- loss, *_ = self.shared_step(batch)
212
-
213
- for t in self.noisy_acc:
214
- _, logits, _, targets = self.shared_step(batch, t)
215
- self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
- self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
-
218
- return loss
219
-
220
- def configure_optimizers(self):
221
- optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
-
223
- if self.use_scheduler:
224
- scheduler = instantiate_from_config(self.scheduler_config)
225
-
226
- print("Setting up LambdaLR scheduler...")
227
- scheduler = [
228
- {
229
- 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
- 'interval': 'step',
231
- 'frequency': 1
232
- }]
233
- return [optimizer], scheduler
234
-
235
- return optimizer
236
-
237
- @torch.no_grad()
238
- def log_images(self, batch, N=8, *args, **kwargs):
239
- log = dict()
240
- x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
- log['inputs'] = x
242
-
243
- y = self.get_conditioning(batch)
244
-
245
- if self.label_key == 'class_label':
246
- y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
- log['labels'] = y
248
-
249
- if ismap(y):
250
- log['labels'] = self.diffusion_model.to_rgb(y)
251
-
252
- for step in range(self.log_steps):
253
- current_time = step * self.log_time_interval
254
-
255
- _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
-
257
- log[f'inputs@t{current_time}'] = x_noisy
258
-
259
- pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
- pred = rearrange(pred, 'b h w c -> b c h w')
261
-
262
- log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
-
264
- for key in log:
265
- log[key] = log[key][:N]
266
-
267
- return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/ddim.py CHANGED
@@ -3,7 +3,6 @@
3
  import torch
4
  import numpy as np
5
  from tqdm import tqdm
6
- from functools import partial
7
 
8
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9
  extract_into_tensor
@@ -24,7 +23,7 @@ class DDIMSampler(object):
24
 
25
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
  self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28
  alphas_cumprod = self.model.alphas_cumprod
29
  assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
  to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
@@ -43,14 +42,14 @@ class DDIMSampler(object):
43
  # ddim sampling parameters
44
  ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
  ddim_timesteps=self.ddim_timesteps,
46
- eta=ddim_eta,verbose=verbose)
47
  self.register_buffer('ddim_sigmas', ddim_sigmas)
48
  self.register_buffer('ddim_alphas', ddim_alphas)
49
  self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50
  self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51
  sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52
  (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54
  self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55
 
56
  @torch.no_grad()
@@ -75,6 +74,9 @@ class DDIMSampler(object):
75
  log_every_t=100,
76
  unconditional_guidance_scale=1.,
77
  unconditional_conditioning=None,
 
 
 
78
  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79
  **kwargs
80
  ):
@@ -107,6 +109,9 @@ class DDIMSampler(object):
107
  log_every_t=log_every_t,
108
  unconditional_guidance_scale=unconditional_guidance_scale,
109
  unconditional_conditioning=unconditional_conditioning,
 
 
 
110
  )
111
  return samples, intermediates
112
 
@@ -116,7 +121,8 @@ class DDIMSampler(object):
116
  callback=None, timesteps=None, quantize_denoised=False,
117
  mask=None, x0=None, img_callback=None, log_every_t=100,
118
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
119
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
 
120
  device = self.model.betas.device
121
  b = shape[0]
122
  if x_T is None:
@@ -131,7 +137,7 @@ class DDIMSampler(object):
131
  timesteps = self.ddim_timesteps[:subset_end]
132
 
133
  intermediates = {'x_inter': [img], 'pred_x0': [img]}
134
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
135
  total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
136
  print(f"Running DDIM Sampling with {total_steps} timesteps")
137
 
@@ -151,7 +157,13 @@ class DDIMSampler(object):
151
  noise_dropout=noise_dropout, score_corrector=score_corrector,
152
  corrector_kwargs=corrector_kwargs,
153
  unconditional_guidance_scale=unconditional_guidance_scale,
154
- unconditional_conditioning=unconditional_conditioning)
 
 
 
 
 
 
155
  img, pred_x0 = outs
156
  if callback: callback(i)
157
  if img_callback: img_callback(pred_x0, i)
@@ -165,20 +177,55 @@ class DDIMSampler(object):
165
  @torch.no_grad()
166
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
167
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
168
- unconditional_guidance_scale=1., unconditional_conditioning=None):
 
169
  b, *_, device = *x.shape, x.device
170
 
171
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
172
- e_t = self.model.apply_model(x, t, c)
 
 
 
 
173
  else:
174
  x_in = torch.cat([x] * 2)
175
  t_in = torch.cat([t] * 2)
176
- c_in = torch.cat([unconditional_conditioning, c])
177
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
178
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  if score_corrector is not None:
181
- assert self.model.parameterization == "eps"
182
  e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
183
 
184
  alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
@@ -189,14 +236,18 @@ class DDIMSampler(object):
189
  a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
190
  a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
191
  sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
192
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
193
 
194
  # current prediction for x_0
195
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
 
 
 
 
196
  if quantize_denoised:
197
  pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
198
  # direction pointing to x_t
199
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
200
  noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
201
  if noise_dropout > 0.:
202
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
@@ -238,4 +289,4 @@ class DDIMSampler(object):
238
  x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
239
  unconditional_guidance_scale=unconditional_guidance_scale,
240
  unconditional_conditioning=unconditional_conditioning)
241
- return x_dec
 
3
  import torch
4
  import numpy as np
5
  from tqdm import tqdm
 
6
 
7
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
8
  extract_into_tensor
 
23
 
24
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
  self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
27
  alphas_cumprod = self.model.alphas_cumprod
28
  assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29
  to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
 
42
  # ddim sampling parameters
43
  ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44
  ddim_timesteps=self.ddim_timesteps,
45
+ eta=ddim_eta, verbose=verbose)
46
  self.register_buffer('ddim_sigmas', ddim_sigmas)
47
  self.register_buffer('ddim_alphas', ddim_alphas)
48
  self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49
  self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50
  sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51
  (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53
  self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54
 
55
  @torch.no_grad()
 
74
  log_every_t=100,
75
  unconditional_guidance_scale=1.,
76
  unconditional_conditioning=None,
77
+ features_adapter=None,
78
+ append_to_context=None,
79
+ cond_tau=0.4,
80
  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81
  **kwargs
82
  ):
 
109
  log_every_t=log_every_t,
110
  unconditional_guidance_scale=unconditional_guidance_scale,
111
  unconditional_conditioning=unconditional_conditioning,
112
+ features_adapter=features_adapter,
113
+ append_to_context=append_to_context,
114
+ cond_tau=cond_tau,
115
  )
116
  return samples, intermediates
117
 
 
121
  callback=None, timesteps=None, quantize_denoised=False,
122
  mask=None, x0=None, img_callback=None, log_every_t=100,
123
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
124
+ unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
125
+ append_to_context=None, cond_tau=0.4):
126
  device = self.model.betas.device
127
  b = shape[0]
128
  if x_T is None:
 
137
  timesteps = self.ddim_timesteps[:subset_end]
138
 
139
  intermediates = {'x_inter': [img], 'pred_x0': [img]}
140
+ time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
141
  total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
142
  print(f"Running DDIM Sampling with {total_steps} timesteps")
143
 
 
157
  noise_dropout=noise_dropout, score_corrector=score_corrector,
158
  corrector_kwargs=corrector_kwargs,
159
  unconditional_guidance_scale=unconditional_guidance_scale,
160
+ unconditional_conditioning=unconditional_conditioning,
161
+ features_adapter=None if index < int(
162
+ (1 - cond_tau) * total_steps) else features_adapter,
163
+ # TODO support style_cond_tau
164
+ append_to_context=None if index < int(
165
+ 0.5 * total_steps) else append_to_context,
166
+ )
167
  img, pred_x0 = outs
168
  if callback: callback(i)
169
  if img_callback: img_callback(pred_x0, i)
 
177
  @torch.no_grad()
178
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180
+ unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
181
+ append_to_context=None):
182
  b, *_, device = *x.shape, x.device
183
 
184
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
185
+ if append_to_context is not None:
186
+ model_output = self.model.apply_model(x, t, torch.cat([c, append_to_context], dim=1),
187
+ features_adapter=features_adapter)
188
+ else:
189
+ model_output = self.model.apply_model(x, t, c, features_adapter=features_adapter)
190
  else:
191
  x_in = torch.cat([x] * 2)
192
  t_in = torch.cat([t] * 2)
193
+ if isinstance(c, dict):
194
+ assert isinstance(unconditional_conditioning, dict)
195
+ c_in = dict()
196
+ for k in c:
197
+ if isinstance(c[k], list):
198
+ c_in[k] = [torch.cat([
199
+ unconditional_conditioning[k][i],
200
+ c[k][i]]) for i in range(len(c[k]))]
201
+ else:
202
+ c_in[k] = torch.cat([
203
+ unconditional_conditioning[k],
204
+ c[k]])
205
+ elif isinstance(c, list):
206
+ c_in = list()
207
+ assert isinstance(unconditional_conditioning, list)
208
+ for i in range(len(c)):
209
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
210
+ else:
211
+ if append_to_context is not None:
212
+ pad_len = append_to_context.size(1)
213
+ new_unconditional_conditioning = torch.cat(
214
+ [unconditional_conditioning, unconditional_conditioning[:, -pad_len:, :]], dim=1)
215
+ new_c = torch.cat([c, append_to_context], dim=1)
216
+ c_in = torch.cat([new_unconditional_conditioning, new_c])
217
+ else:
218
+ c_in = torch.cat([unconditional_conditioning, c])
219
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in, features_adapter=features_adapter).chunk(2)
220
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
221
+
222
+ if self.model.parameterization == "v":
223
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
224
+ else:
225
+ e_t = model_output
226
 
227
  if score_corrector is not None:
228
+ assert self.model.parameterization == "eps", 'not implemented'
229
  e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
230
 
231
  alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
 
236
  a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
237
  a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
238
  sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
239
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
240
 
241
  # current prediction for x_0
242
+ if self.model.parameterization != "v":
243
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
244
+ else:
245
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
246
+
247
  if quantize_denoised:
248
  pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
249
  # direction pointing to x_t
250
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
251
  noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
252
  if noise_dropout > 0.:
253
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
 
289
  x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
290
  unconditional_guidance_scale=unconditional_guidance_scale,
291
  unconditional_conditioning=unconditional_conditioning)
292
+ return x_dec
ldm/models/diffusion/ddpm.py CHANGED
@@ -12,16 +12,18 @@ import numpy as np
12
  import pytorch_lightning as pl
13
  from torch.optim.lr_scheduler import LambdaLR
14
  from einops import rearrange, repeat
15
- from contextlib import contextmanager
16
  from functools import partial
 
17
  from tqdm import tqdm
18
  from torchvision.utils import make_grid
19
  from pytorch_lightning.utilities.distributed import rank_zero_only
 
20
 
21
  from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
22
  from ldm.modules.ema import LitEma
23
  from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
24
- from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
25
  from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
26
  from ldm.models.diffusion.ddim import DDIMSampler
27
 
@@ -71,9 +73,13 @@ class DDPM(pl.LightningModule):
71
  use_positional_encodings=False,
72
  learn_logvar=False,
73
  logvar_init=0.,
 
 
 
 
74
  ):
75
  super().__init__()
76
- assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
77
  self.parameterization = parameterization
78
  print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
79
  self.cond_stage_model = None
@@ -100,8 +106,18 @@ class DDPM(pl.LightningModule):
100
 
101
  if monitor is not None:
102
  self.monitor = monitor
 
 
103
  if ckpt_path is not None:
104
  self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
 
 
 
 
 
 
 
 
105
 
106
  self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
107
  linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
@@ -113,6 +129,9 @@ class DDPM(pl.LightningModule):
113
  if self.learn_logvar:
114
  self.logvar = nn.Parameter(self.logvar, requires_grad=True)
115
 
 
 
 
116
 
117
  def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
118
  linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
@@ -146,7 +165,7 @@ class DDPM(pl.LightningModule):
146
 
147
  # calculations for posterior q(x_{t-1} | x_t, x_0)
148
  posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
149
- 1. - alphas_cumprod) + self.v_posterior * betas
150
  # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
151
  self.register_buffer('posterior_variance', to_torch(posterior_variance))
152
  # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
@@ -158,12 +177,14 @@ class DDPM(pl.LightningModule):
158
 
159
  if self.parameterization == "eps":
160
  lvlb_weights = self.betas ** 2 / (
161
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
162
  elif self.parameterization == "x0":
163
  lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
 
 
 
164
  else:
165
  raise NotImplementedError("mu not supported")
166
- # TODO how to choose this term
167
  lvlb_weights[0] = lvlb_weights[1]
168
  self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
169
  assert not torch.isnan(self.lvlb_weights).all()
@@ -183,6 +204,7 @@ class DDPM(pl.LightningModule):
183
  if context is not None:
184
  print(f"{context}: Restored training weights")
185
 
 
186
  def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
187
  sd = torch.load(path, map_location="cpu")
188
  if "state_dict" in list(sd.keys()):
@@ -193,13 +215,57 @@ class DDPM(pl.LightningModule):
193
  if k.startswith(ik):
194
  print("Deleting key {} from state_dict.".format(k))
195
  del sd[k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
197
  sd, strict=False)
198
  print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
199
  if len(missing) > 0:
200
- print(f"Missing Keys: {missing}")
201
  if len(unexpected) > 0:
202
- print(f"Unexpected Keys: {unexpected}")
203
 
204
  def q_mean_variance(self, x_start, t):
205
  """
@@ -219,6 +285,20 @@ class DDPM(pl.LightningModule):
219
  extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
220
  )
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  def q_posterior(self, x_start, x_t, t):
223
  posterior_mean = (
224
  extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
@@ -276,6 +356,12 @@ class DDPM(pl.LightningModule):
276
  return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
277
  extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
278
 
 
 
 
 
 
 
279
  def get_loss(self, pred, target, mean=True):
280
  if self.loss_type == 'l1':
281
  loss = (target - pred).abs()
@@ -301,6 +387,8 @@ class DDPM(pl.LightningModule):
301
  target = noise
302
  elif self.parameterization == "x0":
303
  target = x_start
 
 
304
  else:
305
  raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
306
 
@@ -328,10 +416,10 @@ class DDPM(pl.LightningModule):
328
 
329
  def get_input(self, batch, k):
330
  x = batch[k]
331
- if len(x.shape) == 3:
332
- x = x[..., None]
333
- x = rearrange(x, 'b h w c -> b c h w')
334
- x = x.to(memory_format=torch.contiguous_format).float()
335
  return x
336
 
337
  def shared_step(self, batch):
@@ -421,41 +509,12 @@ class DDPM(pl.LightningModule):
421
  return opt
422
 
423
 
424
- class DiffusionWrapper(pl.LightningModule):
425
- def __init__(self, diff_model_config, conditioning_key):
426
- super().__init__()
427
- self.diffusion_model = instantiate_from_config(diff_model_config)
428
- self.conditioning_key = conditioning_key
429
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
430
-
431
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, features_adapter=None):
432
- if self.conditioning_key is None:
433
- out = self.diffusion_model(x, t, features_adapter=features_adapter)
434
- elif self.conditioning_key == 'concat':
435
- xc = torch.cat([x] + c_concat, dim=1)
436
- out = self.diffusion_model(xc, t, features_adapter=features_adapter)
437
- elif self.conditioning_key == 'crossattn':
438
- cc = torch.cat(c_crossattn, 1)
439
- out = self.diffusion_model(x, t, context=cc, features_adapter=features_adapter)
440
- elif self.conditioning_key == 'hybrid':
441
- xc = torch.cat([x] + c_concat, dim=1)
442
- cc = torch.cat(c_crossattn, 1)
443
- out = self.diffusion_model(xc, t, context=cc, features_adapter=features_adapter)
444
- elif self.conditioning_key == 'adm':
445
- cc = c_crossattn[0]
446
- out = self.diffusion_model(x, t, y=cc, features_adapter=features_adapter)
447
- else:
448
- raise NotImplementedError()
449
-
450
- return out
451
-
452
-
453
  class LatentDiffusion(DDPM):
454
  """main class"""
 
455
  def __init__(self,
456
  first_stage_config,
457
  cond_stage_config,
458
- unet_config,
459
  num_timesteps_cond=None,
460
  cond_stage_key="image",
461
  cond_stage_trainable=False,
@@ -474,9 +533,10 @@ class LatentDiffusion(DDPM):
474
  if cond_stage_config == '__is_unconditional__':
475
  conditioning_key = None
476
  ckpt_path = kwargs.pop("ckpt_path", None)
 
 
477
  ignore_keys = kwargs.pop("ignore_keys", [])
478
- super().__init__(conditioning_key=conditioning_key, unet_config=unet_config, *args, **kwargs)
479
- self.model = DiffusionWrapper(unet_config, conditioning_key)
480
  self.concat_mode = concat_mode
481
  self.cond_stage_trainable = cond_stage_trainable
482
  self.cond_stage_key = cond_stage_key
@@ -492,35 +552,27 @@ class LatentDiffusion(DDPM):
492
  self.instantiate_cond_stage(cond_stage_config)
493
  self.cond_stage_forward = cond_stage_forward
494
  self.clip_denoised = False
495
- self.bbox_tokenizer = None
496
 
497
  self.restarted_from_ckpt = False
498
  if ckpt_path is not None:
499
  self.init_from_ckpt(ckpt_path, ignore_keys)
500
  self.restarted_from_ckpt = True
 
 
 
 
 
 
 
 
 
501
 
502
  def make_cond_schedule(self, ):
503
  self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
504
  ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
505
  self.cond_ids[:self.num_timesteps_cond] = ids
506
 
507
- @rank_zero_only
508
- @torch.no_grad()
509
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
510
- # only for very first batch
511
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
512
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
513
- # set rescale weight to 1./std of encodings
514
- print("### USING STD-RESCALING ###")
515
- x = super().get_input(batch, self.first_stage_key)
516
- x = x.to(self.device)
517
- encoder_posterior = self.encode_first_stage(x)
518
- z = self.get_first_stage_encoding(encoder_posterior).detach()
519
- del self.scale_factor
520
- self.register_buffer('scale_factor', 1. / z.flatten().std())
521
- print(f"setting self.scale_factor to {self.scale_factor}")
522
- print("### USING STD-RESCALING ###")
523
-
524
  def register_schedule(self,
525
  given_betas=None, beta_schedule="linear", timesteps=1000,
526
  linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
@@ -562,7 +614,7 @@ class LatentDiffusion(DDPM):
562
  denoise_row = []
563
  for zd in tqdm(samples, desc=desc):
564
  denoise_row.append(self.decode_first_stage(zd.to(self.device),
565
- force_not_quantize=force_no_decoder_quantization))
566
  n_imgs_per_row = len(denoise_row)
567
  denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
568
  denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
@@ -695,9 +747,9 @@ class LatentDiffusion(DDPM):
695
  if cond_key is None:
696
  cond_key = self.cond_stage_key
697
  if cond_key != self.first_stage_key:
698
- if cond_key in ['caption', 'coordinates_bbox']:
699
  xc = batch[cond_key]
700
- elif cond_key == 'class_label':
701
  xc = batch
702
  else:
703
  xc = super().get_input(batch, cond_key).to(self.device)
@@ -742,181 +794,28 @@ class LatentDiffusion(DDPM):
742
  z = rearrange(z, 'b h w c -> b c h w').contiguous()
743
 
744
  z = 1. / self.scale_factor * z
745
-
746
- if hasattr(self, "split_input_params"):
747
- if self.split_input_params["patch_distributed_vq"]:
748
- ks = self.split_input_params["ks"] # eg. (128, 128)
749
- stride = self.split_input_params["stride"] # eg. (64, 64)
750
- uf = self.split_input_params["vqf"]
751
- bs, nc, h, w = z.shape
752
- if ks[0] > h or ks[1] > w:
753
- ks = (min(ks[0], h), min(ks[1], w))
754
- print("reducing Kernel")
755
-
756
- if stride[0] > h or stride[1] > w:
757
- stride = (min(stride[0], h), min(stride[1], w))
758
- print("reducing stride")
759
-
760
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
761
-
762
- z = unfold(z) # (bn, nc * prod(**ks), L)
763
- # 1. Reshape to img shape
764
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
765
-
766
- # 2. apply model loop over last dim
767
- if isinstance(self.first_stage_model, VQModelInterface):
768
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
769
- force_not_quantize=predict_cids or force_not_quantize)
770
- for i in range(z.shape[-1])]
771
- else:
772
-
773
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
774
- for i in range(z.shape[-1])]
775
-
776
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
777
- o = o * weighting
778
- # Reverse 1. reshape to img shape
779
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
780
- # stitch crops together
781
- decoded = fold(o)
782
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
783
- return decoded
784
- else:
785
- if isinstance(self.first_stage_model, VQModelInterface):
786
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
787
- else:
788
- return self.first_stage_model.decode(z)
789
-
790
- else:
791
- if isinstance(self.first_stage_model, VQModelInterface):
792
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
793
- else:
794
- return self.first_stage_model.decode(z)
795
-
796
- # same as above but without decorator
797
- def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
798
- if predict_cids:
799
- if z.dim() == 4:
800
- z = torch.argmax(z.exp(), dim=1).long()
801
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
802
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
803
-
804
- z = 1. / self.scale_factor * z
805
-
806
- if hasattr(self, "split_input_params"):
807
- if self.split_input_params["patch_distributed_vq"]:
808
- ks = self.split_input_params["ks"] # eg. (128, 128)
809
- stride = self.split_input_params["stride"] # eg. (64, 64)
810
- uf = self.split_input_params["vqf"]
811
- bs, nc, h, w = z.shape
812
- if ks[0] > h or ks[1] > w:
813
- ks = (min(ks[0], h), min(ks[1], w))
814
- print("reducing Kernel")
815
-
816
- if stride[0] > h or stride[1] > w:
817
- stride = (min(stride[0], h), min(stride[1], w))
818
- print("reducing stride")
819
-
820
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
821
-
822
- z = unfold(z) # (bn, nc * prod(**ks), L)
823
- # 1. Reshape to img shape
824
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
825
-
826
- # 2. apply model loop over last dim
827
- if isinstance(self.first_stage_model, VQModelInterface):
828
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
829
- force_not_quantize=predict_cids or force_not_quantize)
830
- for i in range(z.shape[-1])]
831
- else:
832
-
833
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
834
- for i in range(z.shape[-1])]
835
-
836
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
837
- o = o * weighting
838
- # Reverse 1. reshape to img shape
839
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
840
- # stitch crops together
841
- decoded = fold(o)
842
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
843
- return decoded
844
- else:
845
- if isinstance(self.first_stage_model, VQModelInterface):
846
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
847
- else:
848
- return self.first_stage_model.decode(z)
849
-
850
- else:
851
- if isinstance(self.first_stage_model, VQModelInterface):
852
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
853
- else:
854
- return self.first_stage_model.decode(z)
855
 
856
  @torch.no_grad()
857
  def encode_first_stage(self, x):
858
- if hasattr(self, "split_input_params"):
859
- if self.split_input_params["patch_distributed_vq"]:
860
- ks = self.split_input_params["ks"] # eg. (128, 128)
861
- stride = self.split_input_params["stride"] # eg. (64, 64)
862
- df = self.split_input_params["vqf"]
863
- self.split_input_params['original_image_size'] = x.shape[-2:]
864
- bs, nc, h, w = x.shape
865
- if ks[0] > h or ks[1] > w:
866
- ks = (min(ks[0], h), min(ks[1], w))
867
- print("reducing Kernel")
868
-
869
- if stride[0] > h or stride[1] > w:
870
- stride = (min(stride[0], h), min(stride[1], w))
871
- print("reducing stride")
872
-
873
- fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
874
- z = unfold(x) # (bn, nc * prod(**ks), L)
875
- # Reshape to img shape
876
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
877
-
878
- output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
879
- for i in range(z.shape[-1])]
880
-
881
- o = torch.stack(output_list, axis=-1)
882
- o = o * weighting
883
-
884
- # Reverse reshape to img shape
885
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
886
- # stitch crops together
887
- decoded = fold(o)
888
- decoded = decoded / normalization
889
- return decoded
890
-
891
- else:
892
- return self.first_stage_model.encode(x)
893
- else:
894
- return self.first_stage_model.encode(x)
895
 
896
  def shared_step(self, batch, **kwargs):
897
  x, c = self.get_input(batch, self.first_stage_key)
898
- loss = self(x, c)
899
  return loss
900
 
901
- def forward(self, x, c, features_adapter=None, *args, **kwargs):
902
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
903
-
904
- return self.p_losses(x, c, t, features_adapter, *args, **kwargs)
905
-
906
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
907
- def rescale_bbox(bbox):
908
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
909
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
910
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
911
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
912
- return x0, y0, w, h
913
-
914
- return [rescale_bbox(b) for b in bboxes]
915
 
916
- def apply_model(self, x_noisy, t, cond, features_adapter=None, return_ids=False):
917
 
 
918
  if isinstance(cond, dict):
919
- # hybrid case, cond is exptected to be a dict
920
  pass
921
  else:
922
  if not isinstance(cond, list):
@@ -924,98 +823,7 @@ class LatentDiffusion(DDPM):
924
  key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
925
  cond = {key: cond}
926
 
927
- if hasattr(self, "split_input_params"):
928
- assert len(cond) == 1 # todo can only deal with one conditioning atm
929
- assert not return_ids
930
- ks = self.split_input_params["ks"] # eg. (128, 128)
931
- stride = self.split_input_params["stride"] # eg. (64, 64)
932
-
933
- h, w = x_noisy.shape[-2:]
934
-
935
- fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
936
-
937
- z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
938
- # Reshape to img shape
939
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
940
- z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
941
-
942
- if self.cond_stage_key in ["image", "LR_image", "segmentation",
943
- 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
944
- c_key = next(iter(cond.keys())) # get key
945
- c = next(iter(cond.values())) # get value
946
- assert (len(c) == 1) # todo extend to list with more than one elem
947
- c = c[0] # get element
948
-
949
- c = unfold(c)
950
- c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
951
-
952
- cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
953
-
954
- elif self.cond_stage_key == 'coordinates_bbox':
955
- assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
956
-
957
- # assuming padding of unfold is always 0 and its dilation is always 1
958
- n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
959
- full_img_h, full_img_w = self.split_input_params['original_image_size']
960
- # as we are operating on latents, we need the factor from the original image size to the
961
- # spatial latent size to properly rescale the crops for regenerating the bbox annotations
962
- num_downs = self.first_stage_model.encoder.num_resolutions - 1
963
- rescale_latent = 2 ** (num_downs)
964
-
965
- # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
966
- # need to rescale the tl patch coordinates to be in between (0,1)
967
- tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
968
- rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
969
- for patch_nr in range(z.shape[-1])]
970
-
971
- # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
972
- patch_limits = [(x_tl, y_tl,
973
- rescale_latent * ks[0] / full_img_w,
974
- rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
975
- # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
976
-
977
- # tokenize crop coordinates for the bounding boxes of the respective patches
978
- patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
979
- for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
980
- print(patch_limits_tknzd[0].shape)
981
- # cut tknzd crop position from conditioning
982
- assert isinstance(cond, dict), 'cond must be dict to be fed into model'
983
- cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
984
- print(cut_cond.shape)
985
-
986
- adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
987
- adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
988
- print(adapted_cond.shape)
989
- adapted_cond = self.get_learned_conditioning(adapted_cond)
990
- print(adapted_cond.shape)
991
- adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
992
- print(adapted_cond.shape)
993
-
994
- cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
995
-
996
- else:
997
- cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
998
-
999
- # apply model by loop over crops
1000
- if features_adapter is not None:
1001
- output_list = [self.model(z_list[i], t, **cond_list[i], features_adapter=features_adapter) for i in range(z.shape[-1])]
1002
- else:
1003
- output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
1004
- assert not isinstance(output_list[0],
1005
- tuple) # todo cant deal with multiple model outputs check this never happens
1006
-
1007
- o = torch.stack(output_list, axis=-1)
1008
- o = o * weighting
1009
- # Reverse reshape to img shape
1010
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1011
- # stitch crops together
1012
- x_recon = fold(o) / normalization
1013
-
1014
- else:
1015
- if features_adapter is not None:
1016
- x_recon = self.model(x_noisy, t, **cond, features_adapter=features_adapter)
1017
- else:
1018
- x_recon = self.model(x_noisy, t, **cond)
1019
 
1020
  if isinstance(x_recon, tuple) and not return_ids:
1021
  return x_recon[0]
@@ -1040,10 +848,10 @@ class LatentDiffusion(DDPM):
1040
  kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1041
  return mean_flat(kl_prior) / np.log(2.0)
1042
 
1043
- def p_losses(self, x_start, cond, t, features_adapter=None, noise=None):
1044
  noise = default(noise, lambda: torch.randn_like(x_start))
1045
  x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1046
- model_output = self.apply_model(x_noisy, t, cond, features_adapter)
1047
 
1048
  loss_dict = {}
1049
  prefix = 'train' if self.training else 'val'
@@ -1052,6 +860,8 @@ class LatentDiffusion(DDPM):
1052
  target = x_start
1053
  elif self.parameterization == "eps":
1054
  target = noise
 
 
1055
  else:
1056
  raise NotImplementedError()
1057
 
@@ -1247,7 +1057,7 @@ class LatentDiffusion(DDPM):
1247
  @torch.no_grad()
1248
  def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1249
  verbose=True, timesteps=None, quantize_denoised=False,
1250
- mask=None, x0=None, shape=None,**kwargs):
1251
  if shape is None:
1252
  shape = (batch_size, self.channels, self.image_size, self.image_size)
1253
  if cond is not None:
@@ -1263,26 +1073,51 @@ class LatentDiffusion(DDPM):
1263
  mask=mask, x0=x0)
1264
 
1265
  @torch.no_grad()
1266
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1267
-
1268
  if ddim:
1269
  ddim_sampler = DDIMSampler(self)
1270
  shape = (self.channels, self.image_size, self.image_size)
1271
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1272
- shape,cond,verbose=False,**kwargs)
1273
 
1274
  else:
1275
  samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1276
- return_intermediates=True,**kwargs)
1277
 
1278
  return samples, intermediates
1279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1280
 
1281
  @torch.no_grad()
1282
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1283
  quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1284
- plot_diffusion_rows=True, **kwargs):
1285
-
 
 
1286
  use_ddim = ddim_steps is not None
1287
 
1288
  log = dict()
@@ -1299,12 +1134,16 @@ class LatentDiffusion(DDPM):
1299
  if hasattr(self.cond_stage_model, "decode"):
1300
  xc = self.cond_stage_model.decode(c)
1301
  log["conditioning"] = xc
1302
- elif self.cond_stage_key in ["caption"]:
1303
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1304
  log["conditioning"] = xc
1305
- elif self.cond_stage_key == 'class_label':
1306
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1307
- log['conditioning'] = xc
 
 
 
 
1308
  elif isimage(xc):
1309
  log["conditioning"] = xc
1310
  if ismap(xc):
@@ -1330,9 +1169,9 @@ class LatentDiffusion(DDPM):
1330
 
1331
  if sample:
1332
  # get denoise row
1333
- with self.ema_scope("Plotting"):
1334
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1335
- ddim_steps=ddim_steps,eta=ddim_eta)
1336
  # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1337
  x_samples = self.decode_first_stage(samples)
1338
  log["samples"] = x_samples
@@ -1343,39 +1182,52 @@ class LatentDiffusion(DDPM):
1343
  if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1344
  self.first_stage_model, IdentityFirstStage):
1345
  # also display when quantizing x0 while sampling
1346
- with self.ema_scope("Plotting Quantized Denoised"):
1347
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1348
- ddim_steps=ddim_steps,eta=ddim_eta,
1349
  quantize_denoised=True)
1350
  # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1351
  # quantize_denoised=True)
1352
  x_samples = self.decode_first_stage(samples.to(self.device))
1353
  log["samples_x0_quantized"] = x_samples
1354
 
1355
- if inpaint:
1356
- # make a simple center square
1357
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
1358
- mask = torch.ones(N, h, w).to(self.device)
1359
- # zeros will be filled in
1360
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1361
- mask = mask[:, None, ...]
1362
- with self.ema_scope("Plotting Inpaint"):
1363
-
1364
- samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1365
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1366
- x_samples = self.decode_first_stage(samples.to(self.device))
1367
- log["samples_inpainting"] = x_samples
1368
- log["mask"] = mask
1369
-
1370
- # outpaint
1371
- with self.ema_scope("Plotting Outpaint"):
1372
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1373
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1374
- x_samples = self.decode_first_stage(samples.to(self.device))
1375
- log["samples_outpainting"] = x_samples
 
 
 
 
 
 
 
 
 
 
 
 
 
1376
 
1377
  if plot_progressive_rows:
1378
- with self.ema_scope("Plotting Progressives"):
1379
  img, progressives = self.progressive_denoising(c,
1380
  shape=(self.channels, self.image_size, self.image_size),
1381
  batch_size=N)
@@ -1422,25 +1274,40 @@ class LatentDiffusion(DDPM):
1422
  x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1423
  return x
1424
 
1425
- class Layout2ImgDiffusion(LatentDiffusion):
1426
- # TODO: move all layout-specific hacks to this class
1427
- def __init__(self, cond_stage_key, *args, **kwargs):
1428
- assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1429
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1430
-
1431
- def log_images(self, batch, N=8, *args, **kwargs):
1432
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1433
 
1434
- key = 'train' if self.training else 'validation'
1435
- dset = self.trainer.datamodule.datasets[key]
1436
- mapper = dset.conditional_builders[self.cond_stage_key]
 
 
 
1437
 
1438
- bbox_imgs = []
1439
- map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1440
- for tknzd_bbox in batch[self.cond_stage_key][:N]:
1441
- bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1442
- bbox_imgs.append(bboximg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1443
 
1444
- cond_img = torch.stack(bbox_imgs, dim=0)
1445
- logs['bbox_image'] = cond_img
1446
- return logs
 
12
  import pytorch_lightning as pl
13
  from torch.optim.lr_scheduler import LambdaLR
14
  from einops import rearrange, repeat
15
+ from contextlib import contextmanager, nullcontext
16
  from functools import partial
17
+ import itertools
18
  from tqdm import tqdm
19
  from torchvision.utils import make_grid
20
  from pytorch_lightning.utilities.distributed import rank_zero_only
21
+ from omegaconf import ListConfig
22
 
23
  from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
24
  from ldm.modules.ema import LitEma
25
  from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
26
+ from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
27
  from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
28
  from ldm.models.diffusion.ddim import DDIMSampler
29
 
 
73
  use_positional_encodings=False,
74
  learn_logvar=False,
75
  logvar_init=0.,
76
+ make_it_fit=False,
77
+ ucg_training=None,
78
+ reset_ema=False,
79
+ reset_num_ema_updates=False,
80
  ):
81
  super().__init__()
82
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
83
  self.parameterization = parameterization
84
  print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
85
  self.cond_stage_model = None
 
106
 
107
  if monitor is not None:
108
  self.monitor = monitor
109
+ self.make_it_fit = make_it_fit
110
+ if reset_ema: assert exists(ckpt_path)
111
  if ckpt_path is not None:
112
  self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
113
+ if reset_ema:
114
+ assert self.use_ema
115
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
116
+ self.model_ema = LitEma(self.model)
117
+ if reset_num_ema_updates:
118
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
119
+ assert self.use_ema
120
+ self.model_ema.reset_num_updates()
121
 
122
  self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
123
  linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
 
129
  if self.learn_logvar:
130
  self.logvar = nn.Parameter(self.logvar, requires_grad=True)
131
 
132
+ self.ucg_training = ucg_training or dict()
133
+ if self.ucg_training:
134
+ self.ucg_prng = np.random.RandomState()
135
 
136
  def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
137
  linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
 
165
 
166
  # calculations for posterior q(x_{t-1} | x_t, x_0)
167
  posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
168
+ 1. - alphas_cumprod) + self.v_posterior * betas
169
  # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
170
  self.register_buffer('posterior_variance', to_torch(posterior_variance))
171
  # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
 
177
 
178
  if self.parameterization == "eps":
179
  lvlb_weights = self.betas ** 2 / (
180
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
181
  elif self.parameterization == "x0":
182
  lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
183
+ elif self.parameterization == "v":
184
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
185
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
186
  else:
187
  raise NotImplementedError("mu not supported")
 
188
  lvlb_weights[0] = lvlb_weights[1]
189
  self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
190
  assert not torch.isnan(self.lvlb_weights).all()
 
204
  if context is not None:
205
  print(f"{context}: Restored training weights")
206
 
207
+ @torch.no_grad()
208
  def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
209
  sd = torch.load(path, map_location="cpu")
210
  if "state_dict" in list(sd.keys()):
 
215
  if k.startswith(ik):
216
  print("Deleting key {} from state_dict.".format(k))
217
  del sd[k]
218
+ if self.make_it_fit:
219
+ n_params = len([name for name, _ in
220
+ itertools.chain(self.named_parameters(),
221
+ self.named_buffers())])
222
+ for name, param in tqdm(
223
+ itertools.chain(self.named_parameters(),
224
+ self.named_buffers()),
225
+ desc="Fitting old weights to new weights",
226
+ total=n_params
227
+ ):
228
+ if not name in sd:
229
+ continue
230
+ old_shape = sd[name].shape
231
+ new_shape = param.shape
232
+ assert len(old_shape) == len(new_shape)
233
+ if len(new_shape) > 2:
234
+ # we only modify first two axes
235
+ assert new_shape[2:] == old_shape[2:]
236
+ # assumes first axis corresponds to output dim
237
+ if not new_shape == old_shape:
238
+ new_param = param.clone()
239
+ old_param = sd[name]
240
+ if len(new_shape) == 1:
241
+ for i in range(new_param.shape[0]):
242
+ new_param[i] = old_param[i % old_shape[0]]
243
+ elif len(new_shape) >= 2:
244
+ for i in range(new_param.shape[0]):
245
+ for j in range(new_param.shape[1]):
246
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
247
+
248
+ n_used_old = torch.ones(old_shape[1])
249
+ for j in range(new_param.shape[1]):
250
+ n_used_old[j % old_shape[1]] += 1
251
+ n_used_new = torch.zeros(new_shape[1])
252
+ for j in range(new_param.shape[1]):
253
+ n_used_new[j] = n_used_old[j % old_shape[1]]
254
+
255
+ n_used_new = n_used_new[None, :]
256
+ while len(n_used_new.shape) < len(new_shape):
257
+ n_used_new = n_used_new.unsqueeze(-1)
258
+ new_param /= n_used_new
259
+
260
+ sd[name] = new_param
261
+
262
  missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
263
  sd, strict=False)
264
  print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
265
  if len(missing) > 0:
266
+ print(f"Missing Keys:\n {missing}")
267
  if len(unexpected) > 0:
268
+ print(f"\nUnexpected Keys:\n {unexpected}")
269
 
270
  def q_mean_variance(self, x_start, t):
271
  """
 
285
  extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
286
  )
287
 
288
+ def predict_start_from_z_and_v(self, x_t, t, v):
289
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
290
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
291
+ return (
292
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
293
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
294
+ )
295
+
296
+ def predict_eps_from_z_and_v(self, x_t, t, v):
297
+ return (
298
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
299
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
300
+ )
301
+
302
  def q_posterior(self, x_start, x_t, t):
303
  posterior_mean = (
304
  extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
 
356
  return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
357
  extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
358
 
359
+ def get_v(self, x, noise, t):
360
+ return (
361
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
362
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
363
+ )
364
+
365
  def get_loss(self, pred, target, mean=True):
366
  if self.loss_type == 'l1':
367
  loss = (target - pred).abs()
 
387
  target = noise
388
  elif self.parameterization == "x0":
389
  target = x_start
390
+ elif self.parameterization == "v":
391
+ target = self.get_v(x_start, noise, t)
392
  else:
393
  raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
394
 
 
416
 
417
  def get_input(self, batch, k):
418
  x = batch[k]
419
+ # if len(x.shape) == 3:
420
+ # x = x[..., None]
421
+ # x = rearrange(x, 'b h w c -> b c h w')
422
+ # x = x.to(memory_format=torch.contiguous_format).float()
423
  return x
424
 
425
  def shared_step(self, batch):
 
509
  return opt
510
 
511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  class LatentDiffusion(DDPM):
513
  """main class"""
514
+
515
  def __init__(self,
516
  first_stage_config,
517
  cond_stage_config,
 
518
  num_timesteps_cond=None,
519
  cond_stage_key="image",
520
  cond_stage_trainable=False,
 
533
  if cond_stage_config == '__is_unconditional__':
534
  conditioning_key = None
535
  ckpt_path = kwargs.pop("ckpt_path", None)
536
+ reset_ema = kwargs.pop("reset_ema", False)
537
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
538
  ignore_keys = kwargs.pop("ignore_keys", [])
539
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
 
540
  self.concat_mode = concat_mode
541
  self.cond_stage_trainable = cond_stage_trainable
542
  self.cond_stage_key = cond_stage_key
 
552
  self.instantiate_cond_stage(cond_stage_config)
553
  self.cond_stage_forward = cond_stage_forward
554
  self.clip_denoised = False
555
+ self.bbox_tokenizer = None
556
 
557
  self.restarted_from_ckpt = False
558
  if ckpt_path is not None:
559
  self.init_from_ckpt(ckpt_path, ignore_keys)
560
  self.restarted_from_ckpt = True
561
+ if reset_ema:
562
+ assert self.use_ema
563
+ print(
564
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
565
+ self.model_ema = LitEma(self.model)
566
+ if reset_num_ema_updates:
567
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
568
+ assert self.use_ema
569
+ self.model_ema.reset_num_updates()
570
 
571
  def make_cond_schedule(self, ):
572
  self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
573
  ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
574
  self.cond_ids[:self.num_timesteps_cond] = ids
575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  def register_schedule(self,
577
  given_betas=None, beta_schedule="linear", timesteps=1000,
578
  linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
 
614
  denoise_row = []
615
  for zd in tqdm(samples, desc=desc):
616
  denoise_row.append(self.decode_first_stage(zd.to(self.device),
617
+ force_not_quantize=force_no_decoder_quantization))
618
  n_imgs_per_row = len(denoise_row)
619
  denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
620
  denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
 
747
  if cond_key is None:
748
  cond_key = self.cond_stage_key
749
  if cond_key != self.first_stage_key:
750
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
751
  xc = batch[cond_key]
752
+ elif cond_key in ['class_label', 'cls']:
753
  xc = batch
754
  else:
755
  xc = super().get_input(batch, cond_key).to(self.device)
 
794
  z = rearrange(z, 'b h w c -> b c h w').contiguous()
795
 
796
  z = 1. / self.scale_factor * z
797
+ return self.first_stage_model.decode(z)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
 
799
  @torch.no_grad()
800
  def encode_first_stage(self, x):
801
+ return self.first_stage_model.encode(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802
 
803
  def shared_step(self, batch, **kwargs):
804
  x, c = self.get_input(batch, self.first_stage_key)
805
+ loss = self(x, c, **kwargs)
806
  return loss
807
 
808
+ def forward(self, x, c, *args, **kwargs):
809
+ if 't' not in kwargs:
810
+ t = torch.randint(0, self.num_timesteps, (x.shape[0], ), device=self.device).long()
811
+ else:
812
+ t = kwargs.pop('t')
 
 
 
 
 
 
 
 
 
813
 
814
+ return self.p_losses(x, c, t, *args, **kwargs)
815
 
816
+ def apply_model(self, x_noisy, t, cond, return_ids=False, **kwargs):
817
  if isinstance(cond, dict):
818
+ # hybrid case, cond is expected to be a dict
819
  pass
820
  else:
821
  if not isinstance(cond, list):
 
823
  key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
824
  cond = {key: cond}
825
 
826
+ x_recon = self.model(x_noisy, t, **cond, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827
 
828
  if isinstance(x_recon, tuple) and not return_ids:
829
  return x_recon[0]
 
848
  kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
849
  return mean_flat(kl_prior) / np.log(2.0)
850
 
851
+ def p_losses(self, x_start, cond, t, noise=None, **kwargs):
852
  noise = default(noise, lambda: torch.randn_like(x_start))
853
  x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
854
+ model_output = self.apply_model(x_noisy, t, cond, **kwargs)
855
 
856
  loss_dict = {}
857
  prefix = 'train' if self.training else 'val'
 
860
  target = x_start
861
  elif self.parameterization == "eps":
862
  target = noise
863
+ elif self.parameterization == "v":
864
+ target = self.get_v(x_start, noise, t)
865
  else:
866
  raise NotImplementedError()
867
 
 
1057
  @torch.no_grad()
1058
  def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1059
  verbose=True, timesteps=None, quantize_denoised=False,
1060
+ mask=None, x0=None, shape=None, **kwargs):
1061
  if shape is None:
1062
  shape = (batch_size, self.channels, self.image_size, self.image_size)
1063
  if cond is not None:
 
1073
  mask=mask, x0=x0)
1074
 
1075
  @torch.no_grad()
1076
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
 
1077
  if ddim:
1078
  ddim_sampler = DDIMSampler(self)
1079
  shape = (self.channels, self.image_size, self.image_size)
1080
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
1081
+ shape, cond, verbose=False, **kwargs)
1082
 
1083
  else:
1084
  samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1085
+ return_intermediates=True, **kwargs)
1086
 
1087
  return samples, intermediates
1088
 
1089
+ @torch.no_grad()
1090
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
1091
+ if null_label is not None:
1092
+ xc = null_label
1093
+ if isinstance(xc, ListConfig):
1094
+ xc = list(xc)
1095
+ if isinstance(xc, dict) or isinstance(xc, list):
1096
+ c = self.get_learned_conditioning(xc)
1097
+ else:
1098
+ if hasattr(xc, "to"):
1099
+ xc = xc.to(self.device)
1100
+ c = self.get_learned_conditioning(xc)
1101
+ else:
1102
+ if self.cond_stage_key in ["class_label", "cls"]:
1103
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
1104
+ return self.get_learned_conditioning(xc)
1105
+ else:
1106
+ raise NotImplementedError("todo")
1107
+ if isinstance(c, list): # in case the encoder gives us a list
1108
+ for i in range(len(c)):
1109
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
1110
+ else:
1111
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
1112
+ return c
1113
 
1114
  @torch.no_grad()
1115
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
1116
  quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1117
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1118
+ use_ema_scope=True,
1119
+ **kwargs):
1120
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1121
  use_ddim = ddim_steps is not None
1122
 
1123
  log = dict()
 
1134
  if hasattr(self.cond_stage_model, "decode"):
1135
  xc = self.cond_stage_model.decode(c)
1136
  log["conditioning"] = xc
1137
+ elif self.cond_stage_key in ["caption", "txt"]:
1138
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1139
  log["conditioning"] = xc
1140
+ elif self.cond_stage_key in ['class_label', "cls"]:
1141
+ try:
1142
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1143
+ log['conditioning'] = xc
1144
+ except KeyError:
1145
+ # probably no "human_label" in batch
1146
+ pass
1147
  elif isimage(xc):
1148
  log["conditioning"] = xc
1149
  if ismap(xc):
 
1169
 
1170
  if sample:
1171
  # get denoise row
1172
+ with ema_scope("Sampling"):
1173
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1174
+ ddim_steps=ddim_steps, eta=ddim_eta)
1175
  # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1176
  x_samples = self.decode_first_stage(samples)
1177
  log["samples"] = x_samples
 
1182
  if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1183
  self.first_stage_model, IdentityFirstStage):
1184
  # also display when quantizing x0 while sampling
1185
+ with ema_scope("Plotting Quantized Denoised"):
1186
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1187
+ ddim_steps=ddim_steps, eta=ddim_eta,
1188
  quantize_denoised=True)
1189
  # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1190
  # quantize_denoised=True)
1191
  x_samples = self.decode_first_stage(samples.to(self.device))
1192
  log["samples_x0_quantized"] = x_samples
1193
 
1194
+ if unconditional_guidance_scale > 1.0:
1195
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1196
+ if self.model.conditioning_key == "crossattn-adm":
1197
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1198
+ with ema_scope("Sampling with classifier-free guidance"):
1199
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1200
+ ddim_steps=ddim_steps, eta=ddim_eta,
1201
+ unconditional_guidance_scale=unconditional_guidance_scale,
1202
+ unconditional_conditioning=uc,
1203
+ )
1204
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1205
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1206
+
1207
+ if inpaint:
1208
+ # make a simple center square
1209
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1210
+ mask = torch.ones(N, h, w).to(self.device)
1211
+ # zeros will be filled in
1212
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1213
+ mask = mask[:, None, ...]
1214
+ with ema_scope("Plotting Inpaint"):
1215
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1216
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1217
+ x_samples = self.decode_first_stage(samples.to(self.device))
1218
+ log["samples_inpainting"] = x_samples
1219
+ log["mask"] = mask
1220
+
1221
+ # outpaint
1222
+ mask = 1. - mask
1223
+ with ema_scope("Plotting Outpaint"):
1224
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1225
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1226
+ x_samples = self.decode_first_stage(samples.to(self.device))
1227
+ log["samples_outpainting"] = x_samples
1228
 
1229
  if plot_progressive_rows:
1230
+ with ema_scope("Plotting Progressives"):
1231
  img, progressives = self.progressive_denoising(c,
1232
  shape=(self.channels, self.image_size, self.image_size),
1233
  batch_size=N)
 
1274
  x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1275
  return x
1276
 
 
 
 
 
 
 
 
 
1277
 
1278
+ class DiffusionWrapper(pl.LightningModule):
1279
+ def __init__(self, diff_model_config, conditioning_key):
1280
+ super().__init__()
1281
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1282
+ self.conditioning_key = conditioning_key
1283
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
1284
 
1285
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, **kwargs):
1286
+ if self.conditioning_key is None:
1287
+ out = self.diffusion_model(x, t, **kwargs)
1288
+ elif self.conditioning_key == 'concat':
1289
+ xc = torch.cat([x] + c_concat, dim=1)
1290
+ out = self.diffusion_model(xc, t, **kwargs)
1291
+ elif self.conditioning_key == 'crossattn':
1292
+ cc = torch.cat(c_crossattn, 1)
1293
+ out = self.diffusion_model(x, t, context=cc, **kwargs)
1294
+ elif self.conditioning_key == 'hybrid':
1295
+ xc = torch.cat([x] + c_concat, dim=1)
1296
+ cc = torch.cat(c_crossattn, 1)
1297
+ out = self.diffusion_model(xc, t, context=cc, **kwargs)
1298
+ elif self.conditioning_key == 'hybrid-adm':
1299
+ assert c_adm is not None
1300
+ xc = torch.cat([x] + c_concat, dim=1)
1301
+ cc = torch.cat(c_crossattn, 1)
1302
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm, **kwargs)
1303
+ elif self.conditioning_key == 'crossattn-adm':
1304
+ assert c_adm is not None
1305
+ cc = torch.cat(c_crossattn, 1)
1306
+ out = self.diffusion_model(x, t, context=cc, y=c_adm, **kwargs)
1307
+ elif self.conditioning_key == 'adm':
1308
+ cc = c_crossattn[0]
1309
+ out = self.diffusion_model(x, t, y=cc, **kwargs)
1310
+ else:
1311
+ raise NotImplementedError()
1312
 
1313
+ return out
 
 
ldm/models/diffusion/dpm_solver/dpm_solver.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn.functional as F
3
  import math
 
4
 
5
 
6
  class NoiseScheduleVP:
@@ -11,7 +12,7 @@ class NoiseScheduleVP:
11
  alphas_cumprod=None,
12
  continuous_beta_0=0.1,
13
  continuous_beta_1=20.,
14
- ):
15
  """Create a wrapper class for the forward SDE (VP type).
16
 
17
  ***
@@ -93,7 +94,9 @@ class NoiseScheduleVP:
93
  """
94
 
95
  if schedule not in ['discrete', 'linear', 'cosine']:
96
- raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
 
 
97
 
98
  self.schedule = schedule
99
  if schedule == 'discrete':
@@ -112,7 +115,8 @@ class NoiseScheduleVP:
112
  self.beta_1 = continuous_beta_1
113
  self.cosine_s = 0.008
114
  self.cosine_beta_max = 999.
115
- self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
 
116
  self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
117
  self.schedule = schedule
118
  if schedule == 'cosine':
@@ -127,12 +131,13 @@ class NoiseScheduleVP:
127
  Compute log(alpha_t) of a given continuous-time label t in [0, T].
128
  """
129
  if self.schedule == 'discrete':
130
- return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
 
131
  elif self.schedule == 'linear':
132
  return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
133
  elif self.schedule == 'cosine':
134
  log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
135
- log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
136
  return log_alpha_t
137
 
138
  def marginal_alpha(self, t):
@@ -161,30 +166,32 @@ class NoiseScheduleVP:
161
  """
162
  if self.schedule == 'linear':
163
  tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
164
- Delta = self.beta_0**2 + tmp
165
  return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
166
  elif self.schedule == 'discrete':
167
  log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
168
- t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
 
169
  return t.reshape((-1,))
170
  else:
171
  log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
172
- t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
 
173
  t = t_fn(log_alpha)
174
  return t
175
 
176
 
177
  def model_wrapper(
178
- model,
179
- noise_schedule,
180
- model_type="noise",
181
- model_kwargs={},
182
- guidance_type="uncond",
183
- condition=None,
184
- unconditional_condition=None,
185
- guidance_scale=1.,
186
- classifier_fn=None,
187
- classifier_kwargs={},
188
  ):
189
  """Create a wrapper function for the noise prediction model.
190
 
@@ -392,7 +399,7 @@ class DPM_Solver:
392
  alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
393
  x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
394
  if self.thresholding:
395
- p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
396
  s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
397
  s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
398
  x0 = torch.clamp(x0, -s, s) / s
@@ -431,10 +438,11 @@ class DPM_Solver:
431
  return torch.linspace(t_T, t_0, N + 1).to(device)
432
  elif skip_type == 'time_quadratic':
433
  t_order = 2
434
- t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
435
  return t
436
  else:
437
- raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
 
438
 
439
  def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
440
  """
@@ -471,28 +479,29 @@ class DPM_Solver:
471
  if order == 3:
472
  K = steps // 3 + 1
473
  if steps % 3 == 0:
474
- orders = [3,] * (K - 2) + [2, 1]
475
  elif steps % 3 == 1:
476
- orders = [3,] * (K - 1) + [1]
477
  else:
478
- orders = [3,] * (K - 1) + [2]
479
  elif order == 2:
480
  if steps % 2 == 0:
481
  K = steps // 2
482
- orders = [2,] * K
483
  else:
484
  K = steps // 2 + 1
485
- orders = [2,] * (K - 1) + [1]
486
  elif order == 1:
487
  K = 1
488
- orders = [1,] * steps
489
  else:
490
  raise ValueError("'order' must be '1' or '2' or '3'.")
491
  if skip_type == 'logSNR':
492
  # To reproduce the results in DPM-Solver paper
493
  timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
494
  else:
495
- timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)]
 
496
  return timesteps_outer, orders
497
 
498
  def denoise_to_zero_fn(self, x, s):
@@ -528,8 +537,8 @@ class DPM_Solver:
528
  if model_s is None:
529
  model_s = self.model_fn(x, s)
530
  x_t = (
531
- expand_dims(sigma_t / sigma_s, dims) * x
532
- - expand_dims(alpha_t * phi_1, dims) * model_s
533
  )
534
  if return_intermediate:
535
  return x_t, {'model_s': model_s}
@@ -540,15 +549,16 @@ class DPM_Solver:
540
  if model_s is None:
541
  model_s = self.model_fn(x, s)
542
  x_t = (
543
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
544
- - expand_dims(sigma_t * phi_1, dims) * model_s
545
  )
546
  if return_intermediate:
547
  return x_t, {'model_s': model_s}
548
  else:
549
  return x_t
550
 
551
- def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'):
 
552
  """
553
  Singlestep solver DPM-Solver-2 from time `s` to time `t`.
554
 
@@ -575,7 +585,8 @@ class DPM_Solver:
575
  h = lambda_t - lambda_s
576
  lambda_s1 = lambda_s + r1 * h
577
  s1 = ns.inverse_lambda(lambda_s1)
578
- log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
 
579
  sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
580
  alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
581
 
@@ -586,21 +597,22 @@ class DPM_Solver:
586
  if model_s is None:
587
  model_s = self.model_fn(x, s)
588
  x_s1 = (
589
- expand_dims(sigma_s1 / sigma_s, dims) * x
590
- - expand_dims(alpha_s1 * phi_11, dims) * model_s
591
  )
592
  model_s1 = self.model_fn(x_s1, s1)
593
  if solver_type == 'dpm_solver':
594
  x_t = (
595
- expand_dims(sigma_t / sigma_s, dims) * x
596
- - expand_dims(alpha_t * phi_1, dims) * model_s
597
- - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
598
  )
599
  elif solver_type == 'taylor':
600
  x_t = (
601
- expand_dims(sigma_t / sigma_s, dims) * x
602
- - expand_dims(alpha_t * phi_1, dims) * model_s
603
- + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s)
 
604
  )
605
  else:
606
  phi_11 = torch.expm1(r1 * h)
@@ -609,28 +621,29 @@ class DPM_Solver:
609
  if model_s is None:
610
  model_s = self.model_fn(x, s)
611
  x_s1 = (
612
- expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
613
- - expand_dims(sigma_s1 * phi_11, dims) * model_s
614
  )
615
  model_s1 = self.model_fn(x_s1, s1)
616
  if solver_type == 'dpm_solver':
617
  x_t = (
618
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
619
- - expand_dims(sigma_t * phi_1, dims) * model_s
620
- - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
621
  )
622
  elif solver_type == 'taylor':
623
  x_t = (
624
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
625
- - expand_dims(sigma_t * phi_1, dims) * model_s
626
- - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
627
  )
628
  if return_intermediate:
629
  return x_t, {'model_s': model_s, 'model_s1': model_s1}
630
  else:
631
  return x_t
632
 
633
- def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'):
 
634
  """
635
  Singlestep solver DPM-Solver-3 from time `s` to time `t`.
636
 
@@ -664,8 +677,10 @@ class DPM_Solver:
664
  lambda_s2 = lambda_s + r2 * h
665
  s1 = ns.inverse_lambda(lambda_s1)
666
  s2 = ns.inverse_lambda(lambda_s2)
667
- log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
668
- sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
 
 
669
  alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
670
 
671
  if self.predict_x0:
@@ -680,21 +695,21 @@ class DPM_Solver:
680
  model_s = self.model_fn(x, s)
681
  if model_s1 is None:
682
  x_s1 = (
683
- expand_dims(sigma_s1 / sigma_s, dims) * x
684
- - expand_dims(alpha_s1 * phi_11, dims) * model_s
685
  )
686
  model_s1 = self.model_fn(x_s1, s1)
687
  x_s2 = (
688
- expand_dims(sigma_s2 / sigma_s, dims) * x
689
- - expand_dims(alpha_s2 * phi_12, dims) * model_s
690
- + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
691
  )
692
  model_s2 = self.model_fn(x_s2, s2)
693
  if solver_type == 'dpm_solver':
694
  x_t = (
695
- expand_dims(sigma_t / sigma_s, dims) * x
696
- - expand_dims(alpha_t * phi_1, dims) * model_s
697
- + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
698
  )
699
  elif solver_type == 'taylor':
700
  D1_0 = (1. / r1) * (model_s1 - model_s)
@@ -702,10 +717,10 @@ class DPM_Solver:
702
  D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
703
  D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
704
  x_t = (
705
- expand_dims(sigma_t / sigma_s, dims) * x
706
- - expand_dims(alpha_t * phi_1, dims) * model_s
707
- + expand_dims(alpha_t * phi_2, dims) * D1
708
- - expand_dims(alpha_t * phi_3, dims) * D2
709
  )
710
  else:
711
  phi_11 = torch.expm1(r1 * h)
@@ -719,21 +734,21 @@ class DPM_Solver:
719
  model_s = self.model_fn(x, s)
720
  if model_s1 is None:
721
  x_s1 = (
722
- expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
723
- - expand_dims(sigma_s1 * phi_11, dims) * model_s
724
  )
725
  model_s1 = self.model_fn(x_s1, s1)
726
  x_s2 = (
727
- expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
728
- - expand_dims(sigma_s2 * phi_12, dims) * model_s
729
- - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
730
  )
731
  model_s2 = self.model_fn(x_s2, s2)
732
  if solver_type == 'dpm_solver':
733
  x_t = (
734
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
735
- - expand_dims(sigma_t * phi_1, dims) * model_s
736
- - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
737
  )
738
  elif solver_type == 'taylor':
739
  D1_0 = (1. / r1) * (model_s1 - model_s)
@@ -741,10 +756,10 @@ class DPM_Solver:
741
  D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
742
  D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
743
  x_t = (
744
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
745
- - expand_dims(sigma_t * phi_1, dims) * model_s
746
- - expand_dims(sigma_t * phi_2, dims) * D1
747
- - expand_dims(sigma_t * phi_3, dims) * D2
748
  )
749
 
750
  if return_intermediate:
@@ -772,7 +787,8 @@ class DPM_Solver:
772
  dims = x.dim()
773
  model_prev_1, model_prev_0 = model_prev_list
774
  t_prev_1, t_prev_0 = t_prev_list
775
- lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
 
776
  log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
777
  sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
778
  alpha_t = torch.exp(log_alpha_t)
@@ -784,28 +800,28 @@ class DPM_Solver:
784
  if self.predict_x0:
785
  if solver_type == 'dpm_solver':
786
  x_t = (
787
- expand_dims(sigma_t / sigma_prev_0, dims) * x
788
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
789
- - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
790
  )
791
  elif solver_type == 'taylor':
792
  x_t = (
793
- expand_dims(sigma_t / sigma_prev_0, dims) * x
794
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
795
- + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
796
  )
797
  else:
798
  if solver_type == 'dpm_solver':
799
  x_t = (
800
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
801
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
802
- - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
803
  )
804
  elif solver_type == 'taylor':
805
  x_t = (
806
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
807
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
808
- - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
809
  )
810
  return x_t
811
 
@@ -827,7 +843,8 @@ class DPM_Solver:
827
  dims = x.dim()
828
  model_prev_2, model_prev_1, model_prev_0 = model_prev_list
829
  t_prev_2, t_prev_1, t_prev_0 = t_prev_list
830
- lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
 
831
  log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
832
  sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
833
  alpha_t = torch.exp(log_alpha_t)
@@ -842,21 +859,22 @@ class DPM_Solver:
842
  D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
843
  if self.predict_x0:
844
  x_t = (
845
- expand_dims(sigma_t / sigma_prev_0, dims) * x
846
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
847
- + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
848
- - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2
849
  )
850
  else:
851
  x_t = (
852
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
853
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
854
- - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
855
- - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2
856
  )
857
  return x_t
858
 
859
- def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None):
 
860
  """
861
  Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
862
 
@@ -876,9 +894,11 @@ class DPM_Solver:
876
  if order == 1:
877
  return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
878
  elif order == 2:
879
- return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
 
880
  elif order == 3:
881
- return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
 
882
  else:
883
  raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
884
 
@@ -906,7 +926,8 @@ class DPM_Solver:
906
  else:
907
  raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
908
 
909
- def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'):
 
910
  """
911
  The adaptive step size solver based on singlestep DPM-Solver.
912
 
@@ -938,11 +959,17 @@ class DPM_Solver:
938
  if order == 2:
939
  r1 = 0.5
940
  lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
941
- higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
 
 
942
  elif order == 3:
943
  r1, r2 = 1. / 3., 2. / 3.
944
- lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
945
- higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
 
 
 
 
946
  else:
947
  raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
948
  while torch.abs((s - t_0)).mean() > t_err:
@@ -963,9 +990,9 @@ class DPM_Solver:
963
  return x
964
 
965
  def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
966
- method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
967
- atol=0.0078, rtol=0.05,
968
- ):
969
  """
970
  Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
971
 
@@ -1073,7 +1100,8 @@ class DPM_Solver:
1073
  device = x.device
1074
  if method == 'adaptive':
1075
  with torch.no_grad():
1076
- x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
 
1077
  elif method == 'multistep':
1078
  assert steps >= order
1079
  timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
@@ -1083,19 +1111,21 @@ class DPM_Solver:
1083
  model_prev_list = [self.model_fn(x, vec_t)]
1084
  t_prev_list = [vec_t]
1085
  # Init the first `order` values by lower order multistep DPM-Solver.
1086
- for init_order in range(1, order):
1087
  vec_t = timesteps[init_order].expand(x.shape[0])
1088
- x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type)
 
1089
  model_prev_list.append(self.model_fn(x, vec_t))
1090
  t_prev_list.append(vec_t)
1091
  # Compute the remaining values by `order`-th order multistep DPM-Solver.
1092
- for step in range(order, steps + 1):
1093
  vec_t = timesteps[step].expand(x.shape[0])
1094
  if lower_order_final and steps < 15:
1095
  step_order = min(order, steps + 1 - step)
1096
  else:
1097
  step_order = order
1098
- x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type)
 
1099
  for i in range(order - 1):
1100
  t_prev_list[i] = t_prev_list[i + 1]
1101
  model_prev_list[i] = model_prev_list[i + 1]
@@ -1105,14 +1135,18 @@ class DPM_Solver:
1105
  model_prev_list[-1] = self.model_fn(x, vec_t)
1106
  elif method in ['singlestep', 'singlestep_fixed']:
1107
  if method == 'singlestep':
1108
- timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
 
 
 
1109
  elif method == 'singlestep_fixed':
1110
  K = steps // order
1111
- orders = [order,] * K
1112
  timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1113
  for i, order in enumerate(orders):
1114
  t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1115
- timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device)
 
1116
  lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1117
  vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1118
  h = lambda_inner[-1] - lambda_inner[0]
@@ -1124,7 +1158,6 @@ class DPM_Solver:
1124
  return x
1125
 
1126
 
1127
-
1128
  #############################################################
1129
  # other utility functions
1130
  #############################################################
@@ -1181,4 +1214,4 @@ def expand_dims(v, dims):
1181
  Returns:
1182
  a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1183
  """
1184
- return v[(...,) + (None,)*(dims - 1)]
 
1
  import torch
2
  import torch.nn.functional as F
3
  import math
4
+ from tqdm import tqdm
5
 
6
 
7
  class NoiseScheduleVP:
 
12
  alphas_cumprod=None,
13
  continuous_beta_0=0.1,
14
  continuous_beta_1=20.,
15
+ ):
16
  """Create a wrapper class for the forward SDE (VP type).
17
 
18
  ***
 
94
  """
95
 
96
  if schedule not in ['discrete', 'linear', 'cosine']:
97
+ raise ValueError(
98
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
99
+ schedule))
100
 
101
  self.schedule = schedule
102
  if schedule == 'discrete':
 
115
  self.beta_1 = continuous_beta_1
116
  self.cosine_s = 0.008
117
  self.cosine_beta_max = 999.
118
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
119
+ 1. + self.cosine_s) / math.pi - self.cosine_s
120
  self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
121
  self.schedule = schedule
122
  if schedule == 'cosine':
 
131
  Compute log(alpha_t) of a given continuous-time label t in [0, T].
132
  """
133
  if self.schedule == 'discrete':
134
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
135
+ self.log_alpha_array.to(t.device)).reshape((-1))
136
  elif self.schedule == 'linear':
137
  return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
138
  elif self.schedule == 'cosine':
139
  log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
140
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
141
  return log_alpha_t
142
 
143
  def marginal_alpha(self, t):
 
166
  """
167
  if self.schedule == 'linear':
168
  tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
169
+ Delta = self.beta_0 ** 2 + tmp
170
  return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
171
  elif self.schedule == 'discrete':
172
  log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
173
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
174
+ torch.flip(self.t_array.to(lamb.device), [1]))
175
  return t.reshape((-1,))
176
  else:
177
  log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
178
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
179
+ 1. + self.cosine_s) / math.pi - self.cosine_s
180
  t = t_fn(log_alpha)
181
  return t
182
 
183
 
184
  def model_wrapper(
185
+ model,
186
+ noise_schedule,
187
+ model_type="noise",
188
+ model_kwargs={},
189
+ guidance_type="uncond",
190
+ condition=None,
191
+ unconditional_condition=None,
192
+ guidance_scale=1.,
193
+ classifier_fn=None,
194
+ classifier_kwargs={},
195
  ):
196
  """Create a wrapper function for the noise prediction model.
197
 
 
399
  alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
400
  x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
401
  if self.thresholding:
402
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
403
  s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
404
  s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
405
  x0 = torch.clamp(x0, -s, s) / s
 
438
  return torch.linspace(t_T, t_0, N + 1).to(device)
439
  elif skip_type == 'time_quadratic':
440
  t_order = 2
441
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
442
  return t
443
  else:
444
+ raise ValueError(
445
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
446
 
447
  def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
448
  """
 
479
  if order == 3:
480
  K = steps // 3 + 1
481
  if steps % 3 == 0:
482
+ orders = [3, ] * (K - 2) + [2, 1]
483
  elif steps % 3 == 1:
484
+ orders = [3, ] * (K - 1) + [1]
485
  else:
486
+ orders = [3, ] * (K - 1) + [2]
487
  elif order == 2:
488
  if steps % 2 == 0:
489
  K = steps // 2
490
+ orders = [2, ] * K
491
  else:
492
  K = steps // 2 + 1
493
+ orders = [2, ] * (K - 1) + [1]
494
  elif order == 1:
495
  K = 1
496
+ orders = [1, ] * steps
497
  else:
498
  raise ValueError("'order' must be '1' or '2' or '3'.")
499
  if skip_type == 'logSNR':
500
  # To reproduce the results in DPM-Solver paper
501
  timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
502
  else:
503
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
504
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
505
  return timesteps_outer, orders
506
 
507
  def denoise_to_zero_fn(self, x, s):
 
537
  if model_s is None:
538
  model_s = self.model_fn(x, s)
539
  x_t = (
540
+ expand_dims(sigma_t / sigma_s, dims) * x
541
+ - expand_dims(alpha_t * phi_1, dims) * model_s
542
  )
543
  if return_intermediate:
544
  return x_t, {'model_s': model_s}
 
549
  if model_s is None:
550
  model_s = self.model_fn(x, s)
551
  x_t = (
552
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
553
+ - expand_dims(sigma_t * phi_1, dims) * model_s
554
  )
555
  if return_intermediate:
556
  return x_t, {'model_s': model_s}
557
  else:
558
  return x_t
559
 
560
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
561
+ solver_type='dpm_solver'):
562
  """
563
  Singlestep solver DPM-Solver-2 from time `s` to time `t`.
564
 
 
585
  h = lambda_t - lambda_s
586
  lambda_s1 = lambda_s + r1 * h
587
  s1 = ns.inverse_lambda(lambda_s1)
588
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
589
+ s1), ns.marginal_log_mean_coeff(t)
590
  sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
591
  alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
592
 
 
597
  if model_s is None:
598
  model_s = self.model_fn(x, s)
599
  x_s1 = (
600
+ expand_dims(sigma_s1 / sigma_s, dims) * x
601
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
602
  )
603
  model_s1 = self.model_fn(x_s1, s1)
604
  if solver_type == 'dpm_solver':
605
  x_t = (
606
+ expand_dims(sigma_t / sigma_s, dims) * x
607
+ - expand_dims(alpha_t * phi_1, dims) * model_s
608
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
609
  )
610
  elif solver_type == 'taylor':
611
  x_t = (
612
+ expand_dims(sigma_t / sigma_s, dims) * x
613
+ - expand_dims(alpha_t * phi_1, dims) * model_s
614
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
615
+ model_s1 - model_s)
616
  )
617
  else:
618
  phi_11 = torch.expm1(r1 * h)
 
621
  if model_s is None:
622
  model_s = self.model_fn(x, s)
623
  x_s1 = (
624
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
625
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
626
  )
627
  model_s1 = self.model_fn(x_s1, s1)
628
  if solver_type == 'dpm_solver':
629
  x_t = (
630
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
631
+ - expand_dims(sigma_t * phi_1, dims) * model_s
632
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
633
  )
634
  elif solver_type == 'taylor':
635
  x_t = (
636
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
637
+ - expand_dims(sigma_t * phi_1, dims) * model_s
638
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
639
  )
640
  if return_intermediate:
641
  return x_t, {'model_s': model_s, 'model_s1': model_s1}
642
  else:
643
  return x_t
644
 
645
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
646
+ return_intermediate=False, solver_type='dpm_solver'):
647
  """
648
  Singlestep solver DPM-Solver-3 from time `s` to time `t`.
649
 
 
677
  lambda_s2 = lambda_s + r2 * h
678
  s1 = ns.inverse_lambda(lambda_s1)
679
  s2 = ns.inverse_lambda(lambda_s2)
680
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
681
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
682
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
683
+ s2), ns.marginal_std(t)
684
  alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
685
 
686
  if self.predict_x0:
 
695
  model_s = self.model_fn(x, s)
696
  if model_s1 is None:
697
  x_s1 = (
698
+ expand_dims(sigma_s1 / sigma_s, dims) * x
699
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
700
  )
701
  model_s1 = self.model_fn(x_s1, s1)
702
  x_s2 = (
703
+ expand_dims(sigma_s2 / sigma_s, dims) * x
704
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
705
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
706
  )
707
  model_s2 = self.model_fn(x_s2, s2)
708
  if solver_type == 'dpm_solver':
709
  x_t = (
710
+ expand_dims(sigma_t / sigma_s, dims) * x
711
+ - expand_dims(alpha_t * phi_1, dims) * model_s
712
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
713
  )
714
  elif solver_type == 'taylor':
715
  D1_0 = (1. / r1) * (model_s1 - model_s)
 
717
  D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
718
  D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
719
  x_t = (
720
+ expand_dims(sigma_t / sigma_s, dims) * x
721
+ - expand_dims(alpha_t * phi_1, dims) * model_s
722
+ + expand_dims(alpha_t * phi_2, dims) * D1
723
+ - expand_dims(alpha_t * phi_3, dims) * D2
724
  )
725
  else:
726
  phi_11 = torch.expm1(r1 * h)
 
734
  model_s = self.model_fn(x, s)
735
  if model_s1 is None:
736
  x_s1 = (
737
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
738
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
739
  )
740
  model_s1 = self.model_fn(x_s1, s1)
741
  x_s2 = (
742
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
743
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
744
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
745
  )
746
  model_s2 = self.model_fn(x_s2, s2)
747
  if solver_type == 'dpm_solver':
748
  x_t = (
749
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
750
+ - expand_dims(sigma_t * phi_1, dims) * model_s
751
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
752
  )
753
  elif solver_type == 'taylor':
754
  D1_0 = (1. / r1) * (model_s1 - model_s)
 
756
  D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
757
  D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
758
  x_t = (
759
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
760
+ - expand_dims(sigma_t * phi_1, dims) * model_s
761
+ - expand_dims(sigma_t * phi_2, dims) * D1
762
+ - expand_dims(sigma_t * phi_3, dims) * D2
763
  )
764
 
765
  if return_intermediate:
 
787
  dims = x.dim()
788
  model_prev_1, model_prev_0 = model_prev_list
789
  t_prev_1, t_prev_0 = t_prev_list
790
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
791
+ t_prev_0), ns.marginal_lambda(t)
792
  log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
793
  sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
794
  alpha_t = torch.exp(log_alpha_t)
 
800
  if self.predict_x0:
801
  if solver_type == 'dpm_solver':
802
  x_t = (
803
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
804
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
805
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
806
  )
807
  elif solver_type == 'taylor':
808
  x_t = (
809
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
810
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
811
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
812
  )
813
  else:
814
  if solver_type == 'dpm_solver':
815
  x_t = (
816
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
817
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
818
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
819
  )
820
  elif solver_type == 'taylor':
821
  x_t = (
822
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
823
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
824
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
825
  )
826
  return x_t
827
 
 
843
  dims = x.dim()
844
  model_prev_2, model_prev_1, model_prev_0 = model_prev_list
845
  t_prev_2, t_prev_1, t_prev_0 = t_prev_list
846
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
847
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
848
  log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
849
  sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
850
  alpha_t = torch.exp(log_alpha_t)
 
859
  D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
860
  if self.predict_x0:
861
  x_t = (
862
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
863
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
864
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
865
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
866
  )
867
  else:
868
  x_t = (
869
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
870
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
871
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
872
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
873
  )
874
  return x_t
875
 
876
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
877
+ r2=None):
878
  """
879
  Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
880
 
 
894
  if order == 1:
895
  return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
896
  elif order == 2:
897
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
898
+ solver_type=solver_type, r1=r1)
899
  elif order == 3:
900
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
901
+ solver_type=solver_type, r1=r1, r2=r2)
902
  else:
903
  raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
904
 
 
926
  else:
927
  raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
928
 
929
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
930
+ solver_type='dpm_solver'):
931
  """
932
  The adaptive step size solver based on singlestep DPM-Solver.
933
 
 
959
  if order == 2:
960
  r1 = 0.5
961
  lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
962
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
963
+ solver_type=solver_type,
964
+ **kwargs)
965
  elif order == 3:
966
  r1, r2 = 1. / 3., 2. / 3.
967
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
968
+ return_intermediate=True,
969
+ solver_type=solver_type)
970
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
971
+ solver_type=solver_type,
972
+ **kwargs)
973
  else:
974
  raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
975
  while torch.abs((s - t_0)).mean() > t_err:
 
990
  return x
991
 
992
  def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
993
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
994
+ atol=0.0078, rtol=0.05,
995
+ ):
996
  """
997
  Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
998
 
 
1100
  device = x.device
1101
  if method == 'adaptive':
1102
  with torch.no_grad():
1103
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1104
+ solver_type=solver_type)
1105
  elif method == 'multistep':
1106
  assert steps >= order
1107
  timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
 
1111
  model_prev_list = [self.model_fn(x, vec_t)]
1112
  t_prev_list = [vec_t]
1113
  # Init the first `order` values by lower order multistep DPM-Solver.
1114
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
1115
  vec_t = timesteps[init_order].expand(x.shape[0])
1116
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1117
+ solver_type=solver_type)
1118
  model_prev_list.append(self.model_fn(x, vec_t))
1119
  t_prev_list.append(vec_t)
1120
  # Compute the remaining values by `order`-th order multistep DPM-Solver.
1121
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1122
  vec_t = timesteps[step].expand(x.shape[0])
1123
  if lower_order_final and steps < 15:
1124
  step_order = min(order, steps + 1 - step)
1125
  else:
1126
  step_order = order
1127
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1128
+ solver_type=solver_type)
1129
  for i in range(order - 1):
1130
  t_prev_list[i] = t_prev_list[i + 1]
1131
  model_prev_list[i] = model_prev_list[i + 1]
 
1135
  model_prev_list[-1] = self.model_fn(x, vec_t)
1136
  elif method in ['singlestep', 'singlestep_fixed']:
1137
  if method == 'singlestep':
1138
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1139
+ skip_type=skip_type,
1140
+ t_T=t_T, t_0=t_0,
1141
+ device=device)
1142
  elif method == 'singlestep_fixed':
1143
  K = steps // order
1144
+ orders = [order, ] * K
1145
  timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1146
  for i, order in enumerate(orders):
1147
  t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1148
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1149
+ N=order, device=device)
1150
  lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1151
  vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1152
  h = lambda_inner[-1] - lambda_inner[0]
 
1158
  return x
1159
 
1160
 
 
1161
  #############################################################
1162
  # other utility functions
1163
  #############################################################
 
1214
  Returns:
1215
  a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1216
  """
1217
+ return v[(...,) + (None,) * (dims - 1)]
ldm/models/diffusion/dpm_solver/sampler.py CHANGED
@@ -1,10 +1,15 @@
1
  """SAMPLING ONLY."""
2
-
3
  import torch
4
 
5
  from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6
 
7
 
 
 
 
 
 
 
8
  class DPMSolverSampler(object):
9
  def __init__(self, model, **kwargs):
10
  super().__init__()
@@ -56,7 +61,7 @@ class DPMSolverSampler(object):
56
  C, H, W = shape
57
  size = (batch_size, C, H, W)
58
 
59
- # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
60
 
61
  device = self.model.betas.device
62
  if x_T is None:
@@ -69,7 +74,7 @@ class DPMSolverSampler(object):
69
  model_fn = model_wrapper(
70
  lambda x, t, c: self.model.apply_model(x, t, c),
71
  ns,
72
- model_type="noise",
73
  guidance_type="classifier-free",
74
  condition=conditioning,
75
  unconditional_condition=unconditional_conditioning,
 
1
  """SAMPLING ONLY."""
 
2
  import torch
3
 
4
  from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5
 
6
 
7
+ MODEL_TYPES = {
8
+ "eps": "noise",
9
+ "v": "v"
10
+ }
11
+
12
+
13
  class DPMSolverSampler(object):
14
  def __init__(self, model, **kwargs):
15
  super().__init__()
 
61
  C, H, W = shape
62
  size = (batch_size, C, H, W)
63
 
64
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65
 
66
  device = self.model.betas.device
67
  if x_T is None:
 
74
  model_fn = model_wrapper(
75
  lambda x, t, c: self.model.apply_model(x, t, c),
76
  ns,
77
+ model_type=MODEL_TYPES[self.model.parameterization],
78
  guidance_type="classifier-free",
79
  condition=conditioning,
80
  unconditional_condition=unconditional_conditioning,
ldm/models/diffusion/plms.py CHANGED
@@ -3,10 +3,9 @@
3
  import torch
4
  import numpy as np
5
  from tqdm import tqdm
6
- from functools import partial
7
- import copy
8
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
 
 
10
  class PLMSSampler(object):
11
  def __init__(self, model, schedule="linear", **kwargs):
12
  super().__init__()
@@ -24,7 +23,7 @@ class PLMSSampler(object):
24
  if ddim_eta != 0:
25
  raise ValueError('ddim_eta must be 0 for PLMS')
26
  self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28
  alphas_cumprod = self.model.alphas_cumprod
29
  assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
  to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
@@ -43,14 +42,14 @@ class PLMSSampler(object):
43
  # ddim sampling parameters
44
  ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
  ddim_timesteps=self.ddim_timesteps,
46
- eta=ddim_eta,verbose=verbose)
47
  self.register_buffer('ddim_sigmas', ddim_sigmas)
48
  self.register_buffer('ddim_alphas', ddim_alphas)
49
  self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50
  self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51
  sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52
  (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54
  self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55
 
56
  @torch.no_grad()
@@ -75,11 +74,8 @@ class PLMSSampler(object):
75
  log_every_t=100,
76
  unconditional_guidance_scale=1.,
77
  unconditional_conditioning=None,
78
- features_adapter1=None,
79
- features_adapter2=None,
80
- mode = 'sketch',
81
- con_strength=30,
82
- style_feature=None,
83
  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
84
  **kwargs
85
  ):
@@ -113,11 +109,8 @@ class PLMSSampler(object):
113
  log_every_t=log_every_t,
114
  unconditional_guidance_scale=unconditional_guidance_scale,
115
  unconditional_conditioning=unconditional_conditioning,
116
- features_adapter1=copy.deepcopy(features_adapter1),
117
- features_adapter2=copy.deepcopy(features_adapter2),
118
- mode = mode,
119
- con_strength = con_strength,
120
- style_feature=style_feature#.clone()
121
  )
122
  return samples, intermediates
123
 
@@ -127,7 +120,8 @@ class PLMSSampler(object):
127
  callback=None, timesteps=None, quantize_denoised=False,
128
  mask=None, x0=None, img_callback=None, log_every_t=100,
129
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
130
- unconditional_guidance_scale=1., unconditional_conditioning=None,features_adapter1=None, features_adapter2=None, mode='sketch', con_strength=30, style_feature=None):
 
131
  device = self.model.betas.device
132
  b = shape[0]
133
  if x_T is None:
@@ -141,7 +135,7 @@ class PLMSSampler(object):
141
  timesteps = self.ddim_timesteps[:subset_end]
142
 
143
  intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
- time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
145
  total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
  print(f"Running PLMS Sampling with {total_steps} timesteps")
147
 
@@ -152,41 +146,21 @@ class PLMSSampler(object):
152
  index = total_steps - i - 1
153
  ts = torch.full((b,), step, device=device, dtype=torch.long)
154
  ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
155
- cond_in = cond
156
- unconditional_conditioning_in = unconditional_conditioning
157
 
158
- if mask is not None :#and index>=10:
159
  assert x0 is not None
160
  img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
161
  img = img_orig * mask + (1. - mask) * img
162
 
163
- if mode == 'sketch':
164
- if index<con_strength:
165
- features_adapter = None
166
- else:
167
- features_adapter = features_adapter1
168
- elif mode == 'style':
169
- if index<con_strength:
170
- features_adapter = None
171
- else:
172
- features_adapter = features_adapter1
173
-
174
- if index>25:
175
- cond_in = torch.cat([cond, style_feature.clone()], dim=1)
176
- unconditional_conditioning_in = torch.cat(
177
- [unconditional_conditioning, unconditional_conditioning[:, -8:, :]], dim=1)
178
- elif mode == 'mul':
179
- features_adapter = [a1i*0.5 + a2i for a1i, a2i in zip(features_adapter1, features_adapter2)]
180
- else:
181
- features_adapter = features_adapter1
182
-
183
- outs = self.p_sample_plms(img, cond_in, ts, index=index, use_original_steps=ddim_use_original_steps,
184
  quantize_denoised=quantize_denoised, temperature=temperature,
185
  noise_dropout=noise_dropout, score_corrector=score_corrector,
186
  corrector_kwargs=corrector_kwargs,
187
  unconditional_guidance_scale=unconditional_guidance_scale,
188
- unconditional_conditioning=unconditional_conditioning_in,
189
- old_eps=old_eps, t_next=ts_next, features_adapter=copy.deepcopy(features_adapter))
 
 
190
 
191
  img, pred_x0, e_t = outs
192
  old_eps.append(e_t)
@@ -204,17 +178,18 @@ class PLMSSampler(object):
204
  @torch.no_grad()
205
  def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
206
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
207
- unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, features_adapter=None):
 
208
  b, *_, device = *x.shape, x.device
209
 
210
  def get_model_output(x, t):
211
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
212
- e_t = self.model.apply_model(x, t, c, copy.deepcopy(features_adapter))
213
  else:
214
  x_in = torch.cat([x] * 2)
215
  t_in = torch.cat([t] * 2)
216
  c_in = torch.cat([unconditional_conditioning, c])
217
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, copy.deepcopy(features_adapter)).chunk(2)
218
  e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
219
 
220
  if score_corrector is not None:
@@ -233,14 +208,14 @@ class PLMSSampler(object):
233
  a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
234
  a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
235
  sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
236
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
237
 
238
  # current prediction for x_0
239
  pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
240
  if quantize_denoised:
241
  pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
242
  # direction pointing to x_t
243
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
244
  noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
245
  if noise_dropout > 0.:
246
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
 
3
  import torch
4
  import numpy as np
5
  from tqdm import tqdm
 
 
6
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
 
8
+
9
  class PLMSSampler(object):
10
  def __init__(self, model, schedule="linear", **kwargs):
11
  super().__init__()
 
23
  if ddim_eta != 0:
24
  raise ValueError('ddim_eta must be 0 for PLMS')
25
  self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
27
  alphas_cumprod = self.model.alphas_cumprod
28
  assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29
  to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
 
42
  # ddim sampling parameters
43
  ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44
  ddim_timesteps=self.ddim_timesteps,
45
+ eta=ddim_eta, verbose=verbose)
46
  self.register_buffer('ddim_sigmas', ddim_sigmas)
47
  self.register_buffer('ddim_alphas', ddim_alphas)
48
  self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49
  self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50
  sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51
  (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53
  self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54
 
55
  @torch.no_grad()
 
74
  log_every_t=100,
75
  unconditional_guidance_scale=1.,
76
  unconditional_conditioning=None,
77
+ features_adapter=None,
78
+ cond_tau=0.4,
 
 
 
79
  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
  **kwargs
81
  ):
 
109
  log_every_t=log_every_t,
110
  unconditional_guidance_scale=unconditional_guidance_scale,
111
  unconditional_conditioning=unconditional_conditioning,
112
+ features_adapter=features_adapter,
113
+ cond_tau=cond_tau
 
 
 
114
  )
115
  return samples, intermediates
116
 
 
120
  callback=None, timesteps=None, quantize_denoised=False,
121
  mask=None, x0=None, img_callback=None, log_every_t=100,
122
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123
+ unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
124
+ cond_tau=0.4):
125
  device = self.model.betas.device
126
  b = shape[0]
127
  if x_T is None:
 
135
  timesteps = self.ddim_timesteps[:subset_end]
136
 
137
  intermediates = {'x_inter': [img], 'pred_x0': [img]}
138
+ time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
139
  total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
140
  print(f"Running PLMS Sampling with {total_steps} timesteps")
141
 
 
146
  index = total_steps - i - 1
147
  ts = torch.full((b,), step, device=device, dtype=torch.long)
148
  ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
 
 
149
 
150
+ if mask is not None: # and index>=10:
151
  assert x0 is not None
152
  img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
153
  img = img_orig * mask + (1. - mask) * img
154
 
155
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  quantize_denoised=quantize_denoised, temperature=temperature,
157
  noise_dropout=noise_dropout, score_corrector=score_corrector,
158
  corrector_kwargs=corrector_kwargs,
159
  unconditional_guidance_scale=unconditional_guidance_scale,
160
+ unconditional_conditioning=unconditional_conditioning,
161
+ old_eps=old_eps, t_next=ts_next,
162
+ features_adapter=None if index < int(
163
+ (1 - cond_tau) * total_steps) else features_adapter)
164
 
165
  img, pred_x0, e_t = outs
166
  old_eps.append(e_t)
 
178
  @torch.no_grad()
179
  def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
180
  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
181
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
182
+ features_adapter=None):
183
  b, *_, device = *x.shape, x.device
184
 
185
  def get_model_output(x, t):
186
  if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
187
+ e_t = self.model.apply_model(x, t, c, features_adapter=features_adapter)
188
  else:
189
  x_in = torch.cat([x] * 2)
190
  t_in = torch.cat([t] * 2)
191
  c_in = torch.cat([unconditional_conditioning, c])
192
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, features_adapter=features_adapter).chunk(2)
193
  e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
194
 
195
  if score_corrector is not None:
 
208
  a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
209
  a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
210
  sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
211
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
212
 
213
  # current prediction for x_0
214
  pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
215
  if quantize_denoised:
216
  pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
217
  # direction pointing to x_t
218
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
219
  noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
220
  if noise_dropout > 0.:
221
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
ldm/modules/attention.py CHANGED
@@ -20,6 +20,10 @@ except:
20
  import os
21
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
 
 
 
 
 
23
  def exists(val):
24
  return val is not None
25
 
 
20
  import os
21
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
 
23
+ if os.environ.get("DISABLE_XFORMERS", "false").lower() == 'true':
24
+ XFORMERS_IS_AVAILBLE = False
25
+
26
+
27
  def exists(val):
28
  return val is not None
29
 
ldm/modules/diffusionmodules/openaimodel.py CHANGED
@@ -1,7 +1,6 @@
1
  from abc import abstractmethod
2
- from functools import partial
3
  import math
4
- from typing import Iterable
5
 
6
  import numpy as np
7
  import torch as th
@@ -18,6 +17,7 @@ from ldm.modules.diffusionmodules.util import (
18
  timestep_embedding,
19
  )
20
  from ldm.modules.attention import SpatialTransformer
 
21
 
22
 
23
  # dummy replace
@@ -270,8 +270,6 @@ class ResBlock(TimestepBlock):
270
  h = out_norm(h) * (1 + scale) + shift
271
  h = out_rest(h)
272
  else:
273
- # print(h.shape, emb_out.shape)
274
- # exit(0)
275
  h = h + emb_out
276
  h = self.out_layers(h)
277
  return self.skip_connection(x) + h
@@ -468,16 +466,16 @@ class UNetModel(nn.Module):
468
  context_dim=None, # custom transformer support
469
  n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
470
  legacy=True,
471
- # l_cond = 4,
 
 
 
472
  ):
473
  super().__init__()
474
-
475
- # print('UNet', context_dim)
476
  if use_spatial_transformer:
477
  assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
478
 
479
  if context_dim is not None:
480
- # print('UNet not none', context_dim, context_dim is not None, context_dim != None, context_dim == "None")
481
  assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
482
  from omegaconf.listconfig import ListConfig
483
  if type(context_dim) == ListConfig:
@@ -496,7 +494,24 @@ class UNetModel(nn.Module):
496
  self.in_channels = in_channels
497
  self.model_channels = model_channels
498
  self.out_channels = out_channels
499
- self.num_res_blocks = num_res_blocks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  self.attention_resolutions = attention_resolutions
501
  self.dropout = dropout
502
  self.channel_mult = channel_mult
@@ -508,9 +523,6 @@ class UNetModel(nn.Module):
508
  self.num_head_channels = num_head_channels
509
  self.num_heads_upsample = num_heads_upsample
510
  self.predict_codebook_ids = n_embed is not None
511
- # self.l_cond = l_cond
512
- # print(self.l_cond)
513
- # exit(0)
514
 
515
  time_embed_dim = model_channels * 4
516
  self.time_embed = nn.Sequential(
@@ -520,7 +532,13 @@ class UNetModel(nn.Module):
520
  )
521
 
522
  if self.num_classes is not None:
523
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
 
 
 
 
 
 
524
 
525
  self.input_blocks = nn.ModuleList(
526
  [
@@ -534,7 +552,7 @@ class UNetModel(nn.Module):
534
  ch = model_channels
535
  ds = 1
536
  for level, mult in enumerate(channel_mult):
537
- for _ in range(num_res_blocks):
538
  layers = [
539
  ResBlock(
540
  ch,
@@ -556,17 +574,25 @@ class UNetModel(nn.Module):
556
  if legacy:
557
  #num_heads = 1
558
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
559
- layers.append(
560
- AttentionBlock(
561
- ch,
562
- use_checkpoint=use_checkpoint,
563
- num_heads=num_heads,
564
- num_head_channels=dim_head,
565
- use_new_attention_order=use_new_attention_order,
566
- ) if not use_spatial_transformer else SpatialTransformer(
567
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 
 
 
 
 
 
 
 
 
568
  )
569
- )
570
  self.input_blocks.append(TimestepEmbedSequential(*layers))
571
  self._feature_size += ch
572
  input_block_chans.append(ch)
@@ -618,8 +644,10 @@ class UNetModel(nn.Module):
618
  num_heads=num_heads,
619
  num_head_channels=dim_head,
620
  use_new_attention_order=use_new_attention_order,
621
- ) if not use_spatial_transformer else SpatialTransformer(
622
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 
 
623
  ),
624
  ResBlock(
625
  ch,
@@ -634,7 +662,7 @@ class UNetModel(nn.Module):
634
 
635
  self.output_blocks = nn.ModuleList([])
636
  for level, mult in list(enumerate(channel_mult))[::-1]:
637
- for i in range(num_res_blocks + 1):
638
  ich = input_block_chans.pop()
639
  layers = [
640
  ResBlock(
@@ -657,18 +685,26 @@ class UNetModel(nn.Module):
657
  if legacy:
658
  #num_heads = 1
659
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
660
- layers.append(
661
- AttentionBlock(
662
- ch,
663
- use_checkpoint=use_checkpoint,
664
- num_heads=num_heads_upsample,
665
- num_head_channels=dim_head,
666
- use_new_attention_order=use_new_attention_order,
667
- ) if not use_spatial_transformer else SpatialTransformer(
668
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 
 
 
 
 
 
 
 
 
669
  )
670
- )
671
- if level and i == num_res_blocks:
672
  out_ch = ch
673
  layers.append(
674
  ResBlock(
@@ -716,7 +752,7 @@ class UNetModel(nn.Module):
716
  self.middle_block.apply(convert_module_to_f32)
717
  self.output_blocks.apply(convert_module_to_f32)
718
 
719
- def forward(self, x, timesteps=None, context=None, y=None, features_adapter=None, step_cur=0,**kwargs):
720
  """
721
  Apply the model to an input batch.
722
  :param x: an [N x C x ...] Tensor of inputs.
@@ -733,21 +769,26 @@ class UNetModel(nn.Module):
733
  emb = self.time_embed(t_emb)
734
 
735
  if self.num_classes is not None:
736
- assert y.shape == (x.shape[0],)
737
  emb = emb + self.label_emb(y)
738
 
739
  h = x.type(self.dtype)
740
 
 
 
 
 
741
  for id, module in enumerate(self.input_blocks):
742
  h = module(h, emb, context)
743
- if ((id+1)%3 == 0) and features_adapter is not None and len(features_adapter):
744
- h = h + features_adapter.pop(0)
 
745
  hs.append(h)
746
  if features_adapter is not None:
747
- assert len(features_adapter)==0, 'Wrong features_adapter'
748
 
749
  h = self.middle_block(h, emb, context)
750
- for id, module in enumerate(self.output_blocks):
751
  h = th.cat([h, hs.pop()], dim=1)
752
  h = module(h, emb, context)
753
  h = h.type(x.dtype)
@@ -755,222 +796,3 @@ class UNetModel(nn.Module):
755
  return self.id_predictor(h)
756
  else:
757
  return self.out(h)
758
-
759
-
760
- class EncoderUNetModel(nn.Module):
761
- """
762
- The half UNet model with attention and timestep embedding.
763
- For usage, see UNet.
764
- """
765
-
766
- def __init__(
767
- self,
768
- image_size,
769
- in_channels,
770
- model_channels,
771
- out_channels,
772
- num_res_blocks,
773
- attention_resolutions,
774
- dropout=0,
775
- channel_mult=(1, 2, 4, 8),
776
- conv_resample=True,
777
- dims=2,
778
- use_checkpoint=False,
779
- use_fp16=False,
780
- num_heads=1,
781
- num_head_channels=-1,
782
- num_heads_upsample=-1,
783
- use_scale_shift_norm=False,
784
- resblock_updown=False,
785
- use_new_attention_order=False,
786
- pool="adaptive",
787
- *args,
788
- **kwargs
789
- ):
790
- super().__init__()
791
-
792
- if num_heads_upsample == -1:
793
- num_heads_upsample = num_heads
794
-
795
- self.in_channels = in_channels
796
- self.model_channels = model_channels
797
- self.out_channels = out_channels
798
- self.num_res_blocks = num_res_blocks
799
- self.attention_resolutions = attention_resolutions
800
- self.dropout = dropout
801
- self.channel_mult = channel_mult
802
- self.conv_resample = conv_resample
803
- self.use_checkpoint = use_checkpoint
804
- self.dtype = th.float16 if use_fp16 else th.float32
805
- self.num_heads = num_heads
806
- self.num_head_channels = num_head_channels
807
- self.num_heads_upsample = num_heads_upsample
808
-
809
- time_embed_dim = model_channels * 4
810
- self.time_embed = nn.Sequential(
811
- linear(model_channels, time_embed_dim),
812
- nn.SiLU(),
813
- linear(time_embed_dim, time_embed_dim),
814
- )
815
-
816
- self.input_blocks = nn.ModuleList(
817
- [
818
- TimestepEmbedSequential(
819
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
820
- )
821
- ]
822
- )
823
- self._feature_size = model_channels
824
- input_block_chans = [model_channels]
825
- ch = model_channels
826
- ds = 1
827
- for level, mult in enumerate(channel_mult):
828
- for _ in range(num_res_blocks):
829
- layers = [
830
- ResBlock(
831
- ch,
832
- time_embed_dim,
833
- dropout,
834
- out_channels=mult * model_channels,
835
- dims=dims,
836
- use_checkpoint=use_checkpoint,
837
- use_scale_shift_norm=use_scale_shift_norm,
838
- )
839
- ]
840
- ch = mult * model_channels
841
- if ds in attention_resolutions:
842
- layers.append(
843
- AttentionBlock(
844
- ch,
845
- use_checkpoint=use_checkpoint,
846
- num_heads=num_heads,
847
- num_head_channels=num_head_channels,
848
- use_new_attention_order=use_new_attention_order,
849
- )
850
- )
851
- self.input_blocks.append(TimestepEmbedSequential(*layers))
852
- self._feature_size += ch
853
- input_block_chans.append(ch)
854
- if level != len(channel_mult) - 1:
855
- out_ch = ch
856
- self.input_blocks.append(
857
- TimestepEmbedSequential(
858
- ResBlock(
859
- ch,
860
- time_embed_dim,
861
- dropout,
862
- out_channels=out_ch,
863
- dims=dims,
864
- use_checkpoint=use_checkpoint,
865
- use_scale_shift_norm=use_scale_shift_norm,
866
- down=True,
867
- )
868
- if resblock_updown
869
- else Downsample(
870
- ch, conv_resample, dims=dims, out_channels=out_ch
871
- )
872
- )
873
- )
874
- ch = out_ch
875
- input_block_chans.append(ch)
876
- ds *= 2
877
- self._feature_size += ch
878
-
879
- self.middle_block = TimestepEmbedSequential(
880
- ResBlock(
881
- ch,
882
- time_embed_dim,
883
- dropout,
884
- dims=dims,
885
- use_checkpoint=use_checkpoint,
886
- use_scale_shift_norm=use_scale_shift_norm,
887
- ),
888
- AttentionBlock(
889
- ch,
890
- use_checkpoint=use_checkpoint,
891
- num_heads=num_heads,
892
- num_head_channels=num_head_channels,
893
- use_new_attention_order=use_new_attention_order,
894
- ),
895
- ResBlock(
896
- ch,
897
- time_embed_dim,
898
- dropout,
899
- dims=dims,
900
- use_checkpoint=use_checkpoint,
901
- use_scale_shift_norm=use_scale_shift_norm,
902
- ),
903
- )
904
- self._feature_size += ch
905
- self.pool = pool
906
- if pool == "adaptive":
907
- self.out = nn.Sequential(
908
- normalization(ch),
909
- nn.SiLU(),
910
- nn.AdaptiveAvgPool2d((1, 1)),
911
- zero_module(conv_nd(dims, ch, out_channels, 1)),
912
- nn.Flatten(),
913
- )
914
- elif pool == "attention":
915
- assert num_head_channels != -1
916
- self.out = nn.Sequential(
917
- normalization(ch),
918
- nn.SiLU(),
919
- AttentionPool2d(
920
- (image_size // ds), ch, num_head_channels, out_channels
921
- ),
922
- )
923
- elif pool == "spatial":
924
- self.out = nn.Sequential(
925
- nn.Linear(self._feature_size, 2048),
926
- nn.ReLU(),
927
- nn.Linear(2048, self.out_channels),
928
- )
929
- elif pool == "spatial_v2":
930
- self.out = nn.Sequential(
931
- nn.Linear(self._feature_size, 2048),
932
- normalization(2048),
933
- nn.SiLU(),
934
- nn.Linear(2048, self.out_channels),
935
- )
936
- else:
937
- raise NotImplementedError(f"Unexpected {pool} pooling")
938
-
939
- def convert_to_fp16(self):
940
- """
941
- Convert the torso of the model to float16.
942
- """
943
- self.input_blocks.apply(convert_module_to_f16)
944
- self.middle_block.apply(convert_module_to_f16)
945
-
946
- def convert_to_fp32(self):
947
- """
948
- Convert the torso of the model to float32.
949
- """
950
- self.input_blocks.apply(convert_module_to_f32)
951
- self.middle_block.apply(convert_module_to_f32)
952
-
953
- def forward(self, x, timesteps):
954
- """
955
- Apply the model to an input batch.
956
- :param x: an [N x C x ...] Tensor of inputs.
957
- :param timesteps: a 1-D batch of timesteps.
958
- :return: an [N x K] Tensor of outputs.
959
- """
960
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
961
-
962
- results = []
963
- h = x.type(self.dtype)
964
- for module in self.input_blocks:
965
- h = module(h, emb)
966
- if self.pool.startswith("spatial"):
967
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
968
- h = self.middle_block(h, emb)
969
- if self.pool.startswith("spatial"):
970
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
971
- h = th.cat(results, axis=-1)
972
- return self.out(h)
973
- else:
974
- h = h.type(x.dtype)
975
- return self.out(h)
976
-
 
1
  from abc import abstractmethod
 
2
  import math
3
+ import torch
4
 
5
  import numpy as np
6
  import torch as th
 
17
  timestep_embedding,
18
  )
19
  from ldm.modules.attention import SpatialTransformer
20
+ from ldm.util import exists
21
 
22
 
23
  # dummy replace
 
270
  h = out_norm(h) * (1 + scale) + shift
271
  h = out_rest(h)
272
  else:
 
 
273
  h = h + emb_out
274
  h = self.out_layers(h)
275
  return self.skip_connection(x) + h
 
466
  context_dim=None, # custom transformer support
467
  n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
468
  legacy=True,
469
+ disable_self_attentions=None,
470
+ num_attention_blocks=None,
471
+ disable_middle_self_attn=False,
472
+ use_linear_in_transformer=False,
473
  ):
474
  super().__init__()
 
 
475
  if use_spatial_transformer:
476
  assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
477
 
478
  if context_dim is not None:
 
479
  assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
480
  from omegaconf.listconfig import ListConfig
481
  if type(context_dim) == ListConfig:
 
494
  self.in_channels = in_channels
495
  self.model_channels = model_channels
496
  self.out_channels = out_channels
497
+ if isinstance(num_res_blocks, int):
498
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
499
+ else:
500
+ if len(num_res_blocks) != len(channel_mult):
501
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
502
+ "as a list/tuple (per-level) with the same length as channel_mult")
503
+ self.num_res_blocks = num_res_blocks
504
+ if disable_self_attentions is not None:
505
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
506
+ assert len(disable_self_attentions) == len(channel_mult)
507
+ if num_attention_blocks is not None:
508
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
509
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
510
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
511
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
512
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
513
+ f"attention will still not be set.")
514
+
515
  self.attention_resolutions = attention_resolutions
516
  self.dropout = dropout
517
  self.channel_mult = channel_mult
 
523
  self.num_head_channels = num_head_channels
524
  self.num_heads_upsample = num_heads_upsample
525
  self.predict_codebook_ids = n_embed is not None
 
 
 
526
 
527
  time_embed_dim = model_channels * 4
528
  self.time_embed = nn.Sequential(
 
532
  )
533
 
534
  if self.num_classes is not None:
535
+ if isinstance(self.num_classes, int):
536
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
537
+ elif self.num_classes == "continuous":
538
+ print("setting up linear c_adm embedding layer")
539
+ self.label_emb = nn.Linear(1, time_embed_dim)
540
+ else:
541
+ raise ValueError()
542
 
543
  self.input_blocks = nn.ModuleList(
544
  [
 
552
  ch = model_channels
553
  ds = 1
554
  for level, mult in enumerate(channel_mult):
555
+ for nr in range(self.num_res_blocks[level]):
556
  layers = [
557
  ResBlock(
558
  ch,
 
574
  if legacy:
575
  #num_heads = 1
576
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
577
+ if exists(disable_self_attentions):
578
+ disabled_sa = disable_self_attentions[level]
579
+ else:
580
+ disabled_sa = False
581
+
582
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
583
+ layers.append(
584
+ AttentionBlock(
585
+ ch,
586
+ use_checkpoint=use_checkpoint,
587
+ num_heads=num_heads,
588
+ num_head_channels=dim_head,
589
+ use_new_attention_order=use_new_attention_order,
590
+ ) if not use_spatial_transformer else SpatialTransformer(
591
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
592
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
593
+ use_checkpoint=use_checkpoint
594
+ )
595
  )
 
596
  self.input_blocks.append(TimestepEmbedSequential(*layers))
597
  self._feature_size += ch
598
  input_block_chans.append(ch)
 
644
  num_heads=num_heads,
645
  num_head_channels=dim_head,
646
  use_new_attention_order=use_new_attention_order,
647
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
648
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
649
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
650
+ use_checkpoint=use_checkpoint
651
  ),
652
  ResBlock(
653
  ch,
 
662
 
663
  self.output_blocks = nn.ModuleList([])
664
  for level, mult in list(enumerate(channel_mult))[::-1]:
665
+ for i in range(self.num_res_blocks[level] + 1):
666
  ich = input_block_chans.pop()
667
  layers = [
668
  ResBlock(
 
685
  if legacy:
686
  #num_heads = 1
687
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
688
+ if exists(disable_self_attentions):
689
+ disabled_sa = disable_self_attentions[level]
690
+ else:
691
+ disabled_sa = False
692
+
693
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
694
+ layers.append(
695
+ AttentionBlock(
696
+ ch,
697
+ use_checkpoint=use_checkpoint,
698
+ num_heads=num_heads_upsample,
699
+ num_head_channels=dim_head,
700
+ use_new_attention_order=use_new_attention_order,
701
+ ) if not use_spatial_transformer else SpatialTransformer(
702
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
703
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
704
+ use_checkpoint=use_checkpoint
705
+ )
706
  )
707
+ if level and i == self.num_res_blocks[level]:
 
708
  out_ch = ch
709
  layers.append(
710
  ResBlock(
 
752
  self.middle_block.apply(convert_module_to_f32)
753
  self.output_blocks.apply(convert_module_to_f32)
754
 
755
+ def forward(self, x, timesteps=None, context=None, y=None, features_adapter=None, append_to_context=None, **kwargs):
756
  """
757
  Apply the model to an input batch.
758
  :param x: an [N x C x ...] Tensor of inputs.
 
769
  emb = self.time_embed(t_emb)
770
 
771
  if self.num_classes is not None:
772
+ assert y.shape[0] == x.shape[0]
773
  emb = emb + self.label_emb(y)
774
 
775
  h = x.type(self.dtype)
776
 
777
+ if append_to_context is not None:
778
+ context = torch.cat([context, append_to_context], dim=1)
779
+
780
+ adapter_idx = 0
781
  for id, module in enumerate(self.input_blocks):
782
  h = module(h, emb, context)
783
+ if ((id+1)%3 == 0) and features_adapter is not None:
784
+ h = h + features_adapter[adapter_idx]
785
+ adapter_idx += 1
786
  hs.append(h)
787
  if features_adapter is not None:
788
+ assert len(features_adapter)==adapter_idx, 'Wrong features_adapter'
789
 
790
  h = self.middle_block(h, emb, context)
791
+ for module in self.output_blocks:
792
  h = th.cat([h, hs.pop()], dim=1)
793
  h = module(h, emb, context)
794
  h = h.type(x.dtype)
 
796
  return self.id_predictor(h)
797
  else:
798
  return self.out(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/diffusionmodules/util.py CHANGED
@@ -122,7 +122,9 @@ class CheckpointFunction(torch.autograd.Function):
122
  ctx.run_function = run_function
123
  ctx.input_tensors = list(args[:length])
124
  ctx.input_params = list(args[length:])
125
-
 
 
126
  with torch.no_grad():
127
  output_tensors = ctx.run_function(*ctx.input_tensors)
128
  return output_tensors
@@ -130,7 +132,8 @@ class CheckpointFunction(torch.autograd.Function):
130
  @staticmethod
131
  def backward(ctx, *output_grads):
132
  ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133
- with torch.enable_grad():
 
134
  # Fixes a bug where the first op in run_function modifies the
135
  # Tensor storage in place, which is not allowed for detach()'d
136
  # Tensors.
 
122
  ctx.run_function = run_function
123
  ctx.input_tensors = list(args[:length])
124
  ctx.input_params = list(args[length:])
125
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
126
+ "dtype": torch.get_autocast_gpu_dtype(),
127
+ "cache_enabled": torch.is_autocast_cache_enabled()}
128
  with torch.no_grad():
129
  output_tensors = ctx.run_function(*ctx.input_tensors)
130
  return output_tensors
 
132
  @staticmethod
133
  def backward(ctx, *output_grads):
134
  ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
135
+ with torch.enable_grad(), \
136
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
137
  # Fixes a bug where the first op in run_function modifies the
138
  # Tensor storage in place, which is not allowed for detach()'d
139
  # Tensors.
ldm/modules/ema.py CHANGED
@@ -10,24 +10,28 @@ class LitEma(nn.Module):
10
 
11
  self.m_name2s_name = {}
12
  self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
- self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
- else torch.tensor(-1,dtype=torch.int))
15
 
16
  for name, p in model.named_parameters():
17
  if p.requires_grad:
18
- #remove as '.'-character is not allowed in buffers
19
- s_name = name.replace('.','')
20
- self.m_name2s_name.update({name:s_name})
21
- self.register_buffer(s_name,p.clone().detach().data)
22
 
23
  self.collected_params = []
24
 
25
- def forward(self,model):
 
 
 
 
26
  decay = self.decay
27
 
28
  if self.num_updates >= 0:
29
  self.num_updates += 1
30
- decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
 
32
  one_minus_decay = 1.0 - decay
33
 
 
10
 
11
  self.m_name2s_name = {}
12
  self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1, dtype=torch.int))
15
 
16
  for name, p in model.named_parameters():
17
  if p.requires_grad:
18
+ # remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.', '')
20
+ self.m_name2s_name.update({name: s_name})
21
+ self.register_buffer(s_name, p.clone().detach().data)
22
 
23
  self.collected_params = []
24
 
25
+ def reset_num_updates(self):
26
+ del self.num_updates
27
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28
+
29
+ def forward(self, model):
30
  decay = self.decay
31
 
32
  if self.num_updates >= 0:
33
  self.num_updates += 1
34
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35
 
36
  one_minus_decay = 1.0 - decay
37
 
ldm/modules/encoders/adapter.py CHANGED
@@ -1,9 +1,8 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
- from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock
5
  from collections import OrderedDict
6
 
 
7
  def conv_nd(dims, *args, **kwargs):
8
  """
9
  Create a 1D, 2D, or 3D convolution module.
@@ -16,6 +15,7 @@ def conv_nd(dims, *args, **kwargs):
16
  return nn.Conv3d(*args, **kwargs)
17
  raise ValueError(f"unsupported dimensions: {dims}")
18
 
 
19
  def avg_pool_nd(dims, *args, **kwargs):
20
  """
21
  Create a 1D, 2D, or 3D average pooling module.
@@ -28,6 +28,7 @@ def avg_pool_nd(dims, *args, **kwargs):
28
  return nn.AvgPool3d(*args, **kwargs)
29
  raise ValueError(f"unsupported dimensions: {dims}")
30
 
 
31
  class Downsample(nn.Module):
32
  """
33
  A downsampling layer with an optional convolution.
@@ -37,7 +38,7 @@ class Downsample(nn.Module):
37
  downsampling occurs in the inner-two dimensions.
38
  """
39
 
40
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
41
  super().__init__()
42
  self.channels = channels
43
  self.out_channels = out_channels or channels
@@ -60,15 +61,16 @@ class Downsample(nn.Module):
60
  class ResnetBlock(nn.Module):
61
  def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
62
  super().__init__()
63
- ps = ksize//2
64
- if in_c != out_c or sk==False:
65
  self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
66
  else:
 
67
  self.in_conv = None
68
  self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
69
  self.act = nn.ReLU()
70
  self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
71
- if sk==False:
72
  self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
73
  else:
74
  self.skep = None
@@ -80,7 +82,7 @@ class ResnetBlock(nn.Module):
80
  def forward(self, x):
81
  if self.down == True:
82
  x = self.down_opt(x)
83
- if self.in_conv is not None: # edit
84
  x = self.in_conv(x)
85
 
86
  h = self.block1(x)
@@ -101,12 +103,14 @@ class Adapter(nn.Module):
101
  self.body = []
102
  for i in range(len(channels)):
103
  for j in range(nums_rb):
104
- if (i!=0) and (j==0):
105
- self.body.append(ResnetBlock(channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
 
106
  else:
107
- self.body.append(ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
 
108
  self.body = nn.ModuleList(self.body)
109
- self.conv_in = nn.Conv2d(cin,channels[0], 3, 1, 1)
110
 
111
  def forward(self, x):
112
  # unshuffle
@@ -116,12 +120,79 @@ class Adapter(nn.Module):
116
  x = self.conv_in(x)
117
  for i in range(len(self.channels)):
118
  for j in range(self.nums_rb):
119
- idx = i*self.nums_rb +j
120
  x = self.body[idx](x)
121
  features.append(x)
122
 
123
  return features
124
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  class ResnetBlock_light(nn.Module):
127
  def __init__(self, in_c):
@@ -185,66 +256,3 @@ class Adapter_light(nn.Module):
185
  features.append(x)
186
 
187
  return features
188
-
189
- class QuickGELU(nn.Module):
190
-
191
- def forward(self, x: torch.Tensor):
192
- return x * torch.sigmoid(1.702 * x)
193
-
194
- class ResidualAttentionBlock(nn.Module):
195
-
196
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
197
- super().__init__()
198
-
199
- self.attn = nn.MultiheadAttention(d_model, n_head)
200
- self.ln_1 = LayerNorm(d_model)
201
- self.mlp = nn.Sequential(
202
- OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
203
- ("c_proj", nn.Linear(d_model * 4, d_model))]))
204
- self.ln_2 = LayerNorm(d_model)
205
- self.attn_mask = attn_mask
206
-
207
- def attention(self, x: torch.Tensor):
208
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
209
- return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
210
-
211
- def forward(self, x: torch.Tensor):
212
- x = x + self.attention(self.ln_1(x))
213
- x = x + self.mlp(self.ln_2(x))
214
- return x
215
-
216
- class LayerNorm(nn.LayerNorm):
217
- """Subclass torch's LayerNorm to handle fp16."""
218
-
219
- def forward(self, x: torch.Tensor):
220
- orig_type = x.dtype
221
- ret = super().forward(x.type(torch.float32))
222
- return ret.type(orig_type)
223
-
224
- class StyleAdapter(nn.Module):
225
-
226
- def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
227
- super().__init__()
228
-
229
- scale = width ** -0.5
230
- self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
231
- self.num_token = num_token
232
- self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
233
- self.ln_post = LayerNorm(width)
234
- self.ln_pre = LayerNorm(width)
235
- self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
236
-
237
- def forward(self, x):
238
- # x shape [N, HW+1, C]
239
- style_embedding = self.style_embedding + torch.zeros(
240
- (x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
241
- x = torch.cat([x, style_embedding], dim=1)
242
- x = self.ln_pre(x)
243
- x = x.permute(1, 0, 2) # NLD -> LND
244
- x = self.transformer_layes(x)
245
- x = x.permute(1, 0, 2) # LND -> NLD
246
-
247
- x = self.ln_post(x[:, -self.num_token:, :])
248
- x = x @ self.proj
249
-
250
- return x
 
1
  import torch
2
  import torch.nn as nn
 
 
3
  from collections import OrderedDict
4
 
5
+
6
  def conv_nd(dims, *args, **kwargs):
7
  """
8
  Create a 1D, 2D, or 3D convolution module.
 
15
  return nn.Conv3d(*args, **kwargs)
16
  raise ValueError(f"unsupported dimensions: {dims}")
17
 
18
+
19
  def avg_pool_nd(dims, *args, **kwargs):
20
  """
21
  Create a 1D, 2D, or 3D average pooling module.
 
28
  return nn.AvgPool3d(*args, **kwargs)
29
  raise ValueError(f"unsupported dimensions: {dims}")
30
 
31
+
32
  class Downsample(nn.Module):
33
  """
34
  A downsampling layer with an optional convolution.
 
38
  downsampling occurs in the inner-two dimensions.
39
  """
40
 
41
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
42
  super().__init__()
43
  self.channels = channels
44
  self.out_channels = out_channels or channels
 
61
  class ResnetBlock(nn.Module):
62
  def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
63
  super().__init__()
64
+ ps = ksize // 2
65
+ if in_c != out_c or sk == False:
66
  self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
67
  else:
68
+ # print('n_in')
69
  self.in_conv = None
70
  self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
71
  self.act = nn.ReLU()
72
  self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
73
+ if sk == False:
74
  self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
75
  else:
76
  self.skep = None
 
82
  def forward(self, x):
83
  if self.down == True:
84
  x = self.down_opt(x)
85
+ if self.in_conv is not None: # edit
86
  x = self.in_conv(x)
87
 
88
  h = self.block1(x)
 
103
  self.body = []
104
  for i in range(len(channels)):
105
  for j in range(nums_rb):
106
+ if (i != 0) and (j == 0):
107
+ self.body.append(
108
+ ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
109
  else:
110
+ self.body.append(
111
+ ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
112
  self.body = nn.ModuleList(self.body)
113
+ self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
114
 
115
  def forward(self, x):
116
  # unshuffle
 
120
  x = self.conv_in(x)
121
  for i in range(len(self.channels)):
122
  for j in range(self.nums_rb):
123
+ idx = i * self.nums_rb + j
124
  x = self.body[idx](x)
125
  features.append(x)
126
 
127
  return features
128
+
129
+
130
+ class LayerNorm(nn.LayerNorm):
131
+ """Subclass torch's LayerNorm to handle fp16."""
132
+
133
+ def forward(self, x: torch.Tensor):
134
+ orig_type = x.dtype
135
+ ret = super().forward(x.type(torch.float32))
136
+ return ret.type(orig_type)
137
+
138
+
139
+ class QuickGELU(nn.Module):
140
+
141
+ def forward(self, x: torch.Tensor):
142
+ return x * torch.sigmoid(1.702 * x)
143
+
144
+
145
+ class ResidualAttentionBlock(nn.Module):
146
+
147
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
148
+ super().__init__()
149
+
150
+ self.attn = nn.MultiheadAttention(d_model, n_head)
151
+ self.ln_1 = LayerNorm(d_model)
152
+ self.mlp = nn.Sequential(
153
+ OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
154
+ ("c_proj", nn.Linear(d_model * 4, d_model))]))
155
+ self.ln_2 = LayerNorm(d_model)
156
+ self.attn_mask = attn_mask
157
+
158
+ def attention(self, x: torch.Tensor):
159
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
160
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
161
+
162
+ def forward(self, x: torch.Tensor):
163
+ x = x + self.attention(self.ln_1(x))
164
+ x = x + self.mlp(self.ln_2(x))
165
+ return x
166
+
167
+
168
+ class StyleAdapter(nn.Module):
169
+
170
+ def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
171
+ super().__init__()
172
+
173
+ scale = width ** -0.5
174
+ self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
175
+ self.num_token = num_token
176
+ self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
177
+ self.ln_post = LayerNorm(width)
178
+ self.ln_pre = LayerNorm(width)
179
+ self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
180
+
181
+ def forward(self, x):
182
+ # x shape [N, HW+1, C]
183
+ style_embedding = self.style_embedding + torch.zeros(
184
+ (x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
185
+ x = torch.cat([x, style_embedding], dim=1)
186
+ x = self.ln_pre(x)
187
+ x = x.permute(1, 0, 2) # NLD -> LND
188
+ x = self.transformer_layes(x)
189
+ x = x.permute(1, 0, 2) # LND -> NLD
190
+
191
+ x = self.ln_post(x[:, -self.num_token:, :])
192
+ x = x @ self.proj
193
+
194
+ return x
195
+
196
 
197
  class ResnetBlock_light(nn.Module):
198
  def __init__(self, in_c):
 
256
  features.append(x)
257
 
258
  return features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/encoders/modules.py CHANGED
@@ -1,12 +1,13 @@
1
  import torch
2
  import torch.nn as nn
3
- from functools import partial
4
- import clip
5
- from einops import rearrange, repeat
6
- from transformers import CLIPTokenizer, CLIPTextModel
7
- import kornia
8
 
9
- from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
 
 
 
 
10
 
11
 
12
  class AbstractEncoder(nn.Module):
@@ -17,6 +18,11 @@ class AbstractEncoder(nn.Module):
17
  raise NotImplementedError
18
 
19
 
 
 
 
 
 
20
 
21
  class ClassEmbedder(nn.Module):
22
  def __init__(self, embed_dim, n_classes=1000, key='class'):
@@ -33,116 +39,48 @@ class ClassEmbedder(nn.Module):
33
  return c
34
 
35
 
36
- class TransformerEmbedder(AbstractEncoder):
37
- """Some transformer encoder layers"""
38
- def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39
  super().__init__()
 
 
40
  self.device = device
41
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42
- attn_layers=Encoder(dim=n_embed, depth=n_layer))
43
-
44
- def forward(self, tokens):
45
- tokens = tokens.to(self.device) # meh
46
- z = self.transformer(tokens, return_embeddings=True)
47
- return z
48
 
49
- def encode(self, x):
50
- return self(x)
51
-
52
-
53
- class BERTTokenizer(AbstractEncoder):
54
- """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55
- def __init__(self, device="cuda", vq_interface=True, max_length=77):
56
- super().__init__()
57
- from transformers import BertTokenizerFast # TODO: add to reuquirements
58
- self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
59
- self.device = device
60
- self.vq_interface = vq_interface
61
- self.max_length = max_length
62
 
63
  def forward(self, text):
64
  batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
65
  return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
66
  tokens = batch_encoding["input_ids"].to(self.device)
67
- return tokens
68
-
69
- @torch.no_grad()
70
- def encode(self, text):
71
- tokens = self(text)
72
- if not self.vq_interface:
73
- return tokens
74
- return None, None, [None, None, tokens]
75
-
76
- def decode(self, text):
77
- return text
78
-
79
-
80
- class BERTEmbedder(AbstractEncoder):
81
- """Uses the BERT tokenizr model and add some transformer encoder layers"""
82
- def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
83
- device="cuda",use_tokenizer=True, embedding_dropout=0.0):
84
- super().__init__()
85
- self.use_tknz_fn = use_tokenizer
86
- if self.use_tknz_fn:
87
- self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
88
- self.device = device
89
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
90
- attn_layers=Encoder(dim=n_embed, depth=n_layer),
91
- emb_dropout=embedding_dropout)
92
 
93
- def forward(self, text):
94
- if self.use_tknz_fn:
95
- tokens = self.tknz_fn(text)#.to(self.device)
96
- else:
97
- tokens = text
98
- z = self.transformer(tokens, return_embeddings=True)
99
  return z
100
 
101
  def encode(self, text):
102
- # output of length 77
103
  return self(text)
104
 
105
 
106
- class SpatialRescaler(nn.Module):
107
- def __init__(self,
108
- n_stages=1,
109
- method='bilinear',
110
- multiplier=0.5,
111
- in_channels=3,
112
- out_channels=None,
113
- bias=False):
114
- super().__init__()
115
- self.n_stages = n_stages
116
- assert self.n_stages >= 0
117
- assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
118
- self.multiplier = multiplier
119
- self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
120
- self.remap_output = out_channels is not None
121
- if self.remap_output:
122
- print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
123
- self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
124
-
125
- def forward(self,x):
126
- for stage in range(self.n_stages):
127
- x = self.interpolator(x, scale_factor=self.multiplier)
128
-
129
-
130
- if self.remap_output:
131
- x = self.channel_mapper(x)
132
- return x
133
-
134
- def encode(self, x):
135
- return self(x)
136
-
137
  class FrozenCLIPEmbedder(AbstractEncoder):
138
- """Uses the CLIP transformer encoder for text (from Hugging Face)"""
139
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
 
140
  super().__init__()
141
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
142
- self.transformer = CLIPTextModel.from_pretrained(version)
143
  self.device = device
144
  self.max_length = max_length
145
- self.freeze()
 
 
146
 
147
  def freeze(self):
148
  self.transformer = self.transformer.eval()
@@ -153,26 +91,47 @@ class FrozenCLIPEmbedder(AbstractEncoder):
153
  batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
154
  return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
155
  tokens = batch_encoding["input_ids"].to(self.device)
156
- outputs = self.transformer(input_ids=tokens)
157
 
158
- z = outputs.last_hidden_state
 
 
 
 
159
  return z
160
 
161
  def encode(self, text):
162
  return self(text)
163
 
164
 
165
- class FrozenCLIPTextEmbedder(nn.Module):
166
  """
167
- Uses the CLIP transformer encoder for text.
168
  """
169
- def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
 
 
 
 
 
 
170
  super().__init__()
171
- self.model, _ = clip.load(version, jit=False, device="cpu")
 
 
 
 
172
  self.device = device
173
  self.max_length = max_length
174
- self.n_repeat = n_repeat
175
- self.normalize = normalize
 
 
 
 
 
 
 
176
 
177
  def freeze(self):
178
  self.model = self.model.eval()
@@ -180,55 +139,303 @@ class FrozenCLIPTextEmbedder(nn.Module):
180
  param.requires_grad = False
181
 
182
  def forward(self, text):
183
- tokens = clip.tokenize(text).to(self.device)
184
- z = self.model.encode_text(tokens)
185
- if self.normalize:
186
- z = z / torch.linalg.norm(z, dim=1, keepdim=True)
187
  return z
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def encode(self, text):
190
- z = self(text)
191
- if z.ndim==2:
192
- z = z[:, None, :]
193
- z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
194
- return z
195
 
196
 
197
- class FrozenClipImageEmbedder(nn.Module):
198
- """
199
- Uses the CLIP image encoder.
200
- """
201
- def __init__(
202
- self,
203
- model,
204
- jit=False,
205
- device='cuda' if torch.cuda.is_available() else 'cpu',
206
- antialias=False,
207
- ):
208
  super().__init__()
209
- self.model, _ = clip.load(name=model, device=device, jit=jit)
 
 
 
210
 
211
- self.antialias = antialias
 
212
 
213
- self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
214
- self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- def preprocess(self, x):
217
- # normalize to [0,1]
218
- x = kornia.geometry.resize(x, (224, 224),
219
- interpolation='bicubic',align_corners=True,
220
- antialias=self.antialias)
221
- x = (x + 1.) / 2.
222
- # renormalize according to clip
223
- x = kornia.enhance.normalize(x, self.mean, self.std)
224
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- def forward(self, x):
227
- # x is assumed to be in range [-1,1]
228
- return self.model.encode_image(self.preprocess(x))
229
 
230
 
231
  if __name__ == "__main__":
232
- from ldm.util import count_params
233
  model = FrozenCLIPEmbedder()
234
- count_params(model, verbose=True)
 
1
  import torch
2
  import torch.nn as nn
3
+ import math
4
+ from torch.utils.checkpoint import checkpoint
 
 
 
5
 
6
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, CLIPModel
7
+
8
+ import open_clip
9
+ import re
10
+ from ldm.util import default, count_params
11
 
12
 
13
  class AbstractEncoder(nn.Module):
 
18
  raise NotImplementedError
19
 
20
 
21
+ class IdentityEncoder(AbstractEncoder):
22
+
23
+ def encode(self, x):
24
+ return x
25
+
26
 
27
  class ClassEmbedder(nn.Module):
28
  def __init__(self, embed_dim, n_classes=1000, key='class'):
 
39
  return c
40
 
41
 
42
+ class FrozenT5Embedder(AbstractEncoder):
43
+ """Uses the T5 transformer encoder for text"""
44
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
45
  super().__init__()
46
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
47
+ self.transformer = T5EncoderModel.from_pretrained(version)
48
  self.device = device
49
+ self.max_length = max_length # TODO: typical value?
50
+ if freeze:
51
+ self.freeze()
 
 
 
 
52
 
53
+ def freeze(self):
54
+ self.transformer = self.transformer.eval()
55
+ #self.train = disabled_train
56
+ for param in self.parameters():
57
+ param.requires_grad = False
 
 
 
 
 
 
 
 
58
 
59
  def forward(self, text):
60
  batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
61
  return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
62
  tokens = batch_encoding["input_ids"].to(self.device)
63
+ outputs = self.transformer(input_ids=tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ z = outputs.last_hidden_state
 
 
 
 
 
66
  return z
67
 
68
  def encode(self, text):
 
69
  return self(text)
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  class FrozenCLIPEmbedder(AbstractEncoder):
73
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
74
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
75
+ freeze=True, layer="last"): # clip-vit-base-patch32
76
  super().__init__()
77
  self.tokenizer = CLIPTokenizer.from_pretrained(version)
78
+ self.transformer = CLIPModel.from_pretrained(version).text_model
79
  self.device = device
80
  self.max_length = max_length
81
+ if freeze:
82
+ self.freeze()
83
+ self.layer = layer
84
 
85
  def freeze(self):
86
  self.transformer = self.transformer.eval()
 
91
  batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
92
  return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
93
  tokens = batch_encoding["input_ids"].to(self.device)
94
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer != 'last')
95
 
96
+ if self.layer == 'penultimate':
97
+ z = outputs.hidden_states[-2]
98
+ z = self.transformer.final_layer_norm(z)
99
+ else:
100
+ z = outputs.last_hidden_state
101
  return z
102
 
103
  def encode(self, text):
104
  return self(text)
105
 
106
 
107
+ class FrozenOpenCLIPEmbedder(AbstractEncoder):
108
  """
109
+ Uses the OpenCLIP transformer encoder for text
110
  """
111
+ LAYERS = [
112
+ #"pooled",
113
+ "last",
114
+ "penultimate"
115
+ ]
116
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
117
+ freeze=True, layer="last"):
118
  super().__init__()
119
+ assert layer in self.LAYERS
120
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
121
+ del model.visual
122
+ self.model = model
123
+
124
  self.device = device
125
  self.max_length = max_length
126
+ if freeze:
127
+ self.freeze()
128
+ self.layer = layer
129
+ if self.layer == "last":
130
+ self.layer_idx = 0
131
+ elif self.layer == "penultimate":
132
+ self.layer_idx = 1
133
+ else:
134
+ raise NotImplementedError()
135
 
136
  def freeze(self):
137
  self.model = self.model.eval()
 
139
  param.requires_grad = False
140
 
141
  def forward(self, text):
142
+ tokens = open_clip.tokenize(text)
143
+ z = self.encode_with_transformer(tokens.to(self.device))
 
 
144
  return z
145
 
146
+ def encode_with_transformer(self, text):
147
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
148
+ x = x + self.model.positional_embedding
149
+ x = x.permute(1, 0, 2) # NLD -> LND
150
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
151
+ x = x.permute(1, 0, 2) # LND -> NLD
152
+ x = self.model.ln_final(x)
153
+ return x
154
+
155
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
156
+ for i, r in enumerate(self.model.transformer.resblocks):
157
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
158
+ break
159
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
160
+ x = checkpoint(r, x, attn_mask)
161
+ else:
162
+ x = r(x, attn_mask=attn_mask)
163
+ return x
164
+
165
  def encode(self, text):
166
+ return self(text)
 
 
 
 
167
 
168
 
169
+ class FrozenCLIPT5Encoder(AbstractEncoder):
170
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
171
+ clip_max_length=77, t5_max_length=77):
 
 
 
 
 
 
 
 
172
  super().__init__()
173
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
174
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
175
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
176
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
177
 
178
+ def encode(self, text):
179
+ return self(text)
180
 
181
+ def forward(self, text):
182
+ clip_z = self.clip_encoder.encode(text)
183
+ t5_z = self.t5_encoder.encode(text)
184
+ return [clip_z, t5_z]
185
+
186
+
187
+ # code from sd-webui
188
+ re_attention = re.compile(r"""
189
+ \\\(|
190
+ \\\)|
191
+ \\\[|
192
+ \\]|
193
+ \\\\|
194
+ \\|
195
+ \(|
196
+ \[|
197
+ :([+-]?[.\d]+)\)|
198
+ \)|
199
+ ]|
200
+ [^\\()\[\]:]+|
201
+ :
202
+ """, re.X)
203
+
204
+
205
+ def parse_prompt_attention(text):
206
+ """
207
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
208
+ Accepted tokens are:
209
+ (abc) - increases attention to abc by a multiplier of 1.1
210
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
211
+ [abc] - decreases attention to abc by a multiplier of 1.1
212
+ \( - literal character '('
213
+ \[ - literal character '['
214
+ \) - literal character ')'
215
+ \] - literal character ']'
216
+ \\ - literal character '\'
217
+ anything else - just text
218
+
219
+ >>> parse_prompt_attention('normal text')
220
+ [['normal text', 1.0]]
221
+ >>> parse_prompt_attention('an (important) word')
222
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
223
+ >>> parse_prompt_attention('(unbalanced')
224
+ [['unbalanced', 1.1]]
225
+ >>> parse_prompt_attention('\(literal\]')
226
+ [['(literal]', 1.0]]
227
+ >>> parse_prompt_attention('(unnecessary)(parens)')
228
+ [['unnecessaryparens', 1.1]]
229
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
230
+ [['a ', 1.0],
231
+ ['house', 1.5730000000000004],
232
+ [' ', 1.1],
233
+ ['on', 1.0],
234
+ [' a ', 1.1],
235
+ ['hill', 0.55],
236
+ [', sun, ', 1.1],
237
+ ['sky', 1.4641000000000006],
238
+ ['.', 1.1]]
239
+ """
240
 
241
+ res = []
242
+ round_brackets = []
243
+ square_brackets = []
244
+
245
+ round_bracket_multiplier = 1.1
246
+ square_bracket_multiplier = 1 / 1.1
247
+
248
+ def multiply_range(start_position, multiplier):
249
+ for p in range(start_position, len(res)):
250
+ res[p][1] *= multiplier
251
+
252
+ for m in re_attention.finditer(text):
253
+ text = m.group(0)
254
+ weight = m.group(1)
255
+
256
+ if text.startswith('\\'):
257
+ res.append([text[1:], 1.0])
258
+ elif text == '(':
259
+ round_brackets.append(len(res))
260
+ elif text == '[':
261
+ square_brackets.append(len(res))
262
+ elif weight is not None and len(round_brackets) > 0:
263
+ multiply_range(round_brackets.pop(), float(weight))
264
+ elif text == ')' and len(round_brackets) > 0:
265
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
266
+ elif text == ']' and len(square_brackets) > 0:
267
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
268
+ else:
269
+ res.append([text, 1.0])
270
+
271
+ for pos in round_brackets:
272
+ multiply_range(pos, round_bracket_multiplier)
273
+
274
+ for pos in square_brackets:
275
+ multiply_range(pos, square_bracket_multiplier)
276
+
277
+ if len(res) == 0:
278
+ res = [["", 1.0]]
279
+
280
+ # merge runs of identical weights
281
+ i = 0
282
+ while i + 1 < len(res):
283
+ if res[i][1] == res[i + 1][1]:
284
+ res[i][0] += res[i + 1][0]
285
+ res.pop(i + 1)
286
+ else:
287
+ i += 1
288
+
289
+ return res
290
+
291
+ class WebUIFrozenCLIPEmebedder(AbstractEncoder):
292
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", freeze=True, layer="penultimate"):
293
+ super(WebUIFrozenCLIPEmebedder, self).__init__()
294
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
295
+ self.transformer = CLIPModel.from_pretrained(version).text_model
296
+ self.device = device
297
+ self.layer = layer
298
+ if freeze:
299
+ self.freeze()
300
+
301
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
302
+ self.comma_padding_backtrack = 20
303
+
304
+ def freeze(self):
305
+ self.transformer = self.transformer.eval()
306
+ for param in self.parameters():
307
+ param.requires_grad = False
308
+
309
+ def tokenize(self, texts):
310
+ tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
311
+ return tokenized
312
+
313
+ def encode_with_transformers(self, tokens):
314
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer!='last')
315
+
316
+ if self.layer == 'penultimate':
317
+ z = outputs.hidden_states[-2]
318
+ z = self.transformer.final_layer_norm(z)
319
+ else:
320
+ z = outputs.last_hidden_state
321
+
322
+ return z
323
+
324
+ def tokenize_line(self, line):
325
+ parsed = parse_prompt_attention(line)
326
+ # print(parsed)
327
+
328
+ tokenized = self.tokenize([text for text, _ in parsed])
329
+
330
+ remade_tokens = []
331
+ multipliers = []
332
+ last_comma = -1
333
+
334
+ for tokens, (text, weight) in zip(tokenized, parsed):
335
+ i = 0
336
+ while i < len(tokens):
337
+ token = tokens[i]
338
+
339
+ if token == self.comma_token:
340
+ last_comma = len(remade_tokens)
341
+ elif self.comma_padding_backtrack != 0 and max(len(remade_tokens),
342
+ 1) % 75 == 0 and last_comma != -1 and len(
343
+ remade_tokens) - last_comma <= self.comma_padding_backtrack:
344
+ last_comma += 1
345
+ reloc_tokens = remade_tokens[last_comma:]
346
+ reloc_mults = multipliers[last_comma:]
347
+
348
+ remade_tokens = remade_tokens[:last_comma]
349
+ length = len(remade_tokens)
350
+
351
+ rem = int(math.ceil(length / 75)) * 75 - length
352
+ remade_tokens += [self.tokenizer.eos_token_id] * rem + reloc_tokens
353
+ multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
354
+
355
+ remade_tokens.append(token)
356
+ multipliers.append(weight)
357
+ i += 1
358
+
359
+ token_count = len(remade_tokens)
360
+ prompt_target_length = math.ceil(max(token_count, 1) / 75) * 75
361
+ tokens_to_add = prompt_target_length - len(remade_tokens)
362
+
363
+ remade_tokens = remade_tokens + [self.tokenizer.eos_token_id] * tokens_to_add
364
+ multipliers = multipliers + [1.0] * tokens_to_add
365
+
366
+ return remade_tokens, multipliers, token_count
367
+
368
+ def process_text(self, texts):
369
+ remade_batch_tokens = []
370
+ token_count = 0
371
+
372
+ cache = {}
373
+ batch_multipliers = []
374
+ for line in texts:
375
+ if line in cache:
376
+ remade_tokens, multipliers = cache[line]
377
+ else:
378
+ remade_tokens, multipliers, current_token_count = self.tokenize_line(line)
379
+ token_count = max(current_token_count, token_count)
380
+
381
+ cache[line] = (remade_tokens, multipliers)
382
+
383
+ remade_batch_tokens.append(remade_tokens)
384
+ batch_multipliers.append(multipliers)
385
+
386
+ return batch_multipliers, remade_batch_tokens, token_count
387
+
388
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
389
+ remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] + [self.tokenizer.eos_token_id] for x in remade_batch_tokens]
390
+ batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
391
+
392
+ tokens = torch.asarray(remade_batch_tokens).to(self.device)
393
+
394
+ z = self.encode_with_transformers(tokens)
395
+
396
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
397
+ batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
398
+ batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(self.device)
399
+ original_mean = z.mean()
400
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
401
+ new_mean = z.mean()
402
+ z *= original_mean / new_mean
403
+
404
+ return z
405
+
406
+ def forward(self, text):
407
+ batch_multipliers, remade_batch_tokens, token_count = self.process_text(text)
408
+
409
+ z = None
410
+ i = 0
411
+ while max(map(len, remade_batch_tokens)) != 0:
412
+ rem_tokens = [x[75:] for x in remade_batch_tokens]
413
+ rem_multipliers = [x[75:] for x in batch_multipliers]
414
+
415
+ tokens = []
416
+ multipliers = []
417
+ for j in range(len(remade_batch_tokens)):
418
+ if len(remade_batch_tokens[j]) > 0:
419
+ tokens.append(remade_batch_tokens[j][:75])
420
+ multipliers.append(batch_multipliers[j][:75])
421
+ else:
422
+ tokens.append([self.tokenizer.eos_token_id] * 75)
423
+ multipliers.append([1.0] * 75)
424
+
425
+ z1 = self.process_tokens(tokens, multipliers)
426
+ z = z1 if z is None else torch.cat((z, z1), axis=-2)
427
+
428
+ remade_batch_tokens = rem_tokens
429
+ batch_multipliers = rem_multipliers
430
+ i += 1
431
+
432
+ return z
433
+
434
+ def encode(self, text):
435
+ return self(text)
436
 
 
 
 
437
 
438
 
439
  if __name__ == "__main__":
 
440
  model = FrozenCLIPEmbedder()
441
+ count_params(model, verbose=True)
ldm/modules/{structure_condition β†’ extra_condition}/__init__.py RENAMED
File without changes
ldm/modules/extra_condition/api.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum, unique
2
+
3
+ import cv2
4
+ import torch
5
+ from basicsr.utils import img2tensor
6
+ from ldm.util import resize_numpy_image
7
+ from PIL import Image
8
+ from torch import autocast
9
+
10
+
11
+ @unique
12
+ class ExtraCondition(Enum):
13
+ sketch = 0
14
+ keypose = 1
15
+ seg = 2
16
+ depth = 3
17
+ canny = 4
18
+ style = 5
19
+ color = 6
20
+ openpose = 7
21
+
22
+
23
+ def get_cond_model(opt, cond_type: ExtraCondition):
24
+ if cond_type == ExtraCondition.sketch:
25
+ from ldm.modules.extra_condition.model_edge import pidinet
26
+ model = pidinet()
27
+ ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
28
+ model.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()}, strict=True)
29
+ model.to(opt.device)
30
+ return model
31
+ elif cond_type == ExtraCondition.seg:
32
+ raise NotImplementedError
33
+ elif cond_type == ExtraCondition.keypose:
34
+ import mmcv
35
+ from mmdet.apis import init_detector
36
+ from mmpose.apis import init_pose_model
37
+ det_config = 'configs/mm/faster_rcnn_r50_fpn_coco.py'
38
+ det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
39
+ pose_config = 'configs/mm/hrnet_w48_coco_256x192.py'
40
+ pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
41
+ det_config_mmcv = mmcv.Config.fromfile(det_config)
42
+ det_model = init_detector(det_config_mmcv, det_checkpoint, device=opt.device)
43
+ pose_config_mmcv = mmcv.Config.fromfile(pose_config)
44
+ pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=opt.device)
45
+ return {'pose_model': pose_model, 'det_model': det_model}
46
+ elif cond_type == ExtraCondition.depth:
47
+ from ldm.modules.extra_condition.midas.api import MiDaSInference
48
+ model = MiDaSInference(model_type='dpt_hybrid').to(opt.device)
49
+ return model
50
+ elif cond_type == ExtraCondition.canny:
51
+ return None
52
+ elif cond_type == ExtraCondition.style:
53
+ from transformers import CLIPProcessor, CLIPVisionModel
54
+ version = 'openai/clip-vit-large-patch14'
55
+ processor = CLIPProcessor.from_pretrained(version)
56
+ clip_vision_model = CLIPVisionModel.from_pretrained(version).to(opt.device)
57
+ return {'processor': processor, 'clip_vision_model': clip_vision_model}
58
+ elif cond_type == ExtraCondition.color:
59
+ return None
60
+ elif cond_type == ExtraCondition.openpose:
61
+ from ldm.modules.extra_condition.openpose.api import OpenposeInference
62
+ model = OpenposeInference().to(opt.device)
63
+ return model
64
+ else:
65
+ raise NotImplementedError
66
+
67
+
68
+ def get_cond_sketch(opt, cond_image, cond_inp_type, cond_model=None):
69
+ if isinstance(cond_image, str):
70
+ edge = cv2.imread(cond_image)
71
+ else:
72
+ # for gradio input, pay attention, it's rgb numpy
73
+ edge = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
74
+ edge = resize_numpy_image(edge, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
75
+ opt.H, opt.W = edge.shape[:2]
76
+ if cond_inp_type == 'sketch':
77
+ edge = img2tensor(edge)[0].unsqueeze(0).unsqueeze(0) / 255.
78
+ edge = edge.to(opt.device)
79
+ elif cond_inp_type == 'image':
80
+ edge = img2tensor(edge).unsqueeze(0) / 255.
81
+ edge = cond_model(edge.to(opt.device))[-1]
82
+ else:
83
+ raise NotImplementedError
84
+
85
+ # edge = 1-edge # for white background
86
+ edge = edge > 0.5
87
+ edge = edge.float()
88
+
89
+ return edge
90
+
91
+
92
+ def get_cond_seg(opt, cond_image, cond_inp_type='image', cond_model=None):
93
+ if isinstance(cond_image, str):
94
+ seg = cv2.imread(cond_image)
95
+ else:
96
+ seg = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
97
+ seg = resize_numpy_image(seg, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
98
+ opt.H, opt.W = seg.shape[:2]
99
+ if cond_inp_type == 'seg':
100
+ seg = img2tensor(seg).unsqueeze(0) / 255.
101
+ seg = seg.to(opt.device)
102
+ else:
103
+ raise NotImplementedError
104
+
105
+ return seg
106
+
107
+
108
+ def get_cond_keypose(opt, cond_image, cond_inp_type='image', cond_model=None):
109
+ if isinstance(cond_image, str):
110
+ pose = cv2.imread(cond_image)
111
+ else:
112
+ pose = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
113
+ pose = resize_numpy_image(pose, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
114
+ opt.H, opt.W = pose.shape[:2]
115
+ if cond_inp_type == 'keypose':
116
+ pose = img2tensor(pose).unsqueeze(0) / 255.
117
+ pose = pose.to(opt.device)
118
+ elif cond_inp_type == 'image':
119
+ from ldm.modules.extra_condition.utils import imshow_keypoints
120
+ from mmdet.apis import inference_detector
121
+ from mmpose.apis import (inference_top_down_pose_model, process_mmdet_results)
122
+
123
+ # mmpose seems not compatible with autocast fp16
124
+ with autocast("cuda", dtype=torch.float32):
125
+ mmdet_results = inference_detector(cond_model['det_model'], pose)
126
+ # keep the person class bounding boxes.
127
+ person_results = process_mmdet_results(mmdet_results, 1)
128
+
129
+ # optional
130
+ return_heatmap = False
131
+ dataset = cond_model['pose_model'].cfg.data['test']['type']
132
+
133
+ # e.g. use ('backbone', ) to return backbone feature
134
+ output_layer_names = None
135
+ pose_results, returned_outputs = inference_top_down_pose_model(
136
+ cond_model['pose_model'],
137
+ pose,
138
+ person_results,
139
+ bbox_thr=0.2,
140
+ format='xyxy',
141
+ dataset=dataset,
142
+ dataset_info=None,
143
+ return_heatmap=return_heatmap,
144
+ outputs=output_layer_names)
145
+
146
+ # show the results
147
+ pose = imshow_keypoints(pose, pose_results, radius=2, thickness=2)
148
+ pose = img2tensor(pose).unsqueeze(0) / 255.
149
+ pose = pose.to(opt.device)
150
+ else:
151
+ raise NotImplementedError
152
+
153
+ return pose
154
+
155
+
156
+ def get_cond_depth(opt, cond_image, cond_inp_type='image', cond_model=None):
157
+ if isinstance(cond_image, str):
158
+ depth = cv2.imread(cond_image)
159
+ else:
160
+ depth = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
161
+ depth = resize_numpy_image(depth, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
162
+ opt.H, opt.W = depth.shape[:2]
163
+ if cond_inp_type == 'depth':
164
+ depth = img2tensor(depth).unsqueeze(0) / 255.
165
+ depth = depth.to(opt.device)
166
+ elif cond_inp_type == 'image':
167
+ depth = img2tensor(depth).unsqueeze(0) / 127.5 - 1.0
168
+ depth = cond_model(depth.to(opt.device)).repeat(1, 3, 1, 1)
169
+ depth -= torch.min(depth)
170
+ depth /= torch.max(depth)
171
+ else:
172
+ raise NotImplementedError
173
+
174
+ return depth
175
+
176
+
177
+ def get_cond_canny(opt, cond_image, cond_inp_type='image', cond_model=None):
178
+ if isinstance(cond_image, str):
179
+ canny = cv2.imread(cond_image)
180
+ else:
181
+ canny = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
182
+ canny = resize_numpy_image(canny, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
183
+ opt.H, opt.W = canny.shape[:2]
184
+ if cond_inp_type == 'canny':
185
+ canny = img2tensor(canny)[0:1].unsqueeze(0) / 255.
186
+ canny = canny.to(opt.device)
187
+ elif cond_inp_type == 'image':
188
+ canny = cv2.Canny(canny, 100, 200)[..., None]
189
+ canny = img2tensor(canny).unsqueeze(0) / 255.
190
+ canny = canny.to(opt.device)
191
+ else:
192
+ raise NotImplementedError
193
+
194
+ return canny
195
+
196
+
197
+ def get_cond_style(opt, cond_image, cond_inp_type='image', cond_model=None):
198
+ assert cond_inp_type == 'image'
199
+ if isinstance(cond_image, str):
200
+ style = Image.open(cond_image)
201
+ else:
202
+ # numpy image to PIL image
203
+ style = Image.fromarray(cond_image)
204
+
205
+ style_for_clip = cond_model['processor'](images=style, return_tensors="pt")['pixel_values']
206
+ style_feat = cond_model['clip_vision_model'](style_for_clip.to(opt.device))['last_hidden_state']
207
+
208
+ return style_feat
209
+
210
+
211
+ def get_cond_color(opt, cond_image, cond_inp_type='image', cond_model=None):
212
+ if isinstance(cond_image, str):
213
+ color = cv2.imread(cond_image)
214
+ else:
215
+ color = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
216
+ color = resize_numpy_image(color, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
217
+ opt.H, opt.W = color.shape[:2]
218
+ if cond_inp_type == 'image':
219
+ color = cv2.resize(color, (opt.W//64, opt.H//64), interpolation=cv2.INTER_CUBIC)
220
+ color = cv2.resize(color, (opt.W, opt.H), interpolation=cv2.INTER_NEAREST)
221
+ color = img2tensor(color).unsqueeze(0) / 255.
222
+ color = color.to(opt.device)
223
+ return color
224
+
225
+
226
+ def get_cond_openpose(opt, cond_image, cond_inp_type='image', cond_model=None):
227
+ if isinstance(cond_image, str):
228
+ openpose_keypose = cv2.imread(cond_image)
229
+ else:
230
+ openpose_keypose = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
231
+ openpose_keypose = resize_numpy_image(
232
+ openpose_keypose, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
233
+ opt.H, opt.W = openpose_keypose.shape[:2]
234
+ if cond_inp_type == 'openpose':
235
+ openpose_keypose = img2tensor(openpose_keypose).unsqueeze(0) / 255.
236
+ openpose_keypose = openpose_keypose.to(opt.device)
237
+ elif cond_inp_type == 'image':
238
+ with autocast('cuda', dtype=torch.float32):
239
+ openpose_keypose = cond_model(openpose_keypose)
240
+ openpose_keypose = img2tensor(openpose_keypose).unsqueeze(0) / 255.
241
+ openpose_keypose = openpose_keypose.to(opt.device)
242
+
243
+ else:
244
+ raise NotImplementedError
245
+
246
+ return openpose_keypose
247
+
248
+
249
+ def get_adapter_feature(inputs, adapters):
250
+ ret_feat_map = None
251
+ ret_feat_seq = None
252
+ if not isinstance(inputs, list):
253
+ inputs = [inputs]
254
+ adapters = [adapters]
255
+
256
+ for input, adapter in zip(inputs, adapters):
257
+ cur_feature = adapter['model'](input)
258
+ if isinstance(cur_feature, list):
259
+ if ret_feat_map is None:
260
+ ret_feat_map = list(map(lambda x: x * adapter['cond_weight'], cur_feature))
261
+ else:
262
+ ret_feat_map = list(map(lambda x, y: x + y * adapter['cond_weight'], ret_feat_map, cur_feature))
263
+ else:
264
+ if ret_feat_seq is None:
265
+ ret_feat_seq = cur_feature
266
+ else:
267
+ ret_feat_seq = torch.cat([ret_feat_seq, cur_feature], dim=1)
268
+
269
+ return ret_feat_map, ret_feat_seq
ldm/modules/{structure_condition/midas β†’ extra_condition}/midas/__init__.py RENAMED
File without changes
ldm/modules/{structure_condition β†’ extra_condition}/midas/api.py RENAMED
@@ -6,10 +6,10 @@ import torch
6
  import torch.nn as nn
7
  from torchvision.transforms import Compose
8
 
9
- from ldm.modules.structure_condition.midas.midas.dpt_depth import DPTDepthModel
10
- from ldm.modules.structure_condition.midas.midas.midas_net import MidasNet
11
- from ldm.modules.structure_condition.midas.midas.midas_net_custom import MidasNet_small
12
- from ldm.modules.structure_condition.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
13
 
14
 
15
  ISL_PATHS = {
 
6
  import torch.nn as nn
7
  from torchvision.transforms import Compose
8
 
9
+ from ldm.modules.extra_condition.midas.midas.dpt_depth import DPTDepthModel
10
+ from ldm.modules.extra_condition.midas.midas.midas_net import MidasNet
11
+ from ldm.modules.extra_condition.midas.midas.midas_net_custom import MidasNet_small
12
+ from ldm.modules.extra_condition.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
13
 
14
 
15
  ISL_PATHS = {
ldm/modules/{structure_condition/openpose β†’ extra_condition/midas/midas}/__init__.py RENAMED
File without changes
ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/base_model.py RENAMED
File without changes
ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/blocks.py RENAMED
File without changes
ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/dpt_depth.py RENAMED
File without changes
ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/midas_net.py RENAMED
File without changes
ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/midas_net_custom.py RENAMED
File without changes
ldm/modules/{structure_condition β†’ extra_condition}/midas/midas/transforms.py RENAMED
File without changes