Spaces:
Running
Running
Upload dalle/models/__init__.py with huggingface_hub
Browse files- dalle/models/__init__.py +198 -0
dalle/models/__init__.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import logging
|
10 |
+
import torch.nn as nn
|
11 |
+
import pytorch_lightning as pl
|
12 |
+
from typing import Optional, Tuple
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from torch.cuda.amp import autocast
|
15 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from .stage1.vqgan import VQGAN
|
18 |
+
from .stage2.transformer import Transformer1d, iGPT
|
19 |
+
from .. import utils
|
20 |
+
from ..utils.config import get_base_config
|
21 |
+
from ..utils.sampling import sampling, sampling_igpt
|
22 |
+
from .tokenizer import build_tokenizer
|
23 |
+
|
24 |
+
_MODELS = {
|
25 |
+
'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class Dalle(nn.Module):
|
30 |
+
def __init__(self,
|
31 |
+
config: OmegaConf) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.tokenizer = None
|
34 |
+
self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
|
35 |
+
embed_dim=config.stage1.embed_dim,
|
36 |
+
hparams=config.stage1.hparams)
|
37 |
+
self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt,
|
38 |
+
vocab_size_img=config.stage2.vocab_size_img,
|
39 |
+
hparams=config.stage2.hparams)
|
40 |
+
self.config_stage1 = config.stage1
|
41 |
+
self.config_stage2 = config.stage2
|
42 |
+
self.config_dataset = config.dataset
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def from_pretrained(cls,
|
46 |
+
path: str) -> nn.Module:
|
47 |
+
config_base = get_base_config()
|
48 |
+
config_new = OmegaConf.load('config.yaml')
|
49 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
50 |
+
|
51 |
+
model = cls(config_update)
|
52 |
+
model.tokenizer = build_tokenizer('tokenizer',
|
53 |
+
context_length=model.config_dataset.context_length,
|
54 |
+
lowercase=True,
|
55 |
+
dropout=None)
|
56 |
+
return model
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def sampling(self,
|
60 |
+
prompt: str,
|
61 |
+
top_k: int = 256,
|
62 |
+
top_p: Optional[float] = None,
|
63 |
+
softmax_temperature: float = 1.0,
|
64 |
+
num_candidates: int = 96,
|
65 |
+
device: str = 'cuda:0',
|
66 |
+
use_fp16: bool = True) -> torch.FloatTensor:
|
67 |
+
self.stage1.eval()
|
68 |
+
self.stage2.eval()
|
69 |
+
|
70 |
+
tokens = self.tokenizer.encode(prompt)
|
71 |
+
tokens = torch.LongTensor(tokens.ids)
|
72 |
+
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
73 |
+
|
74 |
+
# Check if the encoding works as intended
|
75 |
+
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
76 |
+
|
77 |
+
tokens = tokens.to(device)
|
78 |
+
codes = sampling(self.stage2,
|
79 |
+
tokens,
|
80 |
+
top_k=top_k,
|
81 |
+
top_p=top_p,
|
82 |
+
softmax_temperature=softmax_temperature,
|
83 |
+
use_fp16=use_fp16)
|
84 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
85 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
86 |
+
return pixels
|
87 |
+
|
88 |
+
|
89 |
+
class ImageGPT(pl.LightningModule):
|
90 |
+
def __init__(self,
|
91 |
+
config: OmegaConf) -> None:
|
92 |
+
super().__init__()
|
93 |
+
self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
|
94 |
+
embed_dim=config.stage1.embed_dim,
|
95 |
+
hparams=config.stage1.hparams)
|
96 |
+
self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img,
|
97 |
+
use_cls_cond=config.stage2.use_cls_cond,
|
98 |
+
hparams=config.stage2.hparams)
|
99 |
+
self.config = config
|
100 |
+
self.use_cls_cond = config.stage2.use_cls_cond
|
101 |
+
|
102 |
+
# make the parameters in stage 1 not trainable
|
103 |
+
self.stage1.eval()
|
104 |
+
for p in self.stage1.parameters():
|
105 |
+
p.requires_grad = False
|
106 |
+
|
107 |
+
@classmethod
|
108 |
+
def from_pretrained(cls,
|
109 |
+
path_upstream: str,
|
110 |
+
path_downstream: str) -> Tuple[nn.Module, OmegaConf]:
|
111 |
+
config_base = get_base_config(use_default=False)
|
112 |
+
config_down = OmegaConf.load(path_downstream)
|
113 |
+
config_down = OmegaConf.merge(config_base, config_down)
|
114 |
+
|
115 |
+
model = cls(config_down)
|
116 |
+
model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True)
|
117 |
+
model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False)
|
118 |
+
return model, config_down
|
119 |
+
|
120 |
+
def sample(self,
|
121 |
+
cls_idx: Optional[int] = None,
|
122 |
+
top_k: int = 256,
|
123 |
+
top_p: Optional[float] = None,
|
124 |
+
softmax_temperature: float = 1.0,
|
125 |
+
num_candidates: int = 16,
|
126 |
+
device: str = 'cuda:0',
|
127 |
+
use_fp16: bool = True,
|
128 |
+
is_tqdm: bool = True) -> torch.FloatTensor:
|
129 |
+
self.stage1.eval()
|
130 |
+
self.stage2.eval()
|
131 |
+
|
132 |
+
if cls_idx is None:
|
133 |
+
sos = self.stage2.sos.repeat(num_candidates, 1, 1)
|
134 |
+
else:
|
135 |
+
sos = torch.LongTensor([cls_idx]).to(device=device)
|
136 |
+
sos = sos.repeat(num_candidates)
|
137 |
+
sos = self.stage2.sos(sos).unsqueeze(1)
|
138 |
+
|
139 |
+
codes = sampling_igpt(self.stage2,
|
140 |
+
sos=sos,
|
141 |
+
top_k=top_k,
|
142 |
+
top_p=top_p,
|
143 |
+
softmax_temperature=softmax_temperature,
|
144 |
+
use_fp16=use_fp16,
|
145 |
+
is_tqdm=is_tqdm)
|
146 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
147 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
148 |
+
return pixels
|
149 |
+
|
150 |
+
def forward(self,
|
151 |
+
images: torch.FloatTensor,
|
152 |
+
labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
|
153 |
+
B, C, H, W = images.shape
|
154 |
+
with torch.no_grad():
|
155 |
+
with autocast(enabled=False):
|
156 |
+
codes = self.stage1.get_codes(images).detach()
|
157 |
+
logits = self.stage2(codes, labels)
|
158 |
+
return logits, codes
|
159 |
+
|
160 |
+
def training_step(self, batch, batch_idx):
|
161 |
+
images, labels = batch
|
162 |
+
logits, codes = self(images, labels=labels if self.use_cls_cond else None)
|
163 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
|
164 |
+
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
165 |
+
return loss
|
166 |
+
|
167 |
+
def validation_step(self, batch, batch_idx):
|
168 |
+
images, labels = batch
|
169 |
+
logits, codes = self(images, labels=labels if self.use_cls_cond else None)
|
170 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
|
171 |
+
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
172 |
+
return loss
|
173 |
+
|
174 |
+
def configure_optimizers(self):
|
175 |
+
assert self.config.optimizer.opt_type == 'adamW'
|
176 |
+
assert self.config.optimizer.sched_type == 'cosine'
|
177 |
+
|
178 |
+
opt = torch.optim.AdamW(self.parameters(),
|
179 |
+
lr=self.config.optimizer.base_lr,
|
180 |
+
betas=self.config.optimizer.betas,
|
181 |
+
weight_decay=self.config.optimizer.weight_decay)
|
182 |
+
sched = CosineAnnealingLR(opt,
|
183 |
+
T_max=self.config.optimizer.max_steps,
|
184 |
+
eta_min=self.config.optimizer.min_lr)
|
185 |
+
sched = {
|
186 |
+
'scheduler': sched,
|
187 |
+
'name': 'cosine'
|
188 |
+
}
|
189 |
+
return [opt], [sched]
|
190 |
+
|
191 |
+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
|
192 |
+
on_tpu=False, using_native_amp=False, using_lbfgs=False):
|
193 |
+
optimizer.step(closure=optimizer_closure)
|
194 |
+
self.lr_schedulers().step()
|
195 |
+
self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
|
196 |
+
|
197 |
+
def on_epoch_start(self):
|
198 |
+
self.stage1.eval()
|