Spaces:
Runtime error
Runtime error
Add copy of github repo
Browse files- Paella/src/modules.py +283 -0
- Paella/src/train.py +80 -0
- Paella/src/utils.py +55 -0
- Paella/src/vqgan.py +140 -0
- Paella/utils/alter_attention.py +53 -0
- Paella/utils/modules.py +291 -0
Paella/src/modules.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
class Attention2D(nn.Module):
|
8 |
+
def __init__(self, c, nhead, dropout=0.0):
|
9 |
+
super().__init__()
|
10 |
+
self.attn = torch.nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
|
11 |
+
|
12 |
+
def forward(self, x, kv, self_attn=False):
|
13 |
+
orig_shape = x.shape
|
14 |
+
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1)
|
15 |
+
if self_attn:
|
16 |
+
kv = torch.cat([x, kv], dim=1)
|
17 |
+
x = self.attn(x, kv, kv, need_weights=False)[0]
|
18 |
+
x = x.permute(0, 2, 1).view(*orig_shape)
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class LayerNorm2d(nn.LayerNorm):
|
23 |
+
def __init__(self, *args, **kwargs):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
28 |
+
|
29 |
+
|
30 |
+
class GlobalResponseNorm(nn.Module):
|
31 |
+
"Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
32 |
+
def __init__(self, dim):
|
33 |
+
super().__init__()
|
34 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
35 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
39 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
40 |
+
return self.gamma * (x * Nx) + self.beta + x
|
41 |
+
|
42 |
+
|
43 |
+
class ResBlock(nn.Module):
|
44 |
+
def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
|
45 |
+
super().__init__()
|
46 |
+
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
47 |
+
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
|
48 |
+
self.channelwise = nn.Sequential(
|
49 |
+
nn.Linear(c, c * 4),
|
50 |
+
nn.GELU(),
|
51 |
+
GlobalResponseNorm(c * 4),
|
52 |
+
nn.Dropout(dropout),
|
53 |
+
nn.Linear(c * 4, c)
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x, x_skip=None):
|
57 |
+
x_res = x
|
58 |
+
if x_skip is not None:
|
59 |
+
x = torch.cat([x, x_skip], dim=1)
|
60 |
+
x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
|
61 |
+
x = self.channelwise(x).permute(0, 3, 1, 2)
|
62 |
+
return x + x_res
|
63 |
+
|
64 |
+
|
65 |
+
class AttnBlock(nn.Module):
|
66 |
+
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
67 |
+
super().__init__()
|
68 |
+
self.self_attn = self_attn
|
69 |
+
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
|
70 |
+
self.attention = Attention2D(c, nhead, dropout)
|
71 |
+
self.kv_mapper = nn.Sequential(
|
72 |
+
nn.SiLU(),
|
73 |
+
nn.Linear(c_cond, c)
|
74 |
+
)
|
75 |
+
|
76 |
+
def forward(self, x, kv):
|
77 |
+
kv = self.kv_mapper(kv)
|
78 |
+
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class FeedForwardBlock(nn.Module):
|
83 |
+
def __init__(self, c, dropout=0.0):
|
84 |
+
super().__init__()
|
85 |
+
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
|
86 |
+
self.channelwise = nn.Sequential(
|
87 |
+
nn.Linear(c, c * 4),
|
88 |
+
nn.GELU(),
|
89 |
+
GlobalResponseNorm(c * 4),
|
90 |
+
nn.Dropout(dropout),
|
91 |
+
nn.Linear(c * 4, c)
|
92 |
+
)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
96 |
+
return x
|
97 |
+
|
98 |
+
|
99 |
+
class TimestepBlock(nn.Module):
|
100 |
+
def __init__(self, c, c_timestep):
|
101 |
+
super().__init__()
|
102 |
+
self.mapper = nn.Linear(c_timestep, c * 2)
|
103 |
+
|
104 |
+
def forward(self, x, t):
|
105 |
+
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
|
106 |
+
return x * (1 + a) + b
|
107 |
+
|
108 |
+
|
109 |
+
class Paella(nn.Module):
|
110 |
+
def __init__(self, c_in=256, c_out=256, num_labels=8192, c_r=64, patch_size=2, c_cond=1024,
|
111 |
+
c_hidden=[640, 1280, 1280], nhead=[-1, 16, 16], blocks=[6, 16, 6], level_config=['CT', 'CTA', 'CTA'],
|
112 |
+
clip_embd=1024, byt5_embd=1536, clip_seq_len=4, kernel_size=3, dropout=0.1, self_attn=True):
|
113 |
+
super().__init__()
|
114 |
+
self.c_r = c_r
|
115 |
+
self.c_cond = c_cond
|
116 |
+
self.num_labels = num_labels
|
117 |
+
if not isinstance(dropout, list):
|
118 |
+
dropout = [dropout] * len(c_hidden)
|
119 |
+
|
120 |
+
# CONDITIONING
|
121 |
+
self.byt5_mapper = nn.Linear(byt5_embd, c_cond)
|
122 |
+
self.clip_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len)
|
123 |
+
self.clip_image_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len)
|
124 |
+
self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
|
125 |
+
|
126 |
+
self.in_mapper = nn.Sequential(
|
127 |
+
nn.Embedding(num_labels, c_in),
|
128 |
+
nn.LayerNorm(c_in, elementwise_affine=False, eps=1e-6)
|
129 |
+
)
|
130 |
+
self.embedding = nn.Sequential(
|
131 |
+
nn.PixelUnshuffle(patch_size),
|
132 |
+
nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
|
133 |
+
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
|
134 |
+
)
|
135 |
+
|
136 |
+
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0):
|
137 |
+
if block_type == 'C':
|
138 |
+
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
|
139 |
+
elif block_type == 'A':
|
140 |
+
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
|
141 |
+
elif block_type == 'F':
|
142 |
+
return FeedForwardBlock(c_hidden, dropout=dropout)
|
143 |
+
elif block_type == 'T':
|
144 |
+
return TimestepBlock(c_hidden, c_r)
|
145 |
+
else:
|
146 |
+
raise Exception(f'Block type {block_type} not supported')
|
147 |
+
|
148 |
+
# DOWN BLOCKS
|
149 |
+
self.down_blocks = nn.ModuleList()
|
150 |
+
for i in range(len(c_hidden)):
|
151 |
+
down_block = nn.ModuleList()
|
152 |
+
if i > 0:
|
153 |
+
down_block.append(nn.Sequential(
|
154 |
+
LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
|
155 |
+
nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
|
156 |
+
))
|
157 |
+
for _ in range(blocks[i]):
|
158 |
+
for block_type in level_config[i]:
|
159 |
+
down_block.append(get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i]))
|
160 |
+
self.down_blocks.append(down_block)
|
161 |
+
|
162 |
+
# UP BLOCKS
|
163 |
+
self.up_blocks = nn.ModuleList()
|
164 |
+
for i in reversed(range(len(c_hidden))):
|
165 |
+
up_block = nn.ModuleList()
|
166 |
+
for j in range(blocks[i]):
|
167 |
+
for k, block_type in enumerate(level_config[i]):
|
168 |
+
up_block.append(get_block(block_type, c_hidden[i], nhead[i],
|
169 |
+
c_skip=c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0,
|
170 |
+
dropout=dropout[i]))
|
171 |
+
if i > 0:
|
172 |
+
up_block.append(nn.Sequential(
|
173 |
+
LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
|
174 |
+
nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
|
175 |
+
))
|
176 |
+
self.up_blocks.append(up_block)
|
177 |
+
|
178 |
+
# OUTPUT
|
179 |
+
self.clf = nn.Sequential(
|
180 |
+
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
|
181 |
+
nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
|
182 |
+
nn.PixelShuffle(patch_size),
|
183 |
+
)
|
184 |
+
self.out_mapper = nn.Sequential(
|
185 |
+
LayerNorm2d(c_out, elementwise_affine=False, eps=1e-6),
|
186 |
+
nn.Conv2d(c_out, num_labels, kernel_size=1, bias=False)
|
187 |
+
)
|
188 |
+
|
189 |
+
# --- WEIGHT INIT ---
|
190 |
+
self.apply(self._init_weights) # General init
|
191 |
+
nn.init.normal_(self.byt5_mapper.weight, std=0.02)
|
192 |
+
nn.init.normal_(self.clip_mapper.weight, std=0.02)
|
193 |
+
nn.init.normal_(self.clip_image_mapper.weight, std=0.02)
|
194 |
+
torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02)
|
195 |
+
nn.init.constant_(self.clf[1].weight, 0)
|
196 |
+
nn.init.normal_(self.in_mapper[0].weight, std=np.sqrt(1 / num_labels))
|
197 |
+
self.out_mapper[-1].weight.data = self.in_mapper[0].weight.data[:, :, None, None].clone()
|
198 |
+
|
199 |
+
for level_block in self.down_blocks + self.up_blocks:
|
200 |
+
for block in level_block:
|
201 |
+
if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
202 |
+
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks))
|
203 |
+
elif isinstance(block, TimestepBlock):
|
204 |
+
nn.init.constant_(block.mapper.weight, 0)
|
205 |
+
|
206 |
+
def _init_weights(self, m):
|
207 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
208 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
209 |
+
if m.bias is not None:
|
210 |
+
nn.init.constant_(m.bias, 0)
|
211 |
+
|
212 |
+
def gen_r_embedding(self, r, max_positions=10000):
|
213 |
+
r = r * max_positions
|
214 |
+
half_dim = self.c_r // 2
|
215 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
216 |
+
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
217 |
+
emb = r[:, None] * emb[None, :]
|
218 |
+
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
219 |
+
if self.c_r % 2 == 1:
|
220 |
+
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
221 |
+
return emb
|
222 |
+
|
223 |
+
def gen_c_embeddings(self, byt5, clip, clip_image):
|
224 |
+
seq = self.byt5_mapper(byt5)
|
225 |
+
if clip is not None:
|
226 |
+
clip = self.clip_mapper(clip).view(clip.size(0), -1, self.c_cond)
|
227 |
+
seq = torch.cat([seq, clip], dim=1)
|
228 |
+
if clip_image is not None:
|
229 |
+
clip_image = self.clip_image_mapper(clip_image).view(clip_image.size(0), -1, self.c_cond)
|
230 |
+
seq = torch.cat([seq, clip_image], dim=1)
|
231 |
+
seq = self.seq_norm(seq)
|
232 |
+
return seq
|
233 |
+
|
234 |
+
def _down_encode(self, x, r_embed, c_embed):
|
235 |
+
level_outputs = []
|
236 |
+
for down_block in self.down_blocks:
|
237 |
+
for block in down_block:
|
238 |
+
if isinstance(block, ResBlock):
|
239 |
+
x = block(x)
|
240 |
+
elif isinstance(block, AttnBlock):
|
241 |
+
x = block(x, c_embed)
|
242 |
+
elif isinstance(block, TimestepBlock):
|
243 |
+
x = block(x, r_embed)
|
244 |
+
else:
|
245 |
+
x = block(x)
|
246 |
+
level_outputs.insert(0, x)
|
247 |
+
return level_outputs
|
248 |
+
|
249 |
+
def _up_decode(self, level_outputs, r_embed, c_embed):
|
250 |
+
x = level_outputs[0]
|
251 |
+
for i, up_block in enumerate(self.up_blocks):
|
252 |
+
for j, block in enumerate(up_block):
|
253 |
+
if isinstance(block, ResBlock):
|
254 |
+
x = block(x, level_outputs[i] if j == 0 and i > 0 else None)
|
255 |
+
elif isinstance(block, AttnBlock):
|
256 |
+
x = block(x, c_embed)
|
257 |
+
elif isinstance(block, TimestepBlock):
|
258 |
+
x = block(x, r_embed)
|
259 |
+
else:
|
260 |
+
x = block(x)
|
261 |
+
return x
|
262 |
+
|
263 |
+
def forward(self, x, r, byt5, clip=None, clip_image=None, x_cat=None):
|
264 |
+
if x_cat is not None:
|
265 |
+
x = torch.cat([x, x_cat], dim=1)
|
266 |
+
# Process the conditioning embeddings
|
267 |
+
r_embed = self.gen_r_embedding(r)
|
268 |
+
c_embed = self.gen_c_embeddings(byt5, clip, clip_image)
|
269 |
+
|
270 |
+
# Model Blocks
|
271 |
+
x = self.embedding(self.in_mapper(x).permute(0, 3, 1, 2))
|
272 |
+
level_outputs = self._down_encode(x, r_embed, c_embed)
|
273 |
+
x = self._up_decode(level_outputs, r_embed, c_embed)
|
274 |
+
x = self.out_mapper(self.clf(x))
|
275 |
+
return x
|
276 |
+
|
277 |
+
def add_noise(self, x, t, mask=None, random_x=None):
|
278 |
+
if mask is None:
|
279 |
+
mask = (torch.rand_like(x.float()) <= t[:, None, None]).long()
|
280 |
+
if random_x is None:
|
281 |
+
random_x = torch.randint_like(x, 0, self.num_labels)
|
282 |
+
x = x * (1 - mask) + random_x * mask
|
283 |
+
return x, mask
|
Paella/src/train.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from modules import Paella
|
6 |
+
from torch import nn, optim
|
7 |
+
from warmup_scheduler import GradualWarmupScheduler
|
8 |
+
from utils import get_dataloader, load_conditional_models
|
9 |
+
|
10 |
+
steps = 100_000
|
11 |
+
warmup_updates = 10000
|
12 |
+
batch_size = 16
|
13 |
+
checkpoint_frequency = 2000
|
14 |
+
lr = 1e-4
|
15 |
+
|
16 |
+
train_device = "cuda"
|
17 |
+
dataset_path = ""
|
18 |
+
byt5_model_name = "google/byt5-xl"
|
19 |
+
vqmodel_path = ""
|
20 |
+
run_name = "Paella-ByT5-XL-v1"
|
21 |
+
output_path = "output"
|
22 |
+
checkpoint_path = f"{run_name}.pt"
|
23 |
+
|
24 |
+
|
25 |
+
def train():
|
26 |
+
os.makedirs(output_path, exist_ok=True)
|
27 |
+
device = torch.device(train_device)
|
28 |
+
|
29 |
+
dataloader = get_dataloader(dataset_path, batch_size=batch_size)
|
30 |
+
checkpoint = torch.load(checkpoint_path, map_location=device) if os.path.exists(checkpoint_path) else None
|
31 |
+
|
32 |
+
model = Paella(byt5_embd=2560).to(device)
|
33 |
+
vqgan, (byt5_tokenizer, byt5) = load_conditional_models(byt5_model_name, vqmodel_path, device)
|
34 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr)
|
35 |
+
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_updates)
|
36 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, reduction='none')
|
37 |
+
|
38 |
+
start_iter = 1
|
39 |
+
if checkpoint is not None:
|
40 |
+
model.load_state_dict(checkpoint['state_dict'])
|
41 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
42 |
+
scheduler.last_epoch = checkpoint['scheduler_last_step']
|
43 |
+
start_iter = checkpoint['scheduler_last_step'] + 1
|
44 |
+
del checkpoint
|
45 |
+
|
46 |
+
pbar = tqdm(range(start_iter, steps+1))
|
47 |
+
model.train()
|
48 |
+
for i, (images, captions) in enumerate(dataloader):
|
49 |
+
images = images.to(device)
|
50 |
+
|
51 |
+
with torch.no_grad():
|
52 |
+
if np.random.rand() < 0.05:
|
53 |
+
byt5_captions = [''] * len(captions)
|
54 |
+
else:
|
55 |
+
byt5_captions = captions
|
56 |
+
byt5_tokens = byt5_tokenizer(byt5_captions, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device)
|
57 |
+
byt_embeddings = byt5(input_ids=byt5_tokens).last_hidden_state
|
58 |
+
|
59 |
+
t = (1-torch.rand(images.size(0), device=device))
|
60 |
+
latents = vqgan.encode(images)[2]
|
61 |
+
noised_latents, _ = model.add_noise(latents, t)
|
62 |
+
|
63 |
+
pred = model(noised_latents, t, byt_embeddings)
|
64 |
+
loss = criterion(pred, latents)
|
65 |
+
|
66 |
+
loss.backward()
|
67 |
+
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
68 |
+
scheduler.step()
|
69 |
+
optimizer.zero_grad()
|
70 |
+
|
71 |
+
acc = (pred.argmax(1) == latents).float().mean()
|
72 |
+
|
73 |
+
pbar.set_postfix({'bs': images.size(0), 'loss': loss.item(), 'acc': acc.item(), 'grad_norm': grad_norm.item(), 'lr': optimizer.param_groups[0]['lr'], 'total_steps': scheduler.last_epoch})
|
74 |
+
|
75 |
+
if i % checkpoint_frequency == 0:
|
76 |
+
torch.save({'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_last_step': scheduler.last_epoch, 'iter' : i}, checkpoint_path)
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == '__main__':
|
80 |
+
train()
|
Paella/src/utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from vqgan import VQModel
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
6 |
+
|
7 |
+
transforms = torchvision.transforms.Compose([
|
8 |
+
torchvision.transforms.ToTensor(),
|
9 |
+
torchvision.transforms.Resize(256),
|
10 |
+
torchvision.transforms.RandomCrop(256),
|
11 |
+
])
|
12 |
+
|
13 |
+
|
14 |
+
class YOUR_DATASET(Dataset):
|
15 |
+
def __init__(self, dataset_path):
|
16 |
+
pass
|
17 |
+
|
18 |
+
|
19 |
+
def get_dataloader(dataset_path, batch_size):
|
20 |
+
dataset = YOUR_DATASET(dataset_path)
|
21 |
+
return DataLoader(dataset, batch_size=batch_size, num_workers=8, pin_memory=True)
|
22 |
+
|
23 |
+
|
24 |
+
def load_conditional_models(byt5_model_name, vqgan_path, device):
|
25 |
+
vqgan = VQModel().to(device)
|
26 |
+
vqgan.load_state_dict(torch.load(vqgan_path, map_location=device)['state_dict'])
|
27 |
+
vqgan.eval().requires_grad_(False)
|
28 |
+
|
29 |
+
byt5 = T5EncoderModel.from_pretrained(byt5_model_name).to(device).eval().requires_grad_(False)
|
30 |
+
byt5_tokenizer = AutoTokenizer.from_pretrained(byt5_model_name)
|
31 |
+
|
32 |
+
return vqgan, (byt5_tokenizer, byt5)
|
33 |
+
|
34 |
+
|
35 |
+
def sample(model, model_inputs, latent_shape, unconditional_inputs=None, steps=12, renoise_steps=11, temperature=(1.0, 0.2), cfg=8.0, t_start=1.0, t_end=0.0, device="cuda"):
|
36 |
+
with torch.inference_mode():
|
37 |
+
sampled = torch.randint(0, model.num_labels, size=latent_shape, device=device)
|
38 |
+
init_noise = sampled.clone()
|
39 |
+
t_list = torch.linspace(t_start, t_end, steps+1)
|
40 |
+
temperatures = torch.linspace(temperature[0], temperature[1], steps)
|
41 |
+
for i, t in enumerate(t_list[:steps]):
|
42 |
+
t = torch.ones(latent_shape[0], device=device) * t
|
43 |
+
|
44 |
+
logits = model(sampled, t, **model_inputs)
|
45 |
+
if cfg:
|
46 |
+
logits = logits * cfg + model(sampled, t, **unconditional_inputs) * (1-cfg)
|
47 |
+
scores = logits.div(temperatures[i]).softmax(dim=1)
|
48 |
+
|
49 |
+
sampled = scores.permute(0, 2, 3, 1).reshape(-1, logits.size(1))
|
50 |
+
sampled = torch.multinomial(sampled, 1)[:, 0].view(logits.size(0), *logits.shape[2:])
|
51 |
+
|
52 |
+
if i < renoise_steps:
|
53 |
+
t_next = torch.ones(latent_shape[0], device=device) * t_list[i+1]
|
54 |
+
sampled = model.add_noise(sampled, t_next, random_x=init_noise)[0]
|
55 |
+
return sampled
|
Paella/src/vqgan.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torchtools.nn import VectorQuantize
|
4 |
+
|
5 |
+
|
6 |
+
class ResBlock(nn.Module):
|
7 |
+
def __init__(self, c, c_hidden):
|
8 |
+
super().__init__()
|
9 |
+
# depthwise/attention
|
10 |
+
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
11 |
+
self.depthwise = nn.Sequential(
|
12 |
+
nn.ReplicationPad2d(1),
|
13 |
+
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
14 |
+
)
|
15 |
+
|
16 |
+
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
17 |
+
self.channelwise = nn.Sequential(
|
18 |
+
nn.Linear(c, c_hidden),
|
19 |
+
nn.GELU(),
|
20 |
+
nn.Linear(c_hidden, c),
|
21 |
+
)
|
22 |
+
|
23 |
+
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
24 |
+
|
25 |
+
def _basic_init(module):
|
26 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
27 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
28 |
+
if module.bias is not None:
|
29 |
+
nn.init.constant_(module.bias, 0)
|
30 |
+
|
31 |
+
self.apply(_basic_init)
|
32 |
+
|
33 |
+
def _norm(self, x, norm):
|
34 |
+
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
mods = self.gammas
|
38 |
+
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
|
39 |
+
x = x + self.depthwise(x_temp) * mods[2]
|
40 |
+
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
|
41 |
+
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
class VQModel(nn.Module):
|
46 |
+
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
|
47 |
+
scale_factor=0.3764): # 1.0
|
48 |
+
super().__init__()
|
49 |
+
self.c_latent = c_latent
|
50 |
+
self.scale_factor = scale_factor
|
51 |
+
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
52 |
+
|
53 |
+
# Encoder blocks
|
54 |
+
self.in_block = nn.Sequential(
|
55 |
+
nn.PixelUnshuffle(2),
|
56 |
+
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
57 |
+
)
|
58 |
+
down_blocks = []
|
59 |
+
for i in range(levels):
|
60 |
+
if i > 0:
|
61 |
+
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
62 |
+
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
63 |
+
down_blocks.append(block)
|
64 |
+
down_blocks.append(nn.Sequential(
|
65 |
+
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
66 |
+
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
67 |
+
))
|
68 |
+
self.down_blocks = nn.Sequential(*down_blocks)
|
69 |
+
|
70 |
+
self.codebook_size = codebook_size
|
71 |
+
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
|
72 |
+
|
73 |
+
# Decoder blocks
|
74 |
+
up_blocks = [nn.Sequential(
|
75 |
+
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
76 |
+
)]
|
77 |
+
for i in range(levels):
|
78 |
+
for j in range(bottleneck_blocks if i == 0 else 1):
|
79 |
+
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
|
80 |
+
up_blocks.append(block)
|
81 |
+
if i < levels - 1:
|
82 |
+
up_blocks.append(
|
83 |
+
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
84 |
+
padding=1))
|
85 |
+
self.up_blocks = nn.Sequential(*up_blocks)
|
86 |
+
self.out_block = nn.Sequential(
|
87 |
+
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
88 |
+
nn.PixelShuffle(2),
|
89 |
+
)
|
90 |
+
|
91 |
+
def encode(self, x):
|
92 |
+
x = self.in_block(x)
|
93 |
+
x = self.down_blocks(x)
|
94 |
+
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
95 |
+
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
|
96 |
+
|
97 |
+
def decode(self, x):
|
98 |
+
x = x * self.scale_factor
|
99 |
+
x = self.up_blocks(x)
|
100 |
+
x = self.out_block(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
def decode_indices(self, x):
|
104 |
+
x = self.vquantizer.idx2vq(x, dim=1)
|
105 |
+
x = self.up_blocks(x)
|
106 |
+
x = self.out_block(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
def forward(self, x, quantize=False):
|
110 |
+
qe, x, _, vq_loss = self.encode(x, quantize)
|
111 |
+
x = self.decode(qe)
|
112 |
+
return x, vq_loss
|
113 |
+
|
114 |
+
|
115 |
+
class Discriminator(nn.Module):
|
116 |
+
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
|
117 |
+
super().__init__()
|
118 |
+
d = max(depth - 3, 3)
|
119 |
+
layers = [
|
120 |
+
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
121 |
+
nn.LeakyReLU(0.2),
|
122 |
+
]
|
123 |
+
for i in range(depth - 1):
|
124 |
+
c_in = c_hidden // (2 ** max((d - i), 0))
|
125 |
+
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
126 |
+
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
127 |
+
layers.append(nn.InstanceNorm2d(c_out))
|
128 |
+
layers.append(nn.LeakyReLU(0.2))
|
129 |
+
self.encoder = nn.Sequential(*layers)
|
130 |
+
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
131 |
+
self.logits = nn.Sigmoid()
|
132 |
+
|
133 |
+
def forward(self, x, cond=None):
|
134 |
+
x = self.encoder(x)
|
135 |
+
if cond is not None:
|
136 |
+
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
|
137 |
+
x = torch.cat([x, cond], dim=1)
|
138 |
+
x = self.shuffle(x)
|
139 |
+
x = self.logits(x)
|
140 |
+
return x
|
Paella/utils/alter_attention.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class CustomMultiheadAttention(nn.MultiheadAttention):
|
5 |
+
def forward(self, *args, attn_weights=None, **kwargs):
|
6 |
+
q, k, v = args[:3]
|
7 |
+
need_weights = kwargs.get('need_weights', False)
|
8 |
+
|
9 |
+
w = self.in_proj_weight.chunk(3, dim=0)
|
10 |
+
b = self.in_proj_bias.chunk(3, dim=0)
|
11 |
+
|
12 |
+
if not self.batch_first:
|
13 |
+
q, k, v = q.permute(0, 1), k.permute(0, 1), v.permute(0, 1)
|
14 |
+
|
15 |
+
q = nn.functional.linear(q, w[0], bias=b[0]).view(q.size(0), q.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
|
16 |
+
k = nn.functional.linear(k, w[1], bias=b[1]).view(k.size(0), k.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
|
17 |
+
v = nn.functional.linear(v, w[2], bias=b[2]).view(v.size(0), v.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
|
18 |
+
|
19 |
+
scores = (q @ k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
|
20 |
+
attention = scores.softmax(dim=-1)
|
21 |
+
# print(attention.shape)
|
22 |
+
|
23 |
+
if attn_weights is not None:
|
24 |
+
# print("q ", q.shape)
|
25 |
+
# print("k ", k.shape)
|
26 |
+
weights = torch.ones((attention.shape[2], attention.shape[3])).to(q.device)
|
27 |
+
# print("Weights: ", weights.shape)
|
28 |
+
attn_weights = attn_weights.expand(attention.shape[2], attn_weights.shape[0])
|
29 |
+
weights[-attn_weights.shape[0]:, -attn_weights.shape[1]:] = attn_weights
|
30 |
+
# print(f"{-attn_weights.shape[0]}, {-attn_weights.shape[1]}")
|
31 |
+
attn_weights = weights.clone()
|
32 |
+
# print("Attn Weights: ", weights.shape)
|
33 |
+
# print("weight", attn_weights.shape)
|
34 |
+
attention = attention * attn_weights
|
35 |
+
|
36 |
+
x = attention @ v
|
37 |
+
x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1)
|
38 |
+
x = self.out_proj(x)
|
39 |
+
|
40 |
+
if not self.batch_first:
|
41 |
+
x = x.permute(0, 1)
|
42 |
+
|
43 |
+
return (x, attention if need_weights else None)
|
44 |
+
|
45 |
+
def replace_attention_layers(model):
|
46 |
+
for n, module in model.named_children():
|
47 |
+
if len(list(module.children())) > 0:
|
48 |
+
replace_attention_layers(module)
|
49 |
+
|
50 |
+
if isinstance(module, nn.MultiheadAttention):
|
51 |
+
new_module = CustomMultiheadAttention(module.embed_dim, module.num_heads, dropout=module.dropout, bias=True, batch_first=module.batch_first)
|
52 |
+
new_module.load_state_dict(module.state_dict())
|
53 |
+
setattr(model, n, new_module)
|
Paella/utils/modules.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
class Attention2D(nn.Module):
|
8 |
+
def __init__(self, c, nhead, dropout=0.0):
|
9 |
+
super().__init__()
|
10 |
+
self.attn = torch.nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
|
11 |
+
|
12 |
+
def forward(self, x, kv, self_attn=False, **kwargs):
|
13 |
+
orig_shape = x.shape
|
14 |
+
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
15 |
+
if self_attn:
|
16 |
+
kv = torch.cat([x, kv], dim=1)
|
17 |
+
x = self.attn(x, kv, kv, need_weights=False, **kwargs)[0]
|
18 |
+
x = x.permute(0, 2, 1).view(*orig_shape)
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class LayerNorm2d(nn.LayerNorm):
|
23 |
+
def __init__(self, *args, **kwargs):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
28 |
+
|
29 |
+
|
30 |
+
class GlobalResponseNorm(nn.Module):
|
31 |
+
"Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
32 |
+
def __init__(self, dim):
|
33 |
+
super().__init__()
|
34 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
35 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
39 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
40 |
+
return self.gamma * (x * Nx) + self.beta + x
|
41 |
+
|
42 |
+
|
43 |
+
class ResBlock(nn.Module):
|
44 |
+
def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
|
45 |
+
super().__init__()
|
46 |
+
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
47 |
+
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
|
48 |
+
self.channelwise = nn.Sequential(
|
49 |
+
nn.Linear(c, c * 4),
|
50 |
+
nn.GELU(),
|
51 |
+
GlobalResponseNorm(c * 4),
|
52 |
+
nn.Dropout(dropout),
|
53 |
+
nn.Linear(c * 4, c)
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x, x_skip=None):
|
57 |
+
x_res = x
|
58 |
+
if x_skip is not None:
|
59 |
+
x = torch.cat([x, x_skip], dim=1)
|
60 |
+
x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
|
61 |
+
x = self.channelwise(x).permute(0, 3, 1, 2)
|
62 |
+
return x + x_res
|
63 |
+
|
64 |
+
|
65 |
+
class AttnBlock(nn.Module):
|
66 |
+
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
67 |
+
super().__init__()
|
68 |
+
self.self_attn = self_attn
|
69 |
+
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
|
70 |
+
self.attention = Attention2D(c, nhead, dropout)
|
71 |
+
self.kv_mapper = nn.Sequential(
|
72 |
+
nn.SiLU(),
|
73 |
+
nn.Linear(c_cond, c)
|
74 |
+
)
|
75 |
+
|
76 |
+
def forward(self, x, kv, **kwargs):
|
77 |
+
kv = self.kv_mapper(kv)
|
78 |
+
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, **kwargs)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class FeedForwardBlock(nn.Module):
|
83 |
+
def __init__(self, c, dropout=0.0):
|
84 |
+
super().__init__()
|
85 |
+
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
|
86 |
+
self.channelwise = nn.Sequential(
|
87 |
+
nn.Linear(c, c * 4),
|
88 |
+
nn.GELU(),
|
89 |
+
GlobalResponseNorm(c * 4),
|
90 |
+
nn.Dropout(dropout),
|
91 |
+
nn.Linear(c * 4, c)
|
92 |
+
)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
96 |
+
return x
|
97 |
+
|
98 |
+
|
99 |
+
class TimestepBlock(nn.Module):
|
100 |
+
def __init__(self, c, c_timestep):
|
101 |
+
super().__init__()
|
102 |
+
self.mapper = nn.Linear(c_timestep, c * 2)
|
103 |
+
|
104 |
+
def forward(self, x, t):
|
105 |
+
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
|
106 |
+
return x * (1 + a) + b
|
107 |
+
|
108 |
+
|
109 |
+
class Paella(nn.Module):
|
110 |
+
def __init__(self, c_in=256, c_out=256, num_labels=8192, c_r=64, patch_size=2, c_cond=1024,
|
111 |
+
c_hidden=[640, 1280, 1280], nhead=[-1, 16, 16], blocks=[6, 16, 6], level_config=['CT', 'CTA', 'CTA'],
|
112 |
+
clip_embd=1024, byt5_embd=1536, clip_seq_len=4, kernel_size=3, dropout=0.1, self_attn=True):
|
113 |
+
super().__init__()
|
114 |
+
self.c_r = c_r
|
115 |
+
self.c_cond = c_cond
|
116 |
+
self.num_labels = num_labels
|
117 |
+
if not isinstance(dropout, list):
|
118 |
+
dropout = [dropout] * len(c_hidden)
|
119 |
+
|
120 |
+
# CONDITIONING
|
121 |
+
self.byt5_mapper = nn.Linear(byt5_embd, c_cond)
|
122 |
+
self.clip_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len)
|
123 |
+
self.clip_image_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len)
|
124 |
+
self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
|
125 |
+
|
126 |
+
self.in_mapper = nn.Sequential(
|
127 |
+
nn.Embedding(num_labels, c_in),
|
128 |
+
nn.LayerNorm(c_in, elementwise_affine=False, eps=1e-6)
|
129 |
+
)
|
130 |
+
self.embedding = nn.Sequential(
|
131 |
+
nn.PixelUnshuffle(patch_size),
|
132 |
+
nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
|
133 |
+
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
|
134 |
+
)
|
135 |
+
|
136 |
+
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0):
|
137 |
+
if block_type == 'C':
|
138 |
+
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
|
139 |
+
elif block_type == 'A':
|
140 |
+
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
|
141 |
+
elif block_type == 'F':
|
142 |
+
return FeedForwardBlock(c_hidden, dropout=dropout)
|
143 |
+
elif block_type == 'T':
|
144 |
+
return TimestepBlock(c_hidden, c_r)
|
145 |
+
else:
|
146 |
+
raise Exception(f'Block type {block_type} not supported')
|
147 |
+
|
148 |
+
# DOWN BLOCK
|
149 |
+
self.down_blocks = nn.ModuleList()
|
150 |
+
for i in range(len(c_hidden)):
|
151 |
+
down_block = nn.ModuleList()
|
152 |
+
if i > 0:
|
153 |
+
down_block.append(nn.Sequential(
|
154 |
+
LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
|
155 |
+
nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
|
156 |
+
))
|
157 |
+
for _ in range(blocks[i]):
|
158 |
+
for block_type in level_config[i]:
|
159 |
+
down_block.append(get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i]))
|
160 |
+
self.down_blocks.append(down_block)
|
161 |
+
|
162 |
+
# UP BLOCKS
|
163 |
+
self.up_blocks = nn.ModuleList()
|
164 |
+
for i in reversed(range(len(c_hidden))):
|
165 |
+
up_block = nn.ModuleList()
|
166 |
+
for j in range(blocks[i]):
|
167 |
+
for k, block_type in enumerate(level_config[i]):
|
168 |
+
up_block.append(get_block(block_type, c_hidden[i], nhead[i],
|
169 |
+
c_skip=c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0,
|
170 |
+
dropout=dropout[i]))
|
171 |
+
if i > 0:
|
172 |
+
up_block.append(nn.Sequential(
|
173 |
+
LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
|
174 |
+
nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
|
175 |
+
))
|
176 |
+
self.up_blocks.append(up_block)
|
177 |
+
|
178 |
+
# OUTPUT
|
179 |
+
self.clf = nn.Sequential(
|
180 |
+
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
|
181 |
+
nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
|
182 |
+
nn.PixelShuffle(patch_size),
|
183 |
+
)
|
184 |
+
self.out_mapper = nn.Sequential(
|
185 |
+
LayerNorm2d(c_out, elementwise_affine=False, eps=1e-6),
|
186 |
+
nn.Conv2d(c_out, num_labels, kernel_size=1, bias=False)
|
187 |
+
)
|
188 |
+
|
189 |
+
# --- WEIGHT INIT ---
|
190 |
+
self.apply(self._init_weights)
|
191 |
+
nn.init.normal_(self.byt5_mapper.weight, std=0.02)
|
192 |
+
nn.init.normal_(self.clip_mapper.weight, std=0.02)
|
193 |
+
nn.init.normal_(self.clip_image_mapper.weight, std=0.02)
|
194 |
+
torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
195 |
+
nn.init.constant_(self.clf[1].weight, 0) # outputs
|
196 |
+
nn.init.normal_(self.in_mapper[0].weight, std=np.sqrt(1 / num_labels)) # out mapper
|
197 |
+
self.out_mapper[-1].weight.data = self.in_mapper[0].weight.data[:, :, None, None].clone()
|
198 |
+
|
199 |
+
for level_block in self.down_blocks + self.up_blocks:
|
200 |
+
for block in level_block:
|
201 |
+
if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
202 |
+
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks))
|
203 |
+
elif isinstance(block, TimestepBlock):
|
204 |
+
nn.init.constant_(block.mapper.weight, 0)
|
205 |
+
|
206 |
+
def _init_weights(self, m):
|
207 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
208 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
209 |
+
if m.bias is not None:
|
210 |
+
nn.init.constant_(m.bias, 0)
|
211 |
+
|
212 |
+
def gen_r_embedding(self, r, max_positions=10000):
|
213 |
+
r = r * max_positions
|
214 |
+
half_dim = self.c_r // 2
|
215 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
216 |
+
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
217 |
+
emb = r[:, None] * emb[None, :]
|
218 |
+
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
219 |
+
if self.c_r % 2 == 1: # zero pad
|
220 |
+
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
221 |
+
return emb
|
222 |
+
|
223 |
+
def gen_c_embeddings(self, byt5, clip, clip_image):
|
224 |
+
seq = self.byt5_mapper(byt5)
|
225 |
+
if clip is not None:
|
226 |
+
clip = self.clip_mapper(clip).view(clip.size(0), -1, self.c_cond)
|
227 |
+
seq = torch.cat([seq, clip], dim=1)
|
228 |
+
if clip_image is not None:
|
229 |
+
if isinstance(clip_image, list):
|
230 |
+
for ci in clip_image:
|
231 |
+
ci = self.clip_image_mapper(ci).view(ci.size(0), -1, self.c_cond)
|
232 |
+
seq = torch.cat([seq, ci], dim=1)
|
233 |
+
else:
|
234 |
+
clip_image = self.clip_image_mapper(clip_image).view(clip_image.size(0), -1, self.c_cond)
|
235 |
+
seq = torch.cat([seq, clip_image], dim=1)
|
236 |
+
seq = self.seq_norm(seq)
|
237 |
+
return seq
|
238 |
+
|
239 |
+
def _down_encode(self, x, r_embed, c_embed, **kwargs):
|
240 |
+
level_outputs = []
|
241 |
+
for down_block in self.down_blocks:
|
242 |
+
for block in down_block:
|
243 |
+
if isinstance(block, ResBlock):
|
244 |
+
x = block(x)
|
245 |
+
elif isinstance(block, AttnBlock):
|
246 |
+
x = block(x, c_embed, **kwargs)
|
247 |
+
elif isinstance(block, TimestepBlock):
|
248 |
+
x = block(x, r_embed)
|
249 |
+
else:
|
250 |
+
x = block(x)
|
251 |
+
level_outputs.insert(0, x)
|
252 |
+
return level_outputs
|
253 |
+
|
254 |
+
def _up_decode(self, level_outputs, r_embed, c_embed, **kwargs):
|
255 |
+
x = level_outputs[0]
|
256 |
+
for i, up_block in enumerate(self.up_blocks):
|
257 |
+
for j, block in enumerate(up_block):
|
258 |
+
if isinstance(block, ResBlock):
|
259 |
+
x = block(x, level_outputs[i] if j == 0 and i > 0 else None)
|
260 |
+
elif isinstance(block, AttnBlock):
|
261 |
+
x = block(x, c_embed, **kwargs)
|
262 |
+
elif isinstance(block, TimestepBlock):
|
263 |
+
x = block(x, r_embed)
|
264 |
+
else:
|
265 |
+
x = block(x)
|
266 |
+
return x
|
267 |
+
|
268 |
+
def forward(self, x, r, byt5, clip=None, clip_image=None, x_cat=None, **kwargs):
|
269 |
+
if x_cat is not None:
|
270 |
+
x = torch.cat([x, x_cat], dim=1)
|
271 |
+
# Process the conditioning embeddings
|
272 |
+
r_embed = self.gen_r_embedding(r)
|
273 |
+
c_embed = self.gen_c_embeddings(byt5, clip, clip_image)
|
274 |
+
|
275 |
+
# Model Blocks
|
276 |
+
x = self.embedding(self.in_mapper(x).permute(0, 3, 1, 2))
|
277 |
+
level_outputs = self._down_encode(x, r_embed, c_embed, **kwargs)
|
278 |
+
x = self._up_decode(level_outputs, r_embed, c_embed, **kwargs)
|
279 |
+
x = self.out_mapper(self.clf(x))
|
280 |
+
return x
|
281 |
+
|
282 |
+
def add_noise(self, x, t, mask=None, random_x=None):
|
283 |
+
if mask is None:
|
284 |
+
mask = (torch.rand_like(x.float()) <= t[:, None, None]).long()
|
285 |
+
if random_x is None:
|
286 |
+
random_x = torch.randint_like(x, 0, self.num_labels)
|
287 |
+
x = x * (1 - mask) + random_x * mask
|
288 |
+
return x, mask
|
289 |
+
|
290 |
+
def get_loss_weight(self, t, mask, min_val=0.3):
|
291 |
+
return 1 - (1 - mask) * ((1 - t) * (1 - min_val))[:, None, None]
|