jadechoghari commited on
Commit
5537459
1 Parent(s): 687f75f

Create openaimodel.py

Browse files
Files changed (1) hide show
  1. unet/openaimodel.py +1009 -0
unet/openaimodel.py ADDED
@@ -0,0 +1,1009 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+
21
+ # replace with custom transformer
22
+ from .mv_attention import SPADTransformer as SpatialTransformer
23
+
24
+
25
+ from .util import exists
26
+ from torch import autocast
27
+
28
+ # dummy replace
29
+ def convert_module_to_f16(x):
30
+ pass
31
+
32
+ def convert_module_to_f32(x):
33
+ pass
34
+
35
+
36
+ ## go
37
+ class AttentionPool2d(nn.Module):
38
+ """
39
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ spacial_dim: int,
45
+ embed_dim: int,
46
+ num_heads_channels: int,
47
+ output_dim: int = None,
48
+ ):
49
+ super().__init__()
50
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
51
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
52
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
53
+ self.num_heads = embed_dim // num_heads_channels
54
+ self.attention = QKVAttention(self.num_heads)
55
+
56
+ def forward(self, x):
57
+ b, c, *_spatial = x.shape
58
+ x = x.reshape(b, c, -1) # NC(HW)
59
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
60
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
61
+ x = self.qkv_proj(x)
62
+ x = self.attention(x)
63
+ x = self.c_proj(x)
64
+ return x[:, :, 0]
65
+
66
+
67
+ class TimestepBlock(nn.Module):
68
+ """
69
+ Any module where forward() takes timestep embeddings as a second argument.
70
+ """
71
+
72
+ @abstractmethod
73
+ def forward(self, x, emb):
74
+ """
75
+ Apply the module to `x` given `emb` timestep embeddings.
76
+ """
77
+
78
+
79
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
80
+ """
81
+ A sequential module that passes timestep embeddings to the children that
82
+ support it as an extra input.
83
+ """
84
+
85
+ def forward(self, x, emb, context=None):
86
+ for layer in self:
87
+ if isinstance(layer, TimestepBlock):
88
+ x = layer(x, emb)
89
+ elif isinstance(layer, SpatialTransformer):
90
+ x = layer(x, context)
91
+ else:
92
+ x = layer(x)
93
+ return x
94
+
95
+
96
+ class Upsample(nn.Module):
97
+ """
98
+ An upsampling layer with an optional convolution.
99
+ :param channels: channels in the inputs and outputs.
100
+ :param use_conv: a bool determining if a convolution is applied.
101
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
102
+ upsampling occurs in the inner-two dimensions.
103
+ """
104
+
105
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
106
+ super().__init__()
107
+ self.channels = channels
108
+ self.out_channels = out_channels or channels
109
+ self.use_conv = use_conv
110
+ self.dims = dims
111
+ if use_conv:
112
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
113
+
114
+ def forward(self, x):
115
+ assert x.shape[1] == self.channels
116
+
117
+ # hack
118
+ orig_dtype = x.dtype
119
+ x = x.to(th.float32)
120
+ if self.dims == 3:
121
+ x = F.interpolate(
122
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
123
+ )
124
+ else:
125
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
126
+ x = x.to(orig_dtype)
127
+ if self.use_conv:
128
+ x = self.conv(x)
129
+ return x
130
+
131
+ class TransposedUpsample(nn.Module):
132
+ 'Learned 2x upsampling without padding'
133
+ def __init__(self, channels, out_channels=None, ks=5):
134
+ super().__init__()
135
+ self.channels = channels
136
+ self.out_channels = out_channels or channels
137
+
138
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
139
+
140
+ def forward(self,x):
141
+ return self.up(x)
142
+
143
+
144
+ class Downsample(nn.Module):
145
+ """
146
+ A downsampling layer with an optional convolution.
147
+ :param channels: channels in the inputs and outputs.
148
+ :param use_conv: a bool determining if a convolution is applied.
149
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
150
+ downsampling occurs in the inner-two dimensions.
151
+ """
152
+
153
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
154
+ super().__init__()
155
+ self.channels = channels
156
+ self.out_channels = out_channels or channels
157
+ self.use_conv = use_conv
158
+ self.dims = dims
159
+ stride = 2 if dims != 3 else (1, 2, 2)
160
+ if use_conv:
161
+ self.op = conv_nd(
162
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
163
+ )
164
+ else:
165
+ assert self.channels == self.out_channels
166
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
167
+
168
+ def forward(self, x):
169
+ assert x.shape[1] == self.channels
170
+ return self.op(x)
171
+
172
+
173
+ class ResBlock(TimestepBlock):
174
+ """
175
+ A residual block that can optionally change the number of channels.
176
+ :param channels: the number of input channels.
177
+ :param emb_channels: the number of timestep embedding channels.
178
+ :param dropout: the rate of dropout.
179
+ :param out_channels: if specified, the number of out channels.
180
+ :param use_conv: if True and out_channels is specified, use a spatial
181
+ convolution instead of a smaller 1x1 convolution to change the
182
+ channels in the skip connection.
183
+ :param dims: determines if the signal is 1D, 2D, or 3D.
184
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
185
+ :param up: if True, use this block for upsampling.
186
+ :param down: if True, use this block for downsampling.
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ channels,
192
+ emb_channels,
193
+ dropout,
194
+ out_channels=None,
195
+ use_conv=False,
196
+ use_scale_shift_norm=False,
197
+ dims=2,
198
+ use_checkpoint=False,
199
+ up=False,
200
+ down=False,
201
+ ):
202
+ super().__init__()
203
+ self.channels = channels
204
+ self.emb_channels = emb_channels
205
+ self.dropout = dropout
206
+ self.out_channels = out_channels or channels
207
+ self.use_conv = use_conv
208
+ self.use_checkpoint = use_checkpoint
209
+ self.use_scale_shift_norm = use_scale_shift_norm
210
+
211
+ self.in_layers = nn.Sequential(
212
+ normalization(channels),
213
+ nn.SiLU(),
214
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
215
+ )
216
+
217
+ self.updown = up or down
218
+
219
+ if up:
220
+ self.h_upd = Upsample(channels, False, dims)
221
+ self.x_upd = Upsample(channels, False, dims)
222
+ elif down:
223
+ self.h_upd = Downsample(channels, False, dims)
224
+ self.x_upd = Downsample(channels, False, dims)
225
+ else:
226
+ self.h_upd = self.x_upd = nn.Identity()
227
+
228
+ self.emb_layers = nn.Sequential(
229
+ nn.SiLU(),
230
+ linear(
231
+ emb_channels,
232
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
233
+ ),
234
+ )
235
+ self.out_layers = nn.Sequential(
236
+ normalization(self.out_channels),
237
+ nn.SiLU(),
238
+ nn.Dropout(p=dropout),
239
+ zero_module(
240
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
241
+ ),
242
+ )
243
+
244
+ if self.out_channels == channels:
245
+ self.skip_connection = nn.Identity()
246
+ elif use_conv:
247
+ self.skip_connection = conv_nd(
248
+ dims, channels, self.out_channels, 3, padding=1
249
+ )
250
+ else:
251
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
252
+
253
+ def forward(self, x, emb):
254
+ """
255
+ Apply the block to a Tensor, conditioned on a timestep embedding.
256
+ :param x: an [N x C x ...] Tensor of features.
257
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
258
+ :return: an [N x C x ...] Tensor of outputs.
259
+ """
260
+ return checkpoint(
261
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
262
+ )
263
+
264
+
265
+ def _forward(self, x, emb):
266
+ if self.updown:
267
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
268
+ h = in_rest(x)
269
+ h = self.h_upd(h)
270
+ x = self.x_upd(x)
271
+ h = in_conv(h)
272
+ else:
273
+ h = self.in_layers(x)
274
+ emb_out = self.emb_layers(emb).type(h.dtype)
275
+ while len(emb_out.shape) < len(h.shape):
276
+ emb_out = emb_out[..., None]
277
+ if self.use_scale_shift_norm:
278
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
279
+ scale, shift = th.chunk(emb_out, 2, dim=1)
280
+ h = out_norm(h) * (1 + scale) + shift
281
+ h = out_rest(h)
282
+ else:
283
+ h = h + emb_out
284
+ h = self.out_layers(h)
285
+ return self.skip_connection(x) + h
286
+
287
+
288
+ class AttentionBlock(nn.Module):
289
+ """
290
+ An attention block that allows spatial positions to attend to each other.
291
+ Originally ported from here, but adapted to the N-d case.
292
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ channels,
298
+ num_heads=1,
299
+ num_head_channels=-1,
300
+ use_checkpoint=False,
301
+ use_new_attention_order=False,
302
+ ):
303
+ super().__init__()
304
+ self.channels = channels
305
+ if num_head_channels == -1:
306
+ self.num_heads = num_heads
307
+ else:
308
+ assert (
309
+ channels % num_head_channels == 0
310
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
311
+ self.num_heads = channels // num_head_channels
312
+ self.use_checkpoint = use_checkpoint
313
+ self.norm = normalization(channels)
314
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
315
+ if use_new_attention_order:
316
+ # split qkv before split heads
317
+ self.attention = QKVAttention(self.num_heads)
318
+ else:
319
+ # split heads before split qkv
320
+ self.attention = QKVAttentionLegacy(self.num_heads)
321
+
322
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
323
+
324
+ def forward(self, x):
325
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
326
+ #return pt_checkpoint(self._forward, x) # pytorch
327
+
328
+ def _forward(self, x):
329
+ b, c, *spatial = x.shape
330
+ x = x.reshape(b, c, -1)
331
+ qkv = self.qkv(self.norm(x))
332
+ h = self.attention(qkv)
333
+ h = self.proj_out(h)
334
+ return (x + h).reshape(b, c, *spatial)
335
+
336
+
337
+ def count_flops_attn(model, _x, y):
338
+ """
339
+ A counter for the `thop` package to count the operations in an
340
+ attention operation.
341
+ Meant to be used like:
342
+ macs, params = thop.profile(
343
+ model,
344
+ inputs=(inputs, timestamps),
345
+ custom_ops={QKVAttention: QKVAttention.count_flops},
346
+ )
347
+ """
348
+ b, c, *spatial = y[0].shape
349
+ num_spatial = int(np.prod(spatial))
350
+ # We perform two matmuls with the same number of ops.
351
+ # The first computes the weight matrix, the second computes
352
+ # the combination of the value vectors.
353
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
354
+ model.total_ops += th.DoubleTensor([matmul_ops])
355
+
356
+
357
+ class QKVAttentionLegacy(nn.Module):
358
+ """
359
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
360
+ """
361
+
362
+ def __init__(self, n_heads):
363
+ super().__init__()
364
+ self.n_heads = n_heads
365
+
366
+ def forward(self, qkv):
367
+ """
368
+ Apply QKV attention.
369
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
370
+ :return: an [N x (H * C) x T] tensor after attention.
371
+ """
372
+ bs, width, length = qkv.shape
373
+ assert width % (3 * self.n_heads) == 0
374
+ ch = width // (3 * self.n_heads)
375
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
376
+ scale = 1 / math.sqrt(math.sqrt(ch))
377
+ weight = th.einsum(
378
+ "bct,bcs->bts", q * scale, k * scale
379
+ ) # More stable with f16 than dividing afterwards
380
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
381
+ a = th.einsum("bts,bcs->bct", weight, v)
382
+ return a.reshape(bs, -1, length)
383
+
384
+ @staticmethod
385
+ def count_flops(model, _x, y):
386
+ return count_flops_attn(model, _x, y)
387
+
388
+
389
+ class QKVAttention(nn.Module):
390
+ """
391
+ A module which performs QKV attention and splits in a different order.
392
+ """
393
+
394
+ def __init__(self, n_heads):
395
+ super().__init__()
396
+ self.n_heads = n_heads
397
+
398
+ def forward(self, qkv):
399
+ """
400
+ Apply QKV attention.
401
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
402
+ :return: an [N x (H * C) x T] tensor after attention.
403
+ """
404
+ bs, width, length = qkv.shape
405
+ assert width % (3 * self.n_heads) == 0
406
+ ch = width // (3 * self.n_heads)
407
+ q, k, v = qkv.chunk(3, dim=1)
408
+ scale = 1 / math.sqrt(math.sqrt(ch))
409
+ weight = th.einsum(
410
+ "bct,bcs->bts",
411
+ (q * scale).view(bs * self.n_heads, ch, length),
412
+ (k * scale).view(bs * self.n_heads, ch, length),
413
+ ) # More stable with f16 than dividing afterwards
414
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
415
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
416
+ return a.reshape(bs, -1, length)
417
+
418
+ @staticmethod
419
+ def count_flops(model, _x, y):
420
+ return count_flops_attn(model, _x, y)
421
+
422
+
423
+ class UNetModel(nn.Module):
424
+ """
425
+ The full UNet model with attention and timestep embedding.
426
+ :param in_channels: channels in the input Tensor.
427
+ :param model_channels: base channel count for the model.
428
+ :param out_channels: channels in the output Tensor.
429
+ :param num_res_blocks: number of residual blocks per downsample.
430
+ :param attention_resolutions: a collection of downsample rates at which
431
+ attention will take place. May be a set, list, or tuple.
432
+ For example, if this contains 4, then at 4x downsampling, attention
433
+ will be used.
434
+ :param dropout: the dropout probability.
435
+ :param channel_mult: channel multiplier for each level of the UNet.
436
+ :param conv_resample: if True, use learned convolutions for upsampling and
437
+ downsampling.
438
+ :param dims: determines if the signal is 1D, 2D, or 3D.
439
+ :param num_classes: if specified (as an int), then this model will be
440
+ class-conditional with `num_classes` classes.
441
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
442
+ :param num_heads: the number of attention heads in each attention layer.
443
+ :param num_heads_channels: if specified, ignore num_heads and instead use
444
+ a fixed channel width per attention head.
445
+ :param num_heads_upsample: works with num_heads to set a different number
446
+ of heads for upsampling. Deprecated.
447
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
448
+ :param resblock_updown: use residual blocks for up/downsampling.
449
+ :param use_new_attention_order: use a different attention pattern for potentially
450
+ increased efficiency.
451
+ """
452
+
453
+ def __init__(
454
+ self,
455
+ image_size,
456
+ in_channels,
457
+ model_channels,
458
+ out_channels,
459
+ num_res_blocks,
460
+ attention_resolutions,
461
+ dropout=0,
462
+ channel_mult=(1, 2, 4, 8),
463
+ conv_resample=True,
464
+ dims=2,
465
+ num_classes=None,
466
+ use_checkpoint=False,
467
+ use_fp16=False,
468
+ num_heads=-1,
469
+ num_head_channels=-1,
470
+ num_heads_upsample=-1,
471
+ use_scale_shift_norm=False,
472
+ resblock_updown=False,
473
+ use_new_attention_order=False,
474
+ use_spatial_transformer=False, # custom transformer support
475
+ transformer_depth=1, # custom transformer support
476
+ context_dim=None, # custom transformer support
477
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
478
+ legacy=True,
479
+ disable_self_attentions=None,
480
+ num_attention_blocks=None,
481
+ **kwargs
482
+ ):
483
+ for k,v in kwargs.items():
484
+ print(f"UNetModel: unused parameter {k}={v}")
485
+
486
+ super().__init__()
487
+ if use_spatial_transformer:
488
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
489
+
490
+ if context_dim is not None:
491
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
492
+ from omegaconf.listconfig import ListConfig
493
+ if type(context_dim) == ListConfig:
494
+ context_dim = list(context_dim)
495
+
496
+ if num_heads_upsample == -1:
497
+ num_heads_upsample = num_heads
498
+
499
+ if num_heads == -1:
500
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
501
+
502
+ if num_head_channels == -1:
503
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
504
+
505
+ self.image_size = image_size
506
+ self.in_channels = in_channels
507
+ self.model_channels = model_channels
508
+ self.out_channels = out_channels
509
+ if isinstance(num_res_blocks, int):
510
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
511
+ else:
512
+ if len(num_res_blocks) != len(channel_mult):
513
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
514
+ "as a list/tuple (per-level) with the same length as channel_mult")
515
+ self.num_res_blocks = num_res_blocks
516
+ #self.num_res_blocks = num_res_blocks
517
+ if disable_self_attentions is not None:
518
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
519
+ assert len(disable_self_attentions) == len(channel_mult)
520
+ if num_attention_blocks is not None:
521
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
522
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
523
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
524
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
525
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
526
+ f"attention will still not be set.") # todo: convert to warning
527
+
528
+ self.attention_resolutions = attention_resolutions
529
+ self.dropout = dropout
530
+ self.channel_mult = channel_mult
531
+ self.conv_resample = conv_resample
532
+ self.num_classes = num_classes
533
+ self.use_checkpoint = use_checkpoint
534
+ self.dtype = th.float16 if use_fp16 else th.float32
535
+ self.num_heads = num_heads
536
+ self.num_head_channels = num_head_channels
537
+ self.num_heads_upsample = num_heads_upsample
538
+ self.predict_codebook_ids = n_embed is not None
539
+
540
+ time_embed_dim = model_channels * 4
541
+ self.time_embed = nn.Sequential(
542
+ linear(model_channels, time_embed_dim),
543
+ nn.SiLU(),
544
+ linear(time_embed_dim, time_embed_dim),
545
+ )
546
+
547
+ if self.num_classes is not None:
548
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
549
+
550
+ self.input_blocks = nn.ModuleList(
551
+ [
552
+ TimestepEmbedSequential(
553
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
554
+ )
555
+ ]
556
+ )
557
+ self._feature_size = model_channels
558
+ input_block_chans = [model_channels]
559
+ ch = model_channels
560
+ ds = 1
561
+ for level, mult in enumerate(channel_mult):
562
+ for nr in range(self.num_res_blocks[level]):
563
+ layers = [
564
+ ResBlock(
565
+ ch,
566
+ time_embed_dim,
567
+ dropout,
568
+ out_channels=mult * model_channels,
569
+ dims=dims,
570
+ use_checkpoint=use_checkpoint,
571
+ use_scale_shift_norm=use_scale_shift_norm,
572
+ )
573
+ ]
574
+ ch = mult * model_channels
575
+ if ds in attention_resolutions:
576
+ if num_head_channels == -1:
577
+ dim_head = ch // num_heads
578
+ else:
579
+ num_heads = ch // num_head_channels
580
+ dim_head = num_head_channels
581
+ if legacy:
582
+ #num_heads = 1
583
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
584
+ if exists(disable_self_attentions):
585
+ disabled_sa = disable_self_attentions[level]
586
+ else:
587
+ disabled_sa = False
588
+
589
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
590
+ layers.append(
591
+ AttentionBlock(
592
+ ch,
593
+ use_checkpoint=use_checkpoint,
594
+ num_heads=num_heads,
595
+ num_head_channels=dim_head,
596
+ use_new_attention_order=use_new_attention_order,
597
+ ) if not use_spatial_transformer else SpatialTransformer(
598
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
599
+ disable_self_attn=disabled_sa
600
+ )
601
+ )
602
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
603
+ self._feature_size += ch
604
+ input_block_chans.append(ch)
605
+ if level != len(channel_mult) - 1:
606
+ out_ch = ch
607
+ self.input_blocks.append(
608
+ TimestepEmbedSequential(
609
+ ResBlock(
610
+ ch,
611
+ time_embed_dim,
612
+ dropout,
613
+ out_channels=out_ch,
614
+ dims=dims,
615
+ use_checkpoint=use_checkpoint,
616
+ use_scale_shift_norm=use_scale_shift_norm,
617
+ down=True,
618
+ )
619
+ if resblock_updown
620
+ else Downsample(
621
+ ch, conv_resample, dims=dims, out_channels=out_ch
622
+ )
623
+ )
624
+ )
625
+ ch = out_ch
626
+ input_block_chans.append(ch)
627
+ ds *= 2
628
+ self._feature_size += ch
629
+
630
+ if num_head_channels == -1:
631
+ dim_head = ch // num_heads
632
+ else:
633
+ num_heads = ch // num_head_channels
634
+ dim_head = num_head_channels
635
+ if legacy:
636
+ #num_heads = 1
637
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
638
+ self.middle_block = TimestepEmbedSequential(
639
+ ResBlock(
640
+ ch,
641
+ time_embed_dim,
642
+ dropout,
643
+ dims=dims,
644
+ use_checkpoint=use_checkpoint,
645
+ use_scale_shift_norm=use_scale_shift_norm,
646
+ ),
647
+ AttentionBlock(
648
+ ch,
649
+ use_checkpoint=use_checkpoint,
650
+ num_heads=num_heads,
651
+ num_head_channels=dim_head,
652
+ use_new_attention_order=use_new_attention_order,
653
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
654
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
655
+ ),
656
+ ResBlock(
657
+ ch,
658
+ time_embed_dim,
659
+ dropout,
660
+ dims=dims,
661
+ use_checkpoint=use_checkpoint,
662
+ use_scale_shift_norm=use_scale_shift_norm,
663
+ ),
664
+ )
665
+ self._feature_size += ch
666
+
667
+ self.output_blocks = nn.ModuleList([])
668
+ for level, mult in list(enumerate(channel_mult))[::-1]:
669
+ for i in range(self.num_res_blocks[level] + 1):
670
+ ich = input_block_chans.pop()
671
+ layers = [
672
+ ResBlock(
673
+ ch + ich,
674
+ time_embed_dim,
675
+ dropout,
676
+ out_channels=model_channels * mult,
677
+ dims=dims,
678
+ use_checkpoint=use_checkpoint,
679
+ use_scale_shift_norm=use_scale_shift_norm,
680
+ )
681
+ ]
682
+ ch = model_channels * mult
683
+ if ds in attention_resolutions:
684
+ if num_head_channels == -1:
685
+ dim_head = ch // num_heads
686
+ else:
687
+ num_heads = ch // num_head_channels
688
+ dim_head = num_head_channels
689
+ if legacy:
690
+ #num_heads = 1
691
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
692
+ if exists(disable_self_attentions):
693
+ disabled_sa = disable_self_attentions[level]
694
+ else:
695
+ disabled_sa = False
696
+
697
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
698
+ layers.append(
699
+ AttentionBlock(
700
+ ch,
701
+ use_checkpoint=use_checkpoint,
702
+ num_heads=num_heads_upsample,
703
+ num_head_channels=dim_head,
704
+ use_new_attention_order=use_new_attention_order,
705
+ ) if not use_spatial_transformer else SpatialTransformer(
706
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
707
+ disable_self_attn=disabled_sa
708
+ )
709
+ )
710
+ if level and i == self.num_res_blocks[level]:
711
+ out_ch = ch
712
+ layers.append(
713
+ ResBlock(
714
+ ch,
715
+ time_embed_dim,
716
+ dropout,
717
+ out_channels=out_ch,
718
+ dims=dims,
719
+ use_checkpoint=use_checkpoint,
720
+ use_scale_shift_norm=use_scale_shift_norm,
721
+ up=True,
722
+ )
723
+ if resblock_updown
724
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
725
+ )
726
+ ds //= 2
727
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
728
+ self._feature_size += ch
729
+
730
+ self.out = nn.Sequential(
731
+ normalization(ch),
732
+ nn.SiLU(),
733
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
734
+ )
735
+ if self.predict_codebook_ids:
736
+ self.id_predictor = nn.Sequential(
737
+ normalization(ch),
738
+ conv_nd(dims, model_channels, n_embed, 1),
739
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
740
+ )
741
+
742
+ def convert_to_fp16(self):
743
+ """
744
+ Convert the torso of the model to float16.
745
+ """
746
+ self.input_blocks.apply(convert_module_to_f16)
747
+ self.middle_block.apply(convert_module_to_f16)
748
+ self.output_blocks.apply(convert_module_to_f16)
749
+
750
+ def convert_to_fp32(self):
751
+ """
752
+ Convert the torso of the model to float32.
753
+ """
754
+ self.input_blocks.apply(convert_module_to_f32)
755
+ self.middle_block.apply(convert_module_to_f32)
756
+ self.output_blocks.apply(convert_module_to_f32)
757
+
758
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
759
+ """
760
+ Apply the model to an input batch.
761
+ :param x: an [N x C x ...] Tensor of inputs.
762
+ :param timesteps: a 1-D batch of timesteps.
763
+ :param context: conditioning plugged in via crossattn
764
+ :param y: an [N] Tensor of labels, if class-conditional.
765
+ :return: an [N x C x ...] Tensor of outputs.
766
+ """
767
+ assert (y is not None) == (
768
+ self.num_classes is not None
769
+ ), "must specify y if and only if the model is class-conditional"
770
+ hs = []
771
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
772
+ emb = self.time_embed(t_emb)
773
+
774
+ if self.num_classes is not None:
775
+ assert y.shape == (x.shape[0],)
776
+ emb = emb + self.label_emb(y)
777
+
778
+ h = x.type(self.dtype)
779
+ for module in self.input_blocks:
780
+ h = module(h, emb, context)
781
+ hs.append(h)
782
+ h = self.middle_block(h, emb, context)
783
+ for module in self.output_blocks:
784
+ h = th.cat([h, hs.pop()], dim=1)
785
+ h = module(h, emb, context)
786
+ h = h.type(x.dtype)
787
+ if self.predict_codebook_ids:
788
+ return self.id_predictor(h)
789
+ else:
790
+ return self.out(h)
791
+
792
+
793
+ class EncoderUNetModel(nn.Module):
794
+ """
795
+ The half UNet model with attention and timestep embedding.
796
+ For usage, see UNet.
797
+ """
798
+
799
+ def __init__(
800
+ self,
801
+ image_size,
802
+ in_channels,
803
+ model_channels,
804
+ out_channels,
805
+ num_res_blocks,
806
+ attention_resolutions,
807
+ dropout=0,
808
+ channel_mult=(1, 2, 4, 8),
809
+ conv_resample=True,
810
+ dims=2,
811
+ use_checkpoint=False,
812
+ use_fp16=False,
813
+ num_heads=1,
814
+ num_head_channels=-1,
815
+ num_heads_upsample=-1,
816
+ use_scale_shift_norm=False,
817
+ resblock_updown=False,
818
+ use_new_attention_order=False,
819
+ pool="adaptive",
820
+ *args,
821
+ **kwargs
822
+ ):
823
+ super().__init__()
824
+
825
+ if num_heads_upsample == -1:
826
+ num_heads_upsample = num_heads
827
+
828
+ self.in_channels = in_channels
829
+ self.model_channels = model_channels
830
+ self.out_channels = out_channels
831
+ self.num_res_blocks = num_res_blocks
832
+ self.attention_resolutions = attention_resolutions
833
+ self.dropout = dropout
834
+ self.channel_mult = channel_mult
835
+ self.conv_resample = conv_resample
836
+ self.use_checkpoint = use_checkpoint
837
+ self.dtype = th.float16 if use_fp16 else th.float32
838
+ self.num_heads = num_heads
839
+ self.num_head_channels = num_head_channels
840
+ self.num_heads_upsample = num_heads_upsample
841
+
842
+ time_embed_dim = model_channels * 4
843
+ self.time_embed = nn.Sequential(
844
+ linear(model_channels, time_embed_dim),
845
+ nn.SiLU(),
846
+ linear(time_embed_dim, time_embed_dim),
847
+ )
848
+
849
+ self.input_blocks = nn.ModuleList(
850
+ [
851
+ TimestepEmbedSequential(
852
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
853
+ )
854
+ ]
855
+ )
856
+ self._feature_size = model_channels
857
+ input_block_chans = [model_channels]
858
+ ch = model_channels
859
+ ds = 1
860
+ for level, mult in enumerate(channel_mult):
861
+ for _ in range(num_res_blocks):
862
+ layers = [
863
+ ResBlock(
864
+ ch,
865
+ time_embed_dim,
866
+ dropout,
867
+ out_channels=mult * model_channels,
868
+ dims=dims,
869
+ use_checkpoint=use_checkpoint,
870
+ use_scale_shift_norm=use_scale_shift_norm,
871
+ )
872
+ ]
873
+ ch = mult * model_channels
874
+ if ds in attention_resolutions:
875
+ layers.append(
876
+ AttentionBlock(
877
+ ch,
878
+ use_checkpoint=use_checkpoint,
879
+ num_heads=num_heads,
880
+ num_head_channels=num_head_channels,
881
+ use_new_attention_order=use_new_attention_order,
882
+ )
883
+ )
884
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
885
+ self._feature_size += ch
886
+ input_block_chans.append(ch)
887
+ if level != len(channel_mult) - 1:
888
+ out_ch = ch
889
+ self.input_blocks.append(
890
+ TimestepEmbedSequential(
891
+ ResBlock(
892
+ ch,
893
+ time_embed_dim,
894
+ dropout,
895
+ out_channels=out_ch,
896
+ dims=dims,
897
+ use_checkpoint=use_checkpoint,
898
+ use_scale_shift_norm=use_scale_shift_norm,
899
+ down=True,
900
+ )
901
+ if resblock_updown
902
+ else Downsample(
903
+ ch, conv_resample, dims=dims, out_channels=out_ch
904
+ )
905
+ )
906
+ )
907
+ ch = out_ch
908
+ input_block_chans.append(ch)
909
+ ds *= 2
910
+ self._feature_size += ch
911
+
912
+ self.middle_block = TimestepEmbedSequential(
913
+ ResBlock(
914
+ ch,
915
+ time_embed_dim,
916
+ dropout,
917
+ dims=dims,
918
+ use_checkpoint=use_checkpoint,
919
+ use_scale_shift_norm=use_scale_shift_norm,
920
+ ),
921
+ AttentionBlock(
922
+ ch,
923
+ use_checkpoint=use_checkpoint,
924
+ num_heads=num_heads,
925
+ num_head_channels=num_head_channels,
926
+ use_new_attention_order=use_new_attention_order,
927
+ ),
928
+ ResBlock(
929
+ ch,
930
+ time_embed_dim,
931
+ dropout,
932
+ dims=dims,
933
+ use_checkpoint=use_checkpoint,
934
+ use_scale_shift_norm=use_scale_shift_norm,
935
+ ),
936
+ )
937
+ self._feature_size += ch
938
+ self.pool = pool
939
+ if pool == "adaptive":
940
+ self.out = nn.Sequential(
941
+ normalization(ch),
942
+ nn.SiLU(),
943
+ nn.AdaptiveAvgPool2d((1, 1)),
944
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
945
+ nn.Flatten(),
946
+ )
947
+ elif pool == "attention":
948
+ assert num_head_channels != -1
949
+ self.out = nn.Sequential(
950
+ normalization(ch),
951
+ nn.SiLU(),
952
+ AttentionPool2d(
953
+ (image_size // ds), ch, num_head_channels, out_channels
954
+ ),
955
+ )
956
+ elif pool == "spatial":
957
+ self.out = nn.Sequential(
958
+ nn.Linear(self._feature_size, 2048),
959
+ nn.ReLU(),
960
+ nn.Linear(2048, self.out_channels),
961
+ )
962
+ elif pool == "spatial_v2":
963
+ self.out = nn.Sequential(
964
+ nn.Linear(self._feature_size, 2048),
965
+ normalization(2048),
966
+ nn.SiLU(),
967
+ nn.Linear(2048, self.out_channels),
968
+ )
969
+ else:
970
+ raise NotImplementedError(f"Unexpected {pool} pooling")
971
+
972
+ def convert_to_fp16(self):
973
+ """
974
+ Convert the torso of the model to float16.
975
+ """
976
+ self.input_blocks.apply(convert_module_to_f16)
977
+ self.middle_block.apply(convert_module_to_f16)
978
+
979
+ def convert_to_fp32(self):
980
+ """
981
+ Convert the torso of the model to float32.
982
+ """
983
+ self.input_blocks.apply(convert_module_to_f32)
984
+ self.middle_block.apply(convert_module_to_f32)
985
+
986
+ def forward(self, x, timesteps):
987
+ """
988
+ Apply the model to an input batch.
989
+ :param x: an [N x C x ...] Tensor of inputs.
990
+ :param timesteps: a 1-D batch of timesteps.
991
+ :return: an [N x K] Tensor of outputs.
992
+ """
993
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
994
+
995
+ results = []
996
+ h = x.type(self.dtype)
997
+ for module in self.input_blocks:
998
+ h = module(h, emb)
999
+ if self.pool.startswith("spatial"):
1000
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1001
+ h = self.middle_block(h, emb)
1002
+ if self.pool.startswith("spatial"):
1003
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1004
+ h = th.cat(results, axis=-1)
1005
+ return self.out(h)
1006
+ else:
1007
+ h = h.type(x.dtype)
1008
+ return self.out(h)
1009
+