Spaces:
Runtime error
Runtime error
JustinLin610
commited on
Commit
•
dd78d66
1
Parent(s):
9eb2477
remove unnecessary files
Browse files- models/clip/__init__.py +0 -1
- models/clip/clip.py +0 -229
- models/clip/model.py +0 -437
- models/clip/simple_tokenizer.py +0 -132
- models/taming/lr_scheduler.py +0 -39
- models/taming/models/vqgan.py +0 -262
- models/taming/modules/diffusionmodules/model.py +0 -776
- models/taming/modules/discriminator/model.py +0 -67
- models/taming/modules/losses/__init__.py +0 -2
- models/taming/modules/losses/lpips.py +0 -123
- models/taming/modules/losses/segmentation.py +0 -22
- models/taming/modules/losses/vqperceptual.py +0 -136
- models/taming/modules/misc/coord.py +0 -31
- models/taming/modules/util.py +0 -130
- models/taming/modules/vqvae/quantize.py +0 -445
- models/taming/util.py +0 -172
- utils/eval_utils.py +1 -1
models/clip/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .clip import *
|
|
|
|
models/clip/clip.py
DELETED
@@ -1,229 +0,0 @@
|
|
1 |
-
import hashlib
|
2 |
-
import os
|
3 |
-
import urllib
|
4 |
-
import warnings
|
5 |
-
from typing import Any, Union, List
|
6 |
-
from pkg_resources import packaging
|
7 |
-
|
8 |
-
import torch
|
9 |
-
from PIL import Image
|
10 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
from .model import build_model
|
14 |
-
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
15 |
-
|
16 |
-
try:
|
17 |
-
from torchvision.transforms import InterpolationMode
|
18 |
-
BICUBIC = InterpolationMode.BICUBIC
|
19 |
-
except ImportError:
|
20 |
-
BICUBIC = Image.BICUBIC
|
21 |
-
|
22 |
-
|
23 |
-
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
24 |
-
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
25 |
-
|
26 |
-
|
27 |
-
__all__ = ["available_models", "load", "tokenize"]
|
28 |
-
_tokenizer = _Tokenizer()
|
29 |
-
|
30 |
-
_MODELS = {
|
31 |
-
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
32 |
-
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
33 |
-
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
34 |
-
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
35 |
-
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
36 |
-
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
37 |
-
}
|
38 |
-
|
39 |
-
|
40 |
-
def _download(url: str, root: str):
|
41 |
-
os.makedirs(root, exist_ok=True)
|
42 |
-
filename = os.path.basename(url)
|
43 |
-
|
44 |
-
expected_sha256 = url.split("/")[-2]
|
45 |
-
download_target = os.path.join(root, filename)
|
46 |
-
|
47 |
-
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
48 |
-
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
49 |
-
|
50 |
-
if os.path.isfile(download_target):
|
51 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
52 |
-
return download_target
|
53 |
-
else:
|
54 |
-
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
55 |
-
|
56 |
-
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
57 |
-
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
58 |
-
while True:
|
59 |
-
buffer = source.read(8192)
|
60 |
-
if not buffer:
|
61 |
-
break
|
62 |
-
|
63 |
-
output.write(buffer)
|
64 |
-
loop.update(len(buffer))
|
65 |
-
|
66 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
67 |
-
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
68 |
-
|
69 |
-
return download_target
|
70 |
-
|
71 |
-
|
72 |
-
def _convert_image_to_rgb(image):
|
73 |
-
return image.convert("RGB")
|
74 |
-
|
75 |
-
|
76 |
-
def _transform(n_px):
|
77 |
-
return Compose([
|
78 |
-
Resize(n_px, interpolation=BICUBIC),
|
79 |
-
CenterCrop(n_px),
|
80 |
-
_convert_image_to_rgb,
|
81 |
-
ToTensor(),
|
82 |
-
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
83 |
-
])
|
84 |
-
|
85 |
-
|
86 |
-
def available_models() -> List[str]:
|
87 |
-
"""Returns the names of available CLIP models"""
|
88 |
-
return list(_MODELS.keys())
|
89 |
-
|
90 |
-
|
91 |
-
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
92 |
-
"""Load a CLIP model
|
93 |
-
|
94 |
-
Parameters
|
95 |
-
----------
|
96 |
-
name : str
|
97 |
-
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
98 |
-
|
99 |
-
device : Union[str, torch.device]
|
100 |
-
The device to put the loaded model
|
101 |
-
|
102 |
-
jit : bool
|
103 |
-
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
104 |
-
|
105 |
-
download_root: str
|
106 |
-
path to download the model files; by default, it uses "~/.cache/clip"
|
107 |
-
|
108 |
-
Returns
|
109 |
-
-------
|
110 |
-
model : torch.nn.Module
|
111 |
-
The CLIP model
|
112 |
-
|
113 |
-
preprocess : Callable[[PIL.Image], torch.Tensor]
|
114 |
-
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
115 |
-
"""
|
116 |
-
if name in _MODELS:
|
117 |
-
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
118 |
-
elif os.path.isfile(name):
|
119 |
-
model_path = name
|
120 |
-
else:
|
121 |
-
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
122 |
-
|
123 |
-
try:
|
124 |
-
# loading JIT archive
|
125 |
-
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
126 |
-
state_dict = None
|
127 |
-
except RuntimeError:
|
128 |
-
# loading saved state dict
|
129 |
-
if jit:
|
130 |
-
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
131 |
-
jit = False
|
132 |
-
state_dict = torch.load(model_path, map_location="cpu")
|
133 |
-
|
134 |
-
if not jit:
|
135 |
-
model = build_model(state_dict or model.state_dict()).to(device)
|
136 |
-
if str(device) == "cpu":
|
137 |
-
model.float()
|
138 |
-
return model, _transform(model.visual.input_resolution)
|
139 |
-
|
140 |
-
# patch the device names
|
141 |
-
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
142 |
-
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
143 |
-
|
144 |
-
def patch_device(module):
|
145 |
-
try:
|
146 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
147 |
-
except RuntimeError:
|
148 |
-
graphs = []
|
149 |
-
|
150 |
-
if hasattr(module, "forward1"):
|
151 |
-
graphs.append(module.forward1.graph)
|
152 |
-
|
153 |
-
for graph in graphs:
|
154 |
-
for node in graph.findAllNodes("prim::Constant"):
|
155 |
-
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
156 |
-
node.copyAttributes(device_node)
|
157 |
-
|
158 |
-
model.apply(patch_device)
|
159 |
-
patch_device(model.encode_image)
|
160 |
-
patch_device(model.encode_text)
|
161 |
-
|
162 |
-
# patch dtype to float32 on CPU
|
163 |
-
if str(device) == "cpu":
|
164 |
-
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
165 |
-
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
166 |
-
float_node = float_input.node()
|
167 |
-
|
168 |
-
def patch_float(module):
|
169 |
-
try:
|
170 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
171 |
-
except RuntimeError:
|
172 |
-
graphs = []
|
173 |
-
|
174 |
-
if hasattr(module, "forward1"):
|
175 |
-
graphs.append(module.forward1.graph)
|
176 |
-
|
177 |
-
for graph in graphs:
|
178 |
-
for node in graph.findAllNodes("aten::to"):
|
179 |
-
inputs = list(node.inputs())
|
180 |
-
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
181 |
-
if inputs[i].node()["value"] == 5:
|
182 |
-
inputs[i].node().copyAttributes(float_node)
|
183 |
-
|
184 |
-
model.apply(patch_float)
|
185 |
-
patch_float(model.encode_image)
|
186 |
-
patch_float(model.encode_text)
|
187 |
-
|
188 |
-
model.float()
|
189 |
-
|
190 |
-
return model, _transform(model.input_resolution.item())
|
191 |
-
|
192 |
-
|
193 |
-
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
194 |
-
"""
|
195 |
-
Returns the tokenized representation of given input string(s)
|
196 |
-
|
197 |
-
Parameters
|
198 |
-
----------
|
199 |
-
texts : Union[str, List[str]]
|
200 |
-
An input string or a list of input strings to tokenize
|
201 |
-
|
202 |
-
context_length : int
|
203 |
-
The context length to use; all CLIP models use 77 as the context length
|
204 |
-
|
205 |
-
truncate: bool
|
206 |
-
Whether to truncate the text in case its encoding is longer than the context length
|
207 |
-
|
208 |
-
Returns
|
209 |
-
-------
|
210 |
-
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
211 |
-
"""
|
212 |
-
if isinstance(texts, str):
|
213 |
-
texts = [texts]
|
214 |
-
|
215 |
-
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
216 |
-
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
217 |
-
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
218 |
-
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
219 |
-
|
220 |
-
for i, tokens in enumerate(all_tokens):
|
221 |
-
if len(tokens) > context_length:
|
222 |
-
if truncate:
|
223 |
-
tokens = tokens[:context_length]
|
224 |
-
tokens[-1] = eot_token
|
225 |
-
else:
|
226 |
-
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
227 |
-
result[i, :len(tokens)] = torch.tensor(tokens)
|
228 |
-
|
229 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/clip/model.py
DELETED
@@ -1,437 +0,0 @@
|
|
1 |
-
from collections import OrderedDict
|
2 |
-
from typing import Tuple, Union
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from torch import nn
|
8 |
-
|
9 |
-
|
10 |
-
class Bottleneck(nn.Module):
|
11 |
-
expansion = 4
|
12 |
-
|
13 |
-
def __init__(self, inplanes, planes, stride=1):
|
14 |
-
super().__init__()
|
15 |
-
|
16 |
-
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
-
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
-
|
20 |
-
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
21 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
22 |
-
|
23 |
-
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
24 |
-
|
25 |
-
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
26 |
-
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
27 |
-
|
28 |
-
self.relu = nn.ReLU(inplace=True)
|
29 |
-
self.downsample = None
|
30 |
-
self.stride = stride
|
31 |
-
|
32 |
-
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
33 |
-
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
34 |
-
self.downsample = nn.Sequential(OrderedDict([
|
35 |
-
("-1", nn.AvgPool2d(stride)),
|
36 |
-
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
37 |
-
("1", nn.BatchNorm2d(planes * self.expansion))
|
38 |
-
]))
|
39 |
-
|
40 |
-
def forward(self, x: torch.Tensor):
|
41 |
-
identity = x
|
42 |
-
|
43 |
-
out = self.relu(self.bn1(self.conv1(x)))
|
44 |
-
out = self.relu(self.bn2(self.conv2(out)))
|
45 |
-
out = self.avgpool(out)
|
46 |
-
out = self.bn3(self.conv3(out))
|
47 |
-
|
48 |
-
if self.downsample is not None:
|
49 |
-
identity = self.downsample(x)
|
50 |
-
|
51 |
-
out += identity
|
52 |
-
out = self.relu(out)
|
53 |
-
return out
|
54 |
-
|
55 |
-
|
56 |
-
class AttentionPool2d(nn.Module):
|
57 |
-
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
58 |
-
super().__init__()
|
59 |
-
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
60 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
61 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
62 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
-
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
64 |
-
self.num_heads = num_heads
|
65 |
-
|
66 |
-
def forward(self, x):
|
67 |
-
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
68 |
-
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
69 |
-
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
70 |
-
x, _ = F.multi_head_attention_forward(
|
71 |
-
query=x, key=x, value=x,
|
72 |
-
embed_dim_to_check=x.shape[-1],
|
73 |
-
num_heads=self.num_heads,
|
74 |
-
q_proj_weight=self.q_proj.weight,
|
75 |
-
k_proj_weight=self.k_proj.weight,
|
76 |
-
v_proj_weight=self.v_proj.weight,
|
77 |
-
in_proj_weight=None,
|
78 |
-
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
79 |
-
bias_k=None,
|
80 |
-
bias_v=None,
|
81 |
-
add_zero_attn=False,
|
82 |
-
dropout_p=0,
|
83 |
-
out_proj_weight=self.c_proj.weight,
|
84 |
-
out_proj_bias=self.c_proj.bias,
|
85 |
-
use_separate_proj_weight=True,
|
86 |
-
training=self.training,
|
87 |
-
need_weights=False
|
88 |
-
)
|
89 |
-
|
90 |
-
return x[0]
|
91 |
-
|
92 |
-
|
93 |
-
class ModifiedResNet(nn.Module):
|
94 |
-
"""
|
95 |
-
A ResNet class that is similar to torchvision's but contains the following changes:
|
96 |
-
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
97 |
-
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
98 |
-
- The final pooling layer is a QKV attention instead of an average pool
|
99 |
-
"""
|
100 |
-
|
101 |
-
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
102 |
-
super().__init__()
|
103 |
-
self.output_dim = output_dim
|
104 |
-
self.input_resolution = input_resolution
|
105 |
-
|
106 |
-
# the 3-layer stem
|
107 |
-
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
108 |
-
self.bn1 = nn.BatchNorm2d(width // 2)
|
109 |
-
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
110 |
-
self.bn2 = nn.BatchNorm2d(width // 2)
|
111 |
-
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
112 |
-
self.bn3 = nn.BatchNorm2d(width)
|
113 |
-
self.avgpool = nn.AvgPool2d(2)
|
114 |
-
self.relu = nn.ReLU(inplace=True)
|
115 |
-
|
116 |
-
# residual layers
|
117 |
-
self._inplanes = width # this is a *mutable* variable used during construction
|
118 |
-
self.layer1 = self._make_layer(width, layers[0])
|
119 |
-
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
120 |
-
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
121 |
-
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
122 |
-
|
123 |
-
embed_dim = width * 32 # the ResNet feature dimension
|
124 |
-
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
125 |
-
|
126 |
-
def _make_layer(self, planes, blocks, stride=1):
|
127 |
-
layers = [Bottleneck(self._inplanes, planes, stride)]
|
128 |
-
|
129 |
-
self._inplanes = planes * Bottleneck.expansion
|
130 |
-
for _ in range(1, blocks):
|
131 |
-
layers.append(Bottleneck(self._inplanes, planes))
|
132 |
-
|
133 |
-
return nn.Sequential(*layers)
|
134 |
-
|
135 |
-
def forward(self, x):
|
136 |
-
def stem(x):
|
137 |
-
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
138 |
-
x = self.relu(bn(conv(x)))
|
139 |
-
x = self.avgpool(x)
|
140 |
-
return x
|
141 |
-
|
142 |
-
x = x.type(self.conv1.weight.dtype)
|
143 |
-
x = stem(x)
|
144 |
-
x = self.layer1(x)
|
145 |
-
x = self.layer2(x)
|
146 |
-
x = self.layer3(x)
|
147 |
-
x = self.layer4(x)
|
148 |
-
x = self.attnpool(x)
|
149 |
-
|
150 |
-
return x
|
151 |
-
|
152 |
-
|
153 |
-
class LayerNorm(nn.LayerNorm):
|
154 |
-
"""Subclass torch's LayerNorm to handle fp16."""
|
155 |
-
|
156 |
-
def forward(self, x: torch.Tensor):
|
157 |
-
orig_type = x.dtype
|
158 |
-
ret = super().forward(x.type(torch.float32))
|
159 |
-
return ret.type(orig_type)
|
160 |
-
|
161 |
-
|
162 |
-
class QuickGELU(nn.Module):
|
163 |
-
def forward(self, x: torch.Tensor):
|
164 |
-
return x * torch.sigmoid(1.702 * x)
|
165 |
-
|
166 |
-
|
167 |
-
class ResidualAttentionBlock(nn.Module):
|
168 |
-
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
169 |
-
super().__init__()
|
170 |
-
|
171 |
-
self.attn = nn.MultiheadAttention(d_model, n_head)
|
172 |
-
self.ln_1 = LayerNorm(d_model)
|
173 |
-
self.mlp = nn.Sequential(OrderedDict([
|
174 |
-
("c_fc", nn.Linear(d_model, d_model * 4)),
|
175 |
-
("gelu", QuickGELU()),
|
176 |
-
("c_proj", nn.Linear(d_model * 4, d_model))
|
177 |
-
]))
|
178 |
-
self.ln_2 = LayerNorm(d_model)
|
179 |
-
self.attn_mask = attn_mask
|
180 |
-
|
181 |
-
def attention(self, x: torch.Tensor):
|
182 |
-
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
183 |
-
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
184 |
-
|
185 |
-
def forward(self, x: torch.Tensor):
|
186 |
-
x = x + self.attention(self.ln_1(x))
|
187 |
-
x = x + self.mlp(self.ln_2(x))
|
188 |
-
return x
|
189 |
-
|
190 |
-
|
191 |
-
class Transformer(nn.Module):
|
192 |
-
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
193 |
-
super().__init__()
|
194 |
-
self.width = width
|
195 |
-
self.layers = layers
|
196 |
-
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
197 |
-
|
198 |
-
def forward(self, x: torch.Tensor):
|
199 |
-
return self.resblocks(x)
|
200 |
-
|
201 |
-
|
202 |
-
class VisionTransformer(nn.Module):
|
203 |
-
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
204 |
-
super().__init__()
|
205 |
-
self.input_resolution = input_resolution
|
206 |
-
self.output_dim = output_dim
|
207 |
-
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
208 |
-
|
209 |
-
scale = width ** -0.5
|
210 |
-
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
211 |
-
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
212 |
-
self.ln_pre = LayerNorm(width)
|
213 |
-
|
214 |
-
self.transformer = Transformer(width, layers, heads)
|
215 |
-
|
216 |
-
self.ln_post = LayerNorm(width)
|
217 |
-
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
218 |
-
|
219 |
-
def forward(self, x: torch.Tensor):
|
220 |
-
x = self.conv1(x) # shape = [*, width, grid, grid]
|
221 |
-
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
222 |
-
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
223 |
-
x = torch.cat(
|
224 |
-
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
225 |
-
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
226 |
-
x = x + self.positional_embedding.to(x.dtype)
|
227 |
-
x = self.ln_pre(x)
|
228 |
-
|
229 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
230 |
-
x = self.transformer(x)
|
231 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
232 |
-
|
233 |
-
x = self.ln_post(x[:, 0, :])
|
234 |
-
|
235 |
-
if self.proj is not None:
|
236 |
-
x = x @ self.proj
|
237 |
-
|
238 |
-
return x
|
239 |
-
|
240 |
-
|
241 |
-
class CLIP(nn.Module):
|
242 |
-
def __init__(self,
|
243 |
-
embed_dim: int,
|
244 |
-
# vision
|
245 |
-
image_resolution: int,
|
246 |
-
vision_layers: Union[Tuple[int, int, int, int], int],
|
247 |
-
vision_width: int,
|
248 |
-
vision_patch_size: int,
|
249 |
-
# text
|
250 |
-
context_length: int,
|
251 |
-
vocab_size: int,
|
252 |
-
transformer_width: int,
|
253 |
-
transformer_heads: int,
|
254 |
-
transformer_layers: int
|
255 |
-
):
|
256 |
-
super().__init__()
|
257 |
-
|
258 |
-
self.context_length = context_length
|
259 |
-
self.input_resolution = image_resolution
|
260 |
-
|
261 |
-
if isinstance(vision_layers, (tuple, list)):
|
262 |
-
vision_heads = vision_width * 32 // 64
|
263 |
-
self.visual = ModifiedResNet(
|
264 |
-
layers=vision_layers,
|
265 |
-
output_dim=embed_dim,
|
266 |
-
heads=vision_heads,
|
267 |
-
input_resolution=image_resolution,
|
268 |
-
width=vision_width
|
269 |
-
)
|
270 |
-
else:
|
271 |
-
vision_heads = vision_width // 64
|
272 |
-
self.visual = VisionTransformer(
|
273 |
-
input_resolution=image_resolution,
|
274 |
-
patch_size=vision_patch_size,
|
275 |
-
width=vision_width,
|
276 |
-
layers=vision_layers,
|
277 |
-
heads=vision_heads,
|
278 |
-
output_dim=embed_dim
|
279 |
-
)
|
280 |
-
|
281 |
-
self.transformer = Transformer(
|
282 |
-
width=transformer_width,
|
283 |
-
layers=transformer_layers,
|
284 |
-
heads=transformer_heads,
|
285 |
-
attn_mask=self.build_attention_mask()
|
286 |
-
)
|
287 |
-
|
288 |
-
self.vocab_size = vocab_size
|
289 |
-
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
290 |
-
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
291 |
-
self.ln_final = LayerNorm(transformer_width)
|
292 |
-
|
293 |
-
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
294 |
-
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
295 |
-
|
296 |
-
self.initialize_parameters()
|
297 |
-
|
298 |
-
def initialize_parameters(self):
|
299 |
-
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
300 |
-
nn.init.normal_(self.positional_embedding, std=0.01)
|
301 |
-
|
302 |
-
if isinstance(self.visual, ModifiedResNet):
|
303 |
-
if self.visual.attnpool is not None:
|
304 |
-
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
305 |
-
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
306 |
-
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
307 |
-
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
308 |
-
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
309 |
-
|
310 |
-
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
311 |
-
for name, param in resnet_block.named_parameters():
|
312 |
-
if name.endswith("bn3.weight"):
|
313 |
-
nn.init.zeros_(param)
|
314 |
-
|
315 |
-
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
316 |
-
attn_std = self.transformer.width ** -0.5
|
317 |
-
fc_std = (2 * self.transformer.width) ** -0.5
|
318 |
-
for block in self.transformer.resblocks:
|
319 |
-
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
320 |
-
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
321 |
-
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
322 |
-
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
323 |
-
|
324 |
-
if self.text_projection is not None:
|
325 |
-
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
326 |
-
|
327 |
-
def build_attention_mask(self):
|
328 |
-
# lazily create causal attention mask, with full attention between the vision tokens
|
329 |
-
# pytorch uses additive attention mask; fill with -inf
|
330 |
-
mask = torch.empty(self.context_length, self.context_length)
|
331 |
-
mask.fill_(float("-inf"))
|
332 |
-
mask.triu_(1) # zero out the lower diagonal
|
333 |
-
return mask
|
334 |
-
|
335 |
-
@property
|
336 |
-
def dtype(self):
|
337 |
-
return self.visual.conv1.weight.dtype
|
338 |
-
|
339 |
-
def encode_image(self, image):
|
340 |
-
return self.visual(image.type(self.dtype))
|
341 |
-
|
342 |
-
def encode_text(self, text):
|
343 |
-
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
344 |
-
|
345 |
-
x = x + self.positional_embedding.type(self.dtype)
|
346 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
347 |
-
x = self.transformer(x)
|
348 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
349 |
-
x = self.ln_final(x).type(self.dtype)
|
350 |
-
|
351 |
-
# x.shape = [batch_size, n_ctx, transformer.width]
|
352 |
-
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
353 |
-
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
354 |
-
|
355 |
-
return x
|
356 |
-
|
357 |
-
def forward(self, image, text):
|
358 |
-
image_features = self.encode_image(image)
|
359 |
-
text_features = self.encode_text(text)
|
360 |
-
|
361 |
-
# normalized features
|
362 |
-
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
363 |
-
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
364 |
-
|
365 |
-
# cosine similarity as logits
|
366 |
-
logit_scale = self.logit_scale.exp()
|
367 |
-
logits_per_image = logit_scale * image_features @ text_features.t()
|
368 |
-
logits_per_text = logits_per_image.t()
|
369 |
-
|
370 |
-
# shape = [global_batch_size, global_batch_size]
|
371 |
-
return logits_per_image, logits_per_text
|
372 |
-
|
373 |
-
|
374 |
-
def convert_weights(model: nn.Module):
|
375 |
-
"""Convert applicable model parameters to fp16"""
|
376 |
-
|
377 |
-
def _convert_weights_to_fp16(l):
|
378 |
-
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
379 |
-
l.weight.data = l.weight.data.half()
|
380 |
-
if l.bias is not None:
|
381 |
-
l.bias.data = l.bias.data.half()
|
382 |
-
|
383 |
-
if isinstance(l, nn.MultiheadAttention):
|
384 |
-
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
385 |
-
tensor = getattr(l, attr)
|
386 |
-
if tensor is not None:
|
387 |
-
tensor.data = tensor.data.half()
|
388 |
-
|
389 |
-
for name in ["text_projection", "proj"]:
|
390 |
-
if hasattr(l, name):
|
391 |
-
attr = getattr(l, name)
|
392 |
-
if attr is not None:
|
393 |
-
attr.data = attr.data.half()
|
394 |
-
|
395 |
-
model.apply(_convert_weights_to_fp16)
|
396 |
-
|
397 |
-
|
398 |
-
def build_model(state_dict: dict):
|
399 |
-
vit = "visual.proj" in state_dict
|
400 |
-
|
401 |
-
if vit:
|
402 |
-
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
403 |
-
vision_layers = len(
|
404 |
-
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
405 |
-
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
406 |
-
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
407 |
-
image_resolution = vision_patch_size * grid_size
|
408 |
-
else:
|
409 |
-
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in
|
410 |
-
[1, 2, 3, 4]]
|
411 |
-
vision_layers = tuple(counts)
|
412 |
-
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
413 |
-
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
414 |
-
vision_patch_size = None
|
415 |
-
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
416 |
-
image_resolution = output_width * 32
|
417 |
-
|
418 |
-
embed_dim = state_dict["text_projection"].shape[1]
|
419 |
-
context_length = state_dict["positional_embedding"].shape[0]
|
420 |
-
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
421 |
-
transformer_width = state_dict["ln_final.weight"].shape[0]
|
422 |
-
transformer_heads = transformer_width // 64
|
423 |
-
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
424 |
-
|
425 |
-
model = CLIP(
|
426 |
-
embed_dim,
|
427 |
-
image_resolution, vision_layers, vision_width, vision_patch_size,
|
428 |
-
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
429 |
-
)
|
430 |
-
|
431 |
-
for key in ["input_resolution", "context_length", "vocab_size"]:
|
432 |
-
if key in state_dict:
|
433 |
-
del state_dict[key]
|
434 |
-
|
435 |
-
convert_weights(model)
|
436 |
-
model.load_state_dict(state_dict)
|
437 |
-
return model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/clip/simple_tokenizer.py
DELETED
@@ -1,132 +0,0 @@
|
|
1 |
-
import gzip
|
2 |
-
import html
|
3 |
-
import os
|
4 |
-
from functools import lru_cache
|
5 |
-
|
6 |
-
import ftfy
|
7 |
-
import regex as re
|
8 |
-
|
9 |
-
|
10 |
-
@lru_cache()
|
11 |
-
def default_bpe():
|
12 |
-
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
-
|
14 |
-
|
15 |
-
@lru_cache()
|
16 |
-
def bytes_to_unicode():
|
17 |
-
"""
|
18 |
-
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
-
The reversible bpe codes work on unicode strings.
|
20 |
-
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
-
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
-
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
-
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
-
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
-
"""
|
26 |
-
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
-
cs = bs[:]
|
28 |
-
n = 0
|
29 |
-
for b in range(2**8):
|
30 |
-
if b not in bs:
|
31 |
-
bs.append(b)
|
32 |
-
cs.append(2**8+n)
|
33 |
-
n += 1
|
34 |
-
cs = [chr(n) for n in cs]
|
35 |
-
return dict(zip(bs, cs))
|
36 |
-
|
37 |
-
|
38 |
-
def get_pairs(word):
|
39 |
-
"""Return set of symbol pairs in a word.
|
40 |
-
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
-
"""
|
42 |
-
pairs = set()
|
43 |
-
prev_char = word[0]
|
44 |
-
for char in word[1:]:
|
45 |
-
pairs.add((prev_char, char))
|
46 |
-
prev_char = char
|
47 |
-
return pairs
|
48 |
-
|
49 |
-
|
50 |
-
def basic_clean(text):
|
51 |
-
text = ftfy.fix_text(text)
|
52 |
-
text = html.unescape(html.unescape(text))
|
53 |
-
return text.strip()
|
54 |
-
|
55 |
-
|
56 |
-
def whitespace_clean(text):
|
57 |
-
text = re.sub(r'\s+', ' ', text)
|
58 |
-
text = text.strip()
|
59 |
-
return text
|
60 |
-
|
61 |
-
|
62 |
-
class SimpleTokenizer(object):
|
63 |
-
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
-
self.byte_encoder = bytes_to_unicode()
|
65 |
-
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
-
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
-
merges = merges[1:49152-256-2+1]
|
68 |
-
merges = [tuple(merge.split()) for merge in merges]
|
69 |
-
vocab = list(bytes_to_unicode().values())
|
70 |
-
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
-
for merge in merges:
|
72 |
-
vocab.append(''.join(merge))
|
73 |
-
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
-
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
-
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
-
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
-
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
-
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
-
|
80 |
-
def bpe(self, token):
|
81 |
-
if token in self.cache:
|
82 |
-
return self.cache[token]
|
83 |
-
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
-
pairs = get_pairs(word)
|
85 |
-
|
86 |
-
if not pairs:
|
87 |
-
return token+'</w>'
|
88 |
-
|
89 |
-
while True:
|
90 |
-
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
-
if bigram not in self.bpe_ranks:
|
92 |
-
break
|
93 |
-
first, second = bigram
|
94 |
-
new_word = []
|
95 |
-
i = 0
|
96 |
-
while i < len(word):
|
97 |
-
try:
|
98 |
-
j = word.index(first, i)
|
99 |
-
new_word.extend(word[i:j])
|
100 |
-
i = j
|
101 |
-
except:
|
102 |
-
new_word.extend(word[i:])
|
103 |
-
break
|
104 |
-
|
105 |
-
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
-
new_word.append(first+second)
|
107 |
-
i += 2
|
108 |
-
else:
|
109 |
-
new_word.append(word[i])
|
110 |
-
i += 1
|
111 |
-
new_word = tuple(new_word)
|
112 |
-
word = new_word
|
113 |
-
if len(word) == 1:
|
114 |
-
break
|
115 |
-
else:
|
116 |
-
pairs = get_pairs(word)
|
117 |
-
word = ' '.join(word)
|
118 |
-
self.cache[token] = word
|
119 |
-
return word
|
120 |
-
|
121 |
-
def encode(self, text):
|
122 |
-
bpe_tokens = []
|
123 |
-
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
-
for token in re.findall(self.pat, text):
|
125 |
-
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
-
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
-
return bpe_tokens
|
128 |
-
|
129 |
-
def decode(self, tokens):
|
130 |
-
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
-
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/lr_scheduler.py
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
|
9 |
-
class LambdaWarmUpCosineScheduler:
|
10 |
-
"""
|
11 |
-
note: use with a base_lr of 1.0
|
12 |
-
"""
|
13 |
-
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
14 |
-
self.lr_warm_up_steps = warm_up_steps
|
15 |
-
self.lr_start = lr_start
|
16 |
-
self.lr_min = lr_min
|
17 |
-
self.lr_max = lr_max
|
18 |
-
self.lr_max_decay_steps = max_decay_steps
|
19 |
-
self.last_lr = 0.
|
20 |
-
self.verbosity_interval = verbosity_interval
|
21 |
-
|
22 |
-
def schedule(self, n):
|
23 |
-
if self.verbosity_interval > 0:
|
24 |
-
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
25 |
-
if n < self.lr_warm_up_steps:
|
26 |
-
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
27 |
-
self.last_lr = lr
|
28 |
-
return lr
|
29 |
-
else:
|
30 |
-
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
31 |
-
t = min(t, 1.0)
|
32 |
-
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
33 |
-
1 + np.cos(t * np.pi))
|
34 |
-
self.last_lr = lr
|
35 |
-
return lr
|
36 |
-
|
37 |
-
def __call__(self, n):
|
38 |
-
return self.schedule(n)
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/models/vqgan.py
DELETED
@@ -1,262 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn.functional as F
|
3 |
-
import pytorch_lightning as pl
|
4 |
-
|
5 |
-
from models.taming.util import instantiate_from_config
|
6 |
-
|
7 |
-
from models.taming.modules.diffusionmodules.model import Encoder, Decoder
|
8 |
-
from models.taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
9 |
-
from models.taming.modules.vqvae.quantize import GumbelQuantize
|
10 |
-
|
11 |
-
class VQModel(pl.LightningModule):
|
12 |
-
def __init__(self,
|
13 |
-
ddconfig,
|
14 |
-
lossconfig,
|
15 |
-
n_embed,
|
16 |
-
embed_dim,
|
17 |
-
ckpt_path=None,
|
18 |
-
ignore_keys=[],
|
19 |
-
image_key="image",
|
20 |
-
colorize_nlabels=None,
|
21 |
-
monitor=None,
|
22 |
-
remap=None,
|
23 |
-
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
24 |
-
):
|
25 |
-
super().__init__()
|
26 |
-
self.image_key = image_key
|
27 |
-
self.encoder = Encoder(**ddconfig)
|
28 |
-
self.decoder = Decoder(**ddconfig)
|
29 |
-
self.loss = instantiate_from_config(lossconfig)
|
30 |
-
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
31 |
-
remap=remap, sane_index_shape=sane_index_shape)
|
32 |
-
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
33 |
-
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
34 |
-
if ckpt_path is not None:
|
35 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
36 |
-
self.image_key = image_key
|
37 |
-
if colorize_nlabels is not None:
|
38 |
-
assert type(colorize_nlabels)==int
|
39 |
-
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
40 |
-
if monitor is not None:
|
41 |
-
self.monitor = monitor
|
42 |
-
|
43 |
-
def init_from_ckpt(self, path, ignore_keys=list()):
|
44 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
45 |
-
keys = list(sd.keys())
|
46 |
-
for k in keys:
|
47 |
-
for ik in ignore_keys:
|
48 |
-
if k.startswith(ik):
|
49 |
-
print("Deleting key {} from state_dict.".format(k))
|
50 |
-
del sd[k]
|
51 |
-
self.load_state_dict(sd, strict=False)
|
52 |
-
print(f"Restored from {path}")
|
53 |
-
|
54 |
-
def encode(self, x):
|
55 |
-
h = self.encoder(x)
|
56 |
-
h = self.quant_conv(h)
|
57 |
-
quant, emb_loss, info = self.quantize(h)
|
58 |
-
return quant, emb_loss, info
|
59 |
-
|
60 |
-
def decode(self, quant):
|
61 |
-
quant = self.post_quant_conv(quant)
|
62 |
-
dec = self.decoder(quant)
|
63 |
-
return dec
|
64 |
-
|
65 |
-
def decode_code(self, code_b):
|
66 |
-
quant_b = self.quantize.embed_code(code_b)
|
67 |
-
dec = self.decode(quant_b)
|
68 |
-
return dec
|
69 |
-
|
70 |
-
def forward(self, input):
|
71 |
-
quant, diff, _ = self.encode(input)
|
72 |
-
dec = self.decode(quant)
|
73 |
-
return dec, diff
|
74 |
-
|
75 |
-
def get_input(self, batch, k):
|
76 |
-
x = batch[k]
|
77 |
-
if len(x.shape) == 3:
|
78 |
-
x = x[..., None]
|
79 |
-
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
80 |
-
return x.float()
|
81 |
-
|
82 |
-
def training_step(self, batch, batch_idx, optimizer_idx):
|
83 |
-
x = self.get_input(batch, self.image_key)
|
84 |
-
xrec, qloss = self(x)
|
85 |
-
|
86 |
-
if optimizer_idx == 0:
|
87 |
-
# autoencode
|
88 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
89 |
-
last_layer=self.get_last_layer(), split="train")
|
90 |
-
|
91 |
-
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
92 |
-
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
93 |
-
return aeloss
|
94 |
-
|
95 |
-
if optimizer_idx == 1:
|
96 |
-
# discriminator
|
97 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
98 |
-
last_layer=self.get_last_layer(), split="train")
|
99 |
-
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
100 |
-
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
101 |
-
return discloss
|
102 |
-
|
103 |
-
def validation_step(self, batch, batch_idx):
|
104 |
-
x = self.get_input(batch, self.image_key)
|
105 |
-
xrec, qloss = self(x)
|
106 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
107 |
-
last_layer=self.get_last_layer(), split="val")
|
108 |
-
|
109 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
110 |
-
last_layer=self.get_last_layer(), split="val")
|
111 |
-
rec_loss = log_dict_ae["val/rec_loss"]
|
112 |
-
self.log("val/rec_loss", rec_loss,
|
113 |
-
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
114 |
-
self.log("val/aeloss", aeloss,
|
115 |
-
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
116 |
-
self.log_dict(log_dict_ae)
|
117 |
-
self.log_dict(log_dict_disc)
|
118 |
-
return self.log_dict
|
119 |
-
|
120 |
-
def configure_optimizers(self):
|
121 |
-
lr = self.learning_rate
|
122 |
-
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
123 |
-
list(self.decoder.parameters())+
|
124 |
-
list(self.quantize.parameters())+
|
125 |
-
list(self.quant_conv.parameters())+
|
126 |
-
list(self.post_quant_conv.parameters()),
|
127 |
-
lr=lr, betas=(0.5, 0.9))
|
128 |
-
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
129 |
-
lr=lr, betas=(0.5, 0.9))
|
130 |
-
return [opt_ae, opt_disc], []
|
131 |
-
|
132 |
-
def get_last_layer(self):
|
133 |
-
return self.decoder.conv_out.weight
|
134 |
-
|
135 |
-
def log_images(self, batch, **kwargs):
|
136 |
-
log = dict()
|
137 |
-
x = self.get_input(batch, self.image_key)
|
138 |
-
x = x.to(self.device)
|
139 |
-
xrec, _ = self(x)
|
140 |
-
if x.shape[1] > 3:
|
141 |
-
# colorize with random projection
|
142 |
-
assert xrec.shape[1] > 3
|
143 |
-
x = self.to_rgb(x)
|
144 |
-
xrec = self.to_rgb(xrec)
|
145 |
-
log["inputs"] = x
|
146 |
-
log["reconstructions"] = xrec
|
147 |
-
return log
|
148 |
-
|
149 |
-
def to_rgb(self, x):
|
150 |
-
assert self.image_key == "segmentation"
|
151 |
-
if not hasattr(self, "colorize"):
|
152 |
-
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
153 |
-
x = F.conv2d(x, weight=self.colorize)
|
154 |
-
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
155 |
-
return x
|
156 |
-
|
157 |
-
|
158 |
-
class GumbelVQ(VQModel):
|
159 |
-
def __init__(self,
|
160 |
-
ddconfig,
|
161 |
-
lossconfig,
|
162 |
-
n_embed,
|
163 |
-
embed_dim,
|
164 |
-
temperature_scheduler_config,
|
165 |
-
ckpt_path=None,
|
166 |
-
ignore_keys=[],
|
167 |
-
image_key="image",
|
168 |
-
colorize_nlabels=None,
|
169 |
-
monitor=None,
|
170 |
-
kl_weight=1e-8,
|
171 |
-
remap=None,
|
172 |
-
):
|
173 |
-
|
174 |
-
z_channels = ddconfig["z_channels"]
|
175 |
-
super().__init__(ddconfig,
|
176 |
-
lossconfig,
|
177 |
-
n_embed,
|
178 |
-
embed_dim,
|
179 |
-
ckpt_path=None,
|
180 |
-
ignore_keys=ignore_keys,
|
181 |
-
image_key=image_key,
|
182 |
-
colorize_nlabels=colorize_nlabels,
|
183 |
-
monitor=monitor,
|
184 |
-
)
|
185 |
-
|
186 |
-
self.loss.n_classes = n_embed
|
187 |
-
self.vocab_size = n_embed
|
188 |
-
|
189 |
-
self.quantize = GumbelQuantize(z_channels, embed_dim,
|
190 |
-
n_embed=n_embed,
|
191 |
-
kl_weight=kl_weight, temp_init=1.0,
|
192 |
-
remap=remap)
|
193 |
-
|
194 |
-
self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
|
195 |
-
|
196 |
-
if ckpt_path is not None:
|
197 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
198 |
-
|
199 |
-
def temperature_scheduling(self):
|
200 |
-
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
201 |
-
|
202 |
-
def encode_to_prequant(self, x):
|
203 |
-
h = self.encoder(x)
|
204 |
-
h = self.quant_conv(h)
|
205 |
-
return h
|
206 |
-
|
207 |
-
def decode_code(self, code_b):
|
208 |
-
quant_b = self.quantize.get_codebook_entry(code_b.view(-1), list(code_b.size())+[self.quantize.embedding_dim])
|
209 |
-
dec = self.decode(quant_b)
|
210 |
-
return dec
|
211 |
-
|
212 |
-
def training_step(self, batch, batch_idx, optimizer_idx):
|
213 |
-
self.temperature_scheduling()
|
214 |
-
x = self.get_input(batch, self.image_key)
|
215 |
-
xrec, qloss = self(x)
|
216 |
-
|
217 |
-
if optimizer_idx == 0:
|
218 |
-
# autoencode
|
219 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
220 |
-
last_layer=self.get_last_layer(), split="train")
|
221 |
-
|
222 |
-
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
223 |
-
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
224 |
-
return aeloss
|
225 |
-
|
226 |
-
if optimizer_idx == 1:
|
227 |
-
# discriminator
|
228 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
229 |
-
last_layer=self.get_last_layer(), split="train")
|
230 |
-
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
231 |
-
return discloss
|
232 |
-
|
233 |
-
def validation_step(self, batch, batch_idx):
|
234 |
-
x = self.get_input(batch, self.image_key)
|
235 |
-
xrec, qloss = self(x, return_pred_indices=True)
|
236 |
-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
237 |
-
last_layer=self.get_last_layer(), split="val")
|
238 |
-
|
239 |
-
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
240 |
-
last_layer=self.get_last_layer(), split="val")
|
241 |
-
rec_loss = log_dict_ae["val/rec_loss"]
|
242 |
-
self.log("val/rec_loss", rec_loss,
|
243 |
-
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
244 |
-
self.log("val/aeloss", aeloss,
|
245 |
-
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
246 |
-
self.log_dict(log_dict_ae)
|
247 |
-
self.log_dict(log_dict_disc)
|
248 |
-
return self.log_dict
|
249 |
-
|
250 |
-
def log_images(self, batch, **kwargs):
|
251 |
-
log = dict()
|
252 |
-
x = self.get_input(batch, self.image_key)
|
253 |
-
x = x.to(self.device)
|
254 |
-
# encode
|
255 |
-
h = self.encoder(x)
|
256 |
-
h = self.quant_conv(h)
|
257 |
-
quant, _, _ = self.quantize(h)
|
258 |
-
# decode
|
259 |
-
x_rec = self.decode(quant)
|
260 |
-
log["inputs"] = x
|
261 |
-
log["reconstructions"] = x_rec
|
262 |
-
return log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/modules/diffusionmodules/model.py
DELETED
@@ -1,776 +0,0 @@
|
|
1 |
-
# pytorch_diffusion + derived encoder decoder
|
2 |
-
import math
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
|
8 |
-
def get_timestep_embedding(timesteps, embedding_dim):
|
9 |
-
"""
|
10 |
-
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
11 |
-
From Fairseq.
|
12 |
-
Build sinusoidal embeddings.
|
13 |
-
This matches the implementation in tensor2tensor, but differs slightly
|
14 |
-
from the description in Section 3.5 of "Attention Is All You Need".
|
15 |
-
"""
|
16 |
-
assert len(timesteps.shape) == 1
|
17 |
-
|
18 |
-
half_dim = embedding_dim // 2
|
19 |
-
emb = math.log(10000) / (half_dim - 1)
|
20 |
-
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
21 |
-
emb = emb.to(device=timesteps.device)
|
22 |
-
emb = timesteps.float()[:, None] * emb[None, :]
|
23 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
24 |
-
if embedding_dim % 2 == 1: # zero pad
|
25 |
-
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
26 |
-
return emb
|
27 |
-
|
28 |
-
|
29 |
-
def nonlinearity(x):
|
30 |
-
# swish
|
31 |
-
return x*torch.sigmoid(x)
|
32 |
-
|
33 |
-
|
34 |
-
def Normalize(in_channels):
|
35 |
-
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
36 |
-
|
37 |
-
|
38 |
-
class Upsample(nn.Module):
|
39 |
-
def __init__(self, in_channels, with_conv):
|
40 |
-
super().__init__()
|
41 |
-
self.with_conv = with_conv
|
42 |
-
if self.with_conv:
|
43 |
-
self.conv = torch.nn.Conv2d(in_channels,
|
44 |
-
in_channels,
|
45 |
-
kernel_size=3,
|
46 |
-
stride=1,
|
47 |
-
padding=1)
|
48 |
-
|
49 |
-
def forward(self, x):
|
50 |
-
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
51 |
-
if self.with_conv:
|
52 |
-
x = self.conv(x)
|
53 |
-
return x
|
54 |
-
|
55 |
-
|
56 |
-
class Downsample(nn.Module):
|
57 |
-
def __init__(self, in_channels, with_conv):
|
58 |
-
super().__init__()
|
59 |
-
self.with_conv = with_conv
|
60 |
-
if self.with_conv:
|
61 |
-
# no asymmetric padding in torch conv, must do it ourselves
|
62 |
-
self.conv = torch.nn.Conv2d(in_channels,
|
63 |
-
in_channels,
|
64 |
-
kernel_size=3,
|
65 |
-
stride=2,
|
66 |
-
padding=0)
|
67 |
-
|
68 |
-
def forward(self, x):
|
69 |
-
if self.with_conv:
|
70 |
-
pad = (0,1,0,1)
|
71 |
-
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
72 |
-
x = self.conv(x)
|
73 |
-
else:
|
74 |
-
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
75 |
-
return x
|
76 |
-
|
77 |
-
|
78 |
-
class ResnetBlock(nn.Module):
|
79 |
-
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
80 |
-
dropout, temb_channels=512):
|
81 |
-
super().__init__()
|
82 |
-
self.in_channels = in_channels
|
83 |
-
out_channels = in_channels if out_channels is None else out_channels
|
84 |
-
self.out_channels = out_channels
|
85 |
-
self.use_conv_shortcut = conv_shortcut
|
86 |
-
|
87 |
-
self.norm1 = Normalize(in_channels)
|
88 |
-
self.conv1 = torch.nn.Conv2d(in_channels,
|
89 |
-
out_channels,
|
90 |
-
kernel_size=3,
|
91 |
-
stride=1,
|
92 |
-
padding=1)
|
93 |
-
if temb_channels > 0:
|
94 |
-
self.temb_proj = torch.nn.Linear(temb_channels,
|
95 |
-
out_channels)
|
96 |
-
self.norm2 = Normalize(out_channels)
|
97 |
-
self.dropout = torch.nn.Dropout(dropout)
|
98 |
-
self.conv2 = torch.nn.Conv2d(out_channels,
|
99 |
-
out_channels,
|
100 |
-
kernel_size=3,
|
101 |
-
stride=1,
|
102 |
-
padding=1)
|
103 |
-
if self.in_channels != self.out_channels:
|
104 |
-
if self.use_conv_shortcut:
|
105 |
-
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
106 |
-
out_channels,
|
107 |
-
kernel_size=3,
|
108 |
-
stride=1,
|
109 |
-
padding=1)
|
110 |
-
else:
|
111 |
-
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
112 |
-
out_channels,
|
113 |
-
kernel_size=1,
|
114 |
-
stride=1,
|
115 |
-
padding=0)
|
116 |
-
|
117 |
-
def forward(self, x, temb):
|
118 |
-
h = x
|
119 |
-
h = self.norm1(h)
|
120 |
-
h = nonlinearity(h)
|
121 |
-
h = self.conv1(h)
|
122 |
-
|
123 |
-
if temb is not None:
|
124 |
-
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
125 |
-
|
126 |
-
h = self.norm2(h)
|
127 |
-
h = nonlinearity(h)
|
128 |
-
h = self.dropout(h)
|
129 |
-
h = self.conv2(h)
|
130 |
-
|
131 |
-
if self.in_channels != self.out_channels:
|
132 |
-
if self.use_conv_shortcut:
|
133 |
-
x = self.conv_shortcut(x)
|
134 |
-
else:
|
135 |
-
x = self.nin_shortcut(x)
|
136 |
-
|
137 |
-
return x+h
|
138 |
-
|
139 |
-
|
140 |
-
class AttnBlock(nn.Module):
|
141 |
-
def __init__(self, in_channels):
|
142 |
-
super().__init__()
|
143 |
-
self.in_channels = in_channels
|
144 |
-
|
145 |
-
self.norm = Normalize(in_channels)
|
146 |
-
self.q = torch.nn.Conv2d(in_channels,
|
147 |
-
in_channels,
|
148 |
-
kernel_size=1,
|
149 |
-
stride=1,
|
150 |
-
padding=0)
|
151 |
-
self.k = torch.nn.Conv2d(in_channels,
|
152 |
-
in_channels,
|
153 |
-
kernel_size=1,
|
154 |
-
stride=1,
|
155 |
-
padding=0)
|
156 |
-
self.v = torch.nn.Conv2d(in_channels,
|
157 |
-
in_channels,
|
158 |
-
kernel_size=1,
|
159 |
-
stride=1,
|
160 |
-
padding=0)
|
161 |
-
self.proj_out = torch.nn.Conv2d(in_channels,
|
162 |
-
in_channels,
|
163 |
-
kernel_size=1,
|
164 |
-
stride=1,
|
165 |
-
padding=0)
|
166 |
-
|
167 |
-
|
168 |
-
def forward(self, x):
|
169 |
-
h_ = x
|
170 |
-
h_ = self.norm(h_)
|
171 |
-
q = self.q(h_)
|
172 |
-
k = self.k(h_)
|
173 |
-
v = self.v(h_)
|
174 |
-
|
175 |
-
# compute attention
|
176 |
-
b,c,h,w = q.shape
|
177 |
-
q = q.reshape(b,c,h*w)
|
178 |
-
q = q.permute(0,2,1) # b,hw,c
|
179 |
-
k = k.reshape(b,c,h*w) # b,c,hw
|
180 |
-
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
181 |
-
w_ = w_ * (int(c)**(-0.5))
|
182 |
-
w_ = torch.nn.functional.softmax(w_, dim=2)
|
183 |
-
|
184 |
-
# attend to values
|
185 |
-
v = v.reshape(b,c,h*w)
|
186 |
-
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
187 |
-
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
188 |
-
h_ = h_.reshape(b,c,h,w)
|
189 |
-
|
190 |
-
h_ = self.proj_out(h_)
|
191 |
-
|
192 |
-
return x+h_
|
193 |
-
|
194 |
-
|
195 |
-
class Model(nn.Module):
|
196 |
-
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
197 |
-
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
198 |
-
resolution, use_timestep=True):
|
199 |
-
super().__init__()
|
200 |
-
self.ch = ch
|
201 |
-
self.temb_ch = self.ch*4
|
202 |
-
self.num_resolutions = len(ch_mult)
|
203 |
-
self.num_res_blocks = num_res_blocks
|
204 |
-
self.resolution = resolution
|
205 |
-
self.in_channels = in_channels
|
206 |
-
|
207 |
-
self.use_timestep = use_timestep
|
208 |
-
if self.use_timestep:
|
209 |
-
# timestep embedding
|
210 |
-
self.temb = nn.Module()
|
211 |
-
self.temb.dense = nn.ModuleList([
|
212 |
-
torch.nn.Linear(self.ch,
|
213 |
-
self.temb_ch),
|
214 |
-
torch.nn.Linear(self.temb_ch,
|
215 |
-
self.temb_ch),
|
216 |
-
])
|
217 |
-
|
218 |
-
# downsampling
|
219 |
-
self.conv_in = torch.nn.Conv2d(in_channels,
|
220 |
-
self.ch,
|
221 |
-
kernel_size=3,
|
222 |
-
stride=1,
|
223 |
-
padding=1)
|
224 |
-
|
225 |
-
curr_res = resolution
|
226 |
-
in_ch_mult = (1,)+tuple(ch_mult)
|
227 |
-
self.down = nn.ModuleList()
|
228 |
-
for i_level in range(self.num_resolutions):
|
229 |
-
block = nn.ModuleList()
|
230 |
-
attn = nn.ModuleList()
|
231 |
-
block_in = ch*in_ch_mult[i_level]
|
232 |
-
block_out = ch*ch_mult[i_level]
|
233 |
-
for i_block in range(self.num_res_blocks):
|
234 |
-
block.append(ResnetBlock(in_channels=block_in,
|
235 |
-
out_channels=block_out,
|
236 |
-
temb_channels=self.temb_ch,
|
237 |
-
dropout=dropout))
|
238 |
-
block_in = block_out
|
239 |
-
if curr_res in attn_resolutions:
|
240 |
-
attn.append(AttnBlock(block_in))
|
241 |
-
down = nn.Module()
|
242 |
-
down.block = block
|
243 |
-
down.attn = attn
|
244 |
-
if i_level != self.num_resolutions-1:
|
245 |
-
down.downsample = Downsample(block_in, resamp_with_conv)
|
246 |
-
curr_res = curr_res // 2
|
247 |
-
self.down.append(down)
|
248 |
-
|
249 |
-
# middle
|
250 |
-
self.mid = nn.Module()
|
251 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
252 |
-
out_channels=block_in,
|
253 |
-
temb_channels=self.temb_ch,
|
254 |
-
dropout=dropout)
|
255 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
256 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
257 |
-
out_channels=block_in,
|
258 |
-
temb_channels=self.temb_ch,
|
259 |
-
dropout=dropout)
|
260 |
-
|
261 |
-
# upsampling
|
262 |
-
self.up = nn.ModuleList()
|
263 |
-
for i_level in reversed(range(self.num_resolutions)):
|
264 |
-
block = nn.ModuleList()
|
265 |
-
attn = nn.ModuleList()
|
266 |
-
block_out = ch*ch_mult[i_level]
|
267 |
-
skip_in = ch*ch_mult[i_level]
|
268 |
-
for i_block in range(self.num_res_blocks+1):
|
269 |
-
if i_block == self.num_res_blocks:
|
270 |
-
skip_in = ch*in_ch_mult[i_level]
|
271 |
-
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
272 |
-
out_channels=block_out,
|
273 |
-
temb_channels=self.temb_ch,
|
274 |
-
dropout=dropout))
|
275 |
-
block_in = block_out
|
276 |
-
if curr_res in attn_resolutions:
|
277 |
-
attn.append(AttnBlock(block_in))
|
278 |
-
up = nn.Module()
|
279 |
-
up.block = block
|
280 |
-
up.attn = attn
|
281 |
-
if i_level != 0:
|
282 |
-
up.upsample = Upsample(block_in, resamp_with_conv)
|
283 |
-
curr_res = curr_res * 2
|
284 |
-
self.up.insert(0, up) # prepend to get consistent order
|
285 |
-
|
286 |
-
# end
|
287 |
-
self.norm_out = Normalize(block_in)
|
288 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
289 |
-
out_ch,
|
290 |
-
kernel_size=3,
|
291 |
-
stride=1,
|
292 |
-
padding=1)
|
293 |
-
|
294 |
-
|
295 |
-
def forward(self, x, t=None):
|
296 |
-
#assert x.shape[2] == x.shape[3] == self.resolution
|
297 |
-
|
298 |
-
if self.use_timestep:
|
299 |
-
# timestep embedding
|
300 |
-
assert t is not None
|
301 |
-
temb = get_timestep_embedding(t, self.ch)
|
302 |
-
temb = self.temb.dense[0](temb)
|
303 |
-
temb = nonlinearity(temb)
|
304 |
-
temb = self.temb.dense[1](temb)
|
305 |
-
else:
|
306 |
-
temb = None
|
307 |
-
|
308 |
-
# downsampling
|
309 |
-
hs = [self.conv_in(x)]
|
310 |
-
for i_level in range(self.num_resolutions):
|
311 |
-
for i_block in range(self.num_res_blocks):
|
312 |
-
h = self.down[i_level].block[i_block](hs[-1], temb)
|
313 |
-
if len(self.down[i_level].attn) > 0:
|
314 |
-
h = self.down[i_level].attn[i_block](h)
|
315 |
-
hs.append(h)
|
316 |
-
if i_level != self.num_resolutions-1:
|
317 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
318 |
-
|
319 |
-
# middle
|
320 |
-
h = hs[-1]
|
321 |
-
h = self.mid.block_1(h, temb)
|
322 |
-
h = self.mid.attn_1(h)
|
323 |
-
h = self.mid.block_2(h, temb)
|
324 |
-
|
325 |
-
# upsampling
|
326 |
-
for i_level in reversed(range(self.num_resolutions)):
|
327 |
-
for i_block in range(self.num_res_blocks+1):
|
328 |
-
h = self.up[i_level].block[i_block](
|
329 |
-
torch.cat([h, hs.pop()], dim=1), temb)
|
330 |
-
if len(self.up[i_level].attn) > 0:
|
331 |
-
h = self.up[i_level].attn[i_block](h)
|
332 |
-
if i_level != 0:
|
333 |
-
h = self.up[i_level].upsample(h)
|
334 |
-
|
335 |
-
# end
|
336 |
-
h = self.norm_out(h)
|
337 |
-
h = nonlinearity(h)
|
338 |
-
h = self.conv_out(h)
|
339 |
-
return h
|
340 |
-
|
341 |
-
|
342 |
-
class Encoder(nn.Module):
|
343 |
-
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
344 |
-
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
345 |
-
resolution, z_channels, double_z=True, **ignore_kwargs):
|
346 |
-
super().__init__()
|
347 |
-
self.ch = ch
|
348 |
-
self.temb_ch = 0
|
349 |
-
self.num_resolutions = len(ch_mult)
|
350 |
-
self.num_res_blocks = num_res_blocks
|
351 |
-
self.resolution = resolution
|
352 |
-
self.in_channels = in_channels
|
353 |
-
|
354 |
-
# downsampling
|
355 |
-
self.conv_in = torch.nn.Conv2d(in_channels,
|
356 |
-
self.ch,
|
357 |
-
kernel_size=3,
|
358 |
-
stride=1,
|
359 |
-
padding=1)
|
360 |
-
|
361 |
-
curr_res = resolution
|
362 |
-
in_ch_mult = (1,)+tuple(ch_mult)
|
363 |
-
self.down = nn.ModuleList()
|
364 |
-
for i_level in range(self.num_resolutions):
|
365 |
-
block = nn.ModuleList()
|
366 |
-
attn = nn.ModuleList()
|
367 |
-
block_in = ch*in_ch_mult[i_level]
|
368 |
-
block_out = ch*ch_mult[i_level]
|
369 |
-
for i_block in range(self.num_res_blocks):
|
370 |
-
block.append(ResnetBlock(in_channels=block_in,
|
371 |
-
out_channels=block_out,
|
372 |
-
temb_channels=self.temb_ch,
|
373 |
-
dropout=dropout))
|
374 |
-
block_in = block_out
|
375 |
-
if curr_res in attn_resolutions:
|
376 |
-
attn.append(AttnBlock(block_in))
|
377 |
-
down = nn.Module()
|
378 |
-
down.block = block
|
379 |
-
down.attn = attn
|
380 |
-
if i_level != self.num_resolutions-1:
|
381 |
-
down.downsample = Downsample(block_in, resamp_with_conv)
|
382 |
-
curr_res = curr_res // 2
|
383 |
-
self.down.append(down)
|
384 |
-
|
385 |
-
# middle
|
386 |
-
self.mid = nn.Module()
|
387 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
388 |
-
out_channels=block_in,
|
389 |
-
temb_channels=self.temb_ch,
|
390 |
-
dropout=dropout)
|
391 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
392 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
393 |
-
out_channels=block_in,
|
394 |
-
temb_channels=self.temb_ch,
|
395 |
-
dropout=dropout)
|
396 |
-
|
397 |
-
# end
|
398 |
-
self.norm_out = Normalize(block_in)
|
399 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
400 |
-
2*z_channels if double_z else z_channels,
|
401 |
-
kernel_size=3,
|
402 |
-
stride=1,
|
403 |
-
padding=1)
|
404 |
-
|
405 |
-
|
406 |
-
def forward(self, x):
|
407 |
-
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
408 |
-
|
409 |
-
# timestep embedding
|
410 |
-
temb = None
|
411 |
-
|
412 |
-
# downsampling
|
413 |
-
hs = [self.conv_in(x)]
|
414 |
-
for i_level in range(self.num_resolutions):
|
415 |
-
for i_block in range(self.num_res_blocks):
|
416 |
-
h = self.down[i_level].block[i_block](hs[-1], temb)
|
417 |
-
if len(self.down[i_level].attn) > 0:
|
418 |
-
h = self.down[i_level].attn[i_block](h)
|
419 |
-
hs.append(h)
|
420 |
-
if i_level != self.num_resolutions-1:
|
421 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
422 |
-
|
423 |
-
# middle
|
424 |
-
h = hs[-1]
|
425 |
-
h = self.mid.block_1(h, temb)
|
426 |
-
h = self.mid.attn_1(h)
|
427 |
-
h = self.mid.block_2(h, temb)
|
428 |
-
|
429 |
-
# end
|
430 |
-
h = self.norm_out(h)
|
431 |
-
h = nonlinearity(h)
|
432 |
-
h = self.conv_out(h)
|
433 |
-
return h
|
434 |
-
|
435 |
-
|
436 |
-
class Decoder(nn.Module):
|
437 |
-
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
438 |
-
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
439 |
-
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
440 |
-
super().__init__()
|
441 |
-
self.ch = ch
|
442 |
-
self.temb_ch = 0
|
443 |
-
self.num_resolutions = len(ch_mult)
|
444 |
-
self.num_res_blocks = num_res_blocks
|
445 |
-
self.resolution = resolution
|
446 |
-
self.in_channels = in_channels
|
447 |
-
self.give_pre_end = give_pre_end
|
448 |
-
|
449 |
-
# compute in_ch_mult, block_in and curr_res at lowest res
|
450 |
-
in_ch_mult = (1,)+tuple(ch_mult)
|
451 |
-
block_in = ch*ch_mult[self.num_resolutions-1]
|
452 |
-
curr_res = resolution // 2**(self.num_resolutions-1)
|
453 |
-
self.z_shape = (1,z_channels,curr_res,curr_res)
|
454 |
-
print("Working with z of shape {} = {} dimensions.".format(
|
455 |
-
self.z_shape, np.prod(self.z_shape)))
|
456 |
-
|
457 |
-
# z to block_in
|
458 |
-
self.conv_in = torch.nn.Conv2d(z_channels,
|
459 |
-
block_in,
|
460 |
-
kernel_size=3,
|
461 |
-
stride=1,
|
462 |
-
padding=1)
|
463 |
-
|
464 |
-
# middle
|
465 |
-
self.mid = nn.Module()
|
466 |
-
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
467 |
-
out_channels=block_in,
|
468 |
-
temb_channels=self.temb_ch,
|
469 |
-
dropout=dropout)
|
470 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
471 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
472 |
-
out_channels=block_in,
|
473 |
-
temb_channels=self.temb_ch,
|
474 |
-
dropout=dropout)
|
475 |
-
|
476 |
-
# upsampling
|
477 |
-
self.up = nn.ModuleList()
|
478 |
-
for i_level in reversed(range(self.num_resolutions)):
|
479 |
-
block = nn.ModuleList()
|
480 |
-
attn = nn.ModuleList()
|
481 |
-
block_out = ch*ch_mult[i_level]
|
482 |
-
for i_block in range(self.num_res_blocks+1):
|
483 |
-
block.append(ResnetBlock(in_channels=block_in,
|
484 |
-
out_channels=block_out,
|
485 |
-
temb_channels=self.temb_ch,
|
486 |
-
dropout=dropout))
|
487 |
-
block_in = block_out
|
488 |
-
if curr_res in attn_resolutions:
|
489 |
-
attn.append(AttnBlock(block_in))
|
490 |
-
up = nn.Module()
|
491 |
-
up.block = block
|
492 |
-
up.attn = attn
|
493 |
-
if i_level != 0:
|
494 |
-
up.upsample = Upsample(block_in, resamp_with_conv)
|
495 |
-
curr_res = curr_res * 2
|
496 |
-
self.up.insert(0, up) # prepend to get consistent order
|
497 |
-
|
498 |
-
# end
|
499 |
-
self.norm_out = Normalize(block_in)
|
500 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
501 |
-
out_ch,
|
502 |
-
kernel_size=3,
|
503 |
-
stride=1,
|
504 |
-
padding=1)
|
505 |
-
|
506 |
-
def forward(self, z):
|
507 |
-
#assert z.shape[1:] == self.z_shape[1:]
|
508 |
-
self.last_z_shape = z.shape
|
509 |
-
|
510 |
-
# timestep embedding
|
511 |
-
temb = None
|
512 |
-
|
513 |
-
# z to block_in
|
514 |
-
h = self.conv_in(z)
|
515 |
-
|
516 |
-
# middle
|
517 |
-
h = self.mid.block_1(h, temb)
|
518 |
-
h = self.mid.attn_1(h)
|
519 |
-
h = self.mid.block_2(h, temb)
|
520 |
-
|
521 |
-
# upsampling
|
522 |
-
for i_level in reversed(range(self.num_resolutions)):
|
523 |
-
for i_block in range(self.num_res_blocks+1):
|
524 |
-
h = self.up[i_level].block[i_block](h, temb)
|
525 |
-
if len(self.up[i_level].attn) > 0:
|
526 |
-
h = self.up[i_level].attn[i_block](h)
|
527 |
-
if i_level != 0:
|
528 |
-
h = self.up[i_level].upsample(h)
|
529 |
-
|
530 |
-
# end
|
531 |
-
if self.give_pre_end:
|
532 |
-
return h
|
533 |
-
|
534 |
-
h = self.norm_out(h)
|
535 |
-
h = nonlinearity(h)
|
536 |
-
h = self.conv_out(h)
|
537 |
-
return h
|
538 |
-
|
539 |
-
|
540 |
-
class VUNet(nn.Module):
|
541 |
-
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
542 |
-
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
543 |
-
in_channels, c_channels,
|
544 |
-
resolution, z_channels, use_timestep=False, **ignore_kwargs):
|
545 |
-
super().__init__()
|
546 |
-
self.ch = ch
|
547 |
-
self.temb_ch = self.ch*4
|
548 |
-
self.num_resolutions = len(ch_mult)
|
549 |
-
self.num_res_blocks = num_res_blocks
|
550 |
-
self.resolution = resolution
|
551 |
-
|
552 |
-
self.use_timestep = use_timestep
|
553 |
-
if self.use_timestep:
|
554 |
-
# timestep embedding
|
555 |
-
self.temb = nn.Module()
|
556 |
-
self.temb.dense = nn.ModuleList([
|
557 |
-
torch.nn.Linear(self.ch,
|
558 |
-
self.temb_ch),
|
559 |
-
torch.nn.Linear(self.temb_ch,
|
560 |
-
self.temb_ch),
|
561 |
-
])
|
562 |
-
|
563 |
-
# downsampling
|
564 |
-
self.conv_in = torch.nn.Conv2d(c_channels,
|
565 |
-
self.ch,
|
566 |
-
kernel_size=3,
|
567 |
-
stride=1,
|
568 |
-
padding=1)
|
569 |
-
|
570 |
-
curr_res = resolution
|
571 |
-
in_ch_mult = (1,)+tuple(ch_mult)
|
572 |
-
self.down = nn.ModuleList()
|
573 |
-
for i_level in range(self.num_resolutions):
|
574 |
-
block = nn.ModuleList()
|
575 |
-
attn = nn.ModuleList()
|
576 |
-
block_in = ch*in_ch_mult[i_level]
|
577 |
-
block_out = ch*ch_mult[i_level]
|
578 |
-
for i_block in range(self.num_res_blocks):
|
579 |
-
block.append(ResnetBlock(in_channels=block_in,
|
580 |
-
out_channels=block_out,
|
581 |
-
temb_channels=self.temb_ch,
|
582 |
-
dropout=dropout))
|
583 |
-
block_in = block_out
|
584 |
-
if curr_res in attn_resolutions:
|
585 |
-
attn.append(AttnBlock(block_in))
|
586 |
-
down = nn.Module()
|
587 |
-
down.block = block
|
588 |
-
down.attn = attn
|
589 |
-
if i_level != self.num_resolutions-1:
|
590 |
-
down.downsample = Downsample(block_in, resamp_with_conv)
|
591 |
-
curr_res = curr_res // 2
|
592 |
-
self.down.append(down)
|
593 |
-
|
594 |
-
self.z_in = torch.nn.Conv2d(z_channels,
|
595 |
-
block_in,
|
596 |
-
kernel_size=1,
|
597 |
-
stride=1,
|
598 |
-
padding=0)
|
599 |
-
# middle
|
600 |
-
self.mid = nn.Module()
|
601 |
-
self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
|
602 |
-
out_channels=block_in,
|
603 |
-
temb_channels=self.temb_ch,
|
604 |
-
dropout=dropout)
|
605 |
-
self.mid.attn_1 = AttnBlock(block_in)
|
606 |
-
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
607 |
-
out_channels=block_in,
|
608 |
-
temb_channels=self.temb_ch,
|
609 |
-
dropout=dropout)
|
610 |
-
|
611 |
-
# upsampling
|
612 |
-
self.up = nn.ModuleList()
|
613 |
-
for i_level in reversed(range(self.num_resolutions)):
|
614 |
-
block = nn.ModuleList()
|
615 |
-
attn = nn.ModuleList()
|
616 |
-
block_out = ch*ch_mult[i_level]
|
617 |
-
skip_in = ch*ch_mult[i_level]
|
618 |
-
for i_block in range(self.num_res_blocks+1):
|
619 |
-
if i_block == self.num_res_blocks:
|
620 |
-
skip_in = ch*in_ch_mult[i_level]
|
621 |
-
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
622 |
-
out_channels=block_out,
|
623 |
-
temb_channels=self.temb_ch,
|
624 |
-
dropout=dropout))
|
625 |
-
block_in = block_out
|
626 |
-
if curr_res in attn_resolutions:
|
627 |
-
attn.append(AttnBlock(block_in))
|
628 |
-
up = nn.Module()
|
629 |
-
up.block = block
|
630 |
-
up.attn = attn
|
631 |
-
if i_level != 0:
|
632 |
-
up.upsample = Upsample(block_in, resamp_with_conv)
|
633 |
-
curr_res = curr_res * 2
|
634 |
-
self.up.insert(0, up) # prepend to get consistent order
|
635 |
-
|
636 |
-
# end
|
637 |
-
self.norm_out = Normalize(block_in)
|
638 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
639 |
-
out_ch,
|
640 |
-
kernel_size=3,
|
641 |
-
stride=1,
|
642 |
-
padding=1)
|
643 |
-
|
644 |
-
|
645 |
-
def forward(self, x, z):
|
646 |
-
#assert x.shape[2] == x.shape[3] == self.resolution
|
647 |
-
|
648 |
-
if self.use_timestep:
|
649 |
-
# timestep embedding
|
650 |
-
assert t is not None
|
651 |
-
temb = get_timestep_embedding(t, self.ch)
|
652 |
-
temb = self.temb.dense[0](temb)
|
653 |
-
temb = nonlinearity(temb)
|
654 |
-
temb = self.temb.dense[1](temb)
|
655 |
-
else:
|
656 |
-
temb = None
|
657 |
-
|
658 |
-
# downsampling
|
659 |
-
hs = [self.conv_in(x)]
|
660 |
-
for i_level in range(self.num_resolutions):
|
661 |
-
for i_block in range(self.num_res_blocks):
|
662 |
-
h = self.down[i_level].block[i_block](hs[-1], temb)
|
663 |
-
if len(self.down[i_level].attn) > 0:
|
664 |
-
h = self.down[i_level].attn[i_block](h)
|
665 |
-
hs.append(h)
|
666 |
-
if i_level != self.num_resolutions-1:
|
667 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
668 |
-
|
669 |
-
# middle
|
670 |
-
h = hs[-1]
|
671 |
-
z = self.z_in(z)
|
672 |
-
h = torch.cat((h,z),dim=1)
|
673 |
-
h = self.mid.block_1(h, temb)
|
674 |
-
h = self.mid.attn_1(h)
|
675 |
-
h = self.mid.block_2(h, temb)
|
676 |
-
|
677 |
-
# upsampling
|
678 |
-
for i_level in reversed(range(self.num_resolutions)):
|
679 |
-
for i_block in range(self.num_res_blocks+1):
|
680 |
-
h = self.up[i_level].block[i_block](
|
681 |
-
torch.cat([h, hs.pop()], dim=1), temb)
|
682 |
-
if len(self.up[i_level].attn) > 0:
|
683 |
-
h = self.up[i_level].attn[i_block](h)
|
684 |
-
if i_level != 0:
|
685 |
-
h = self.up[i_level].upsample(h)
|
686 |
-
|
687 |
-
# end
|
688 |
-
h = self.norm_out(h)
|
689 |
-
h = nonlinearity(h)
|
690 |
-
h = self.conv_out(h)
|
691 |
-
return h
|
692 |
-
|
693 |
-
|
694 |
-
class SimpleDecoder(nn.Module):
|
695 |
-
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
696 |
-
super().__init__()
|
697 |
-
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
698 |
-
ResnetBlock(in_channels=in_channels,
|
699 |
-
out_channels=2 * in_channels,
|
700 |
-
temb_channels=0, dropout=0.0),
|
701 |
-
ResnetBlock(in_channels=2 * in_channels,
|
702 |
-
out_channels=4 * in_channels,
|
703 |
-
temb_channels=0, dropout=0.0),
|
704 |
-
ResnetBlock(in_channels=4 * in_channels,
|
705 |
-
out_channels=2 * in_channels,
|
706 |
-
temb_channels=0, dropout=0.0),
|
707 |
-
nn.Conv2d(2*in_channels, in_channels, 1),
|
708 |
-
Upsample(in_channels, with_conv=True)])
|
709 |
-
# end
|
710 |
-
self.norm_out = Normalize(in_channels)
|
711 |
-
self.conv_out = torch.nn.Conv2d(in_channels,
|
712 |
-
out_channels,
|
713 |
-
kernel_size=3,
|
714 |
-
stride=1,
|
715 |
-
padding=1)
|
716 |
-
|
717 |
-
def forward(self, x):
|
718 |
-
for i, layer in enumerate(self.model):
|
719 |
-
if i in [1,2,3]:
|
720 |
-
x = layer(x, None)
|
721 |
-
else:
|
722 |
-
x = layer(x)
|
723 |
-
|
724 |
-
h = self.norm_out(x)
|
725 |
-
h = nonlinearity(h)
|
726 |
-
x = self.conv_out(h)
|
727 |
-
return x
|
728 |
-
|
729 |
-
|
730 |
-
class UpsampleDecoder(nn.Module):
|
731 |
-
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
732 |
-
ch_mult=(2,2), dropout=0.0):
|
733 |
-
super().__init__()
|
734 |
-
# upsampling
|
735 |
-
self.temb_ch = 0
|
736 |
-
self.num_resolutions = len(ch_mult)
|
737 |
-
self.num_res_blocks = num_res_blocks
|
738 |
-
block_in = in_channels
|
739 |
-
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
740 |
-
self.res_blocks = nn.ModuleList()
|
741 |
-
self.upsample_blocks = nn.ModuleList()
|
742 |
-
for i_level in range(self.num_resolutions):
|
743 |
-
res_block = []
|
744 |
-
block_out = ch * ch_mult[i_level]
|
745 |
-
for i_block in range(self.num_res_blocks + 1):
|
746 |
-
res_block.append(ResnetBlock(in_channels=block_in,
|
747 |
-
out_channels=block_out,
|
748 |
-
temb_channels=self.temb_ch,
|
749 |
-
dropout=dropout))
|
750 |
-
block_in = block_out
|
751 |
-
self.res_blocks.append(nn.ModuleList(res_block))
|
752 |
-
if i_level != self.num_resolutions - 1:
|
753 |
-
self.upsample_blocks.append(Upsample(block_in, True))
|
754 |
-
curr_res = curr_res * 2
|
755 |
-
|
756 |
-
# end
|
757 |
-
self.norm_out = Normalize(block_in)
|
758 |
-
self.conv_out = torch.nn.Conv2d(block_in,
|
759 |
-
out_channels,
|
760 |
-
kernel_size=3,
|
761 |
-
stride=1,
|
762 |
-
padding=1)
|
763 |
-
|
764 |
-
def forward(self, x):
|
765 |
-
# upsampling
|
766 |
-
h = x
|
767 |
-
for k, i_level in enumerate(range(self.num_resolutions)):
|
768 |
-
for i_block in range(self.num_res_blocks + 1):
|
769 |
-
h = self.res_blocks[i_level][i_block](h, None)
|
770 |
-
if i_level != self.num_resolutions - 1:
|
771 |
-
h = self.upsample_blocks[k](h)
|
772 |
-
h = self.norm_out(h)
|
773 |
-
h = nonlinearity(h)
|
774 |
-
h = self.conv_out(h)
|
775 |
-
return h
|
776 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/modules/discriminator/model.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
|
5 |
-
from models.taming.modules.util import ActNorm
|
6 |
-
|
7 |
-
|
8 |
-
def weights_init(m):
|
9 |
-
classname = m.__class__.__name__
|
10 |
-
if classname.find('Conv') != -1:
|
11 |
-
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
-
elif classname.find('BatchNorm') != -1:
|
13 |
-
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
-
nn.init.constant_(m.bias.data, 0)
|
15 |
-
|
16 |
-
|
17 |
-
class NLayerDiscriminator(nn.Module):
|
18 |
-
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
-
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
-
"""
|
21 |
-
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
22 |
-
"""Construct a PatchGAN discriminator
|
23 |
-
Parameters:
|
24 |
-
input_nc (int) -- the number of channels in input images
|
25 |
-
ndf (int) -- the number of filters in the last conv layer
|
26 |
-
n_layers (int) -- the number of conv layers in the discriminator
|
27 |
-
norm_layer -- normalization layer
|
28 |
-
"""
|
29 |
-
super(NLayerDiscriminator, self).__init__()
|
30 |
-
if not use_actnorm:
|
31 |
-
norm_layer = nn.BatchNorm2d
|
32 |
-
else:
|
33 |
-
norm_layer = ActNorm
|
34 |
-
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
35 |
-
use_bias = norm_layer.func != nn.BatchNorm2d
|
36 |
-
else:
|
37 |
-
use_bias = norm_layer != nn.BatchNorm2d
|
38 |
-
|
39 |
-
kw = 4
|
40 |
-
padw = 1
|
41 |
-
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
42 |
-
nf_mult = 1
|
43 |
-
nf_mult_prev = 1
|
44 |
-
for n in range(1, n_layers): # gradually increase the number of filters
|
45 |
-
nf_mult_prev = nf_mult
|
46 |
-
nf_mult = min(2 ** n, 8)
|
47 |
-
sequence += [
|
48 |
-
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
49 |
-
norm_layer(ndf * nf_mult),
|
50 |
-
nn.LeakyReLU(0.2, True)
|
51 |
-
]
|
52 |
-
|
53 |
-
nf_mult_prev = nf_mult
|
54 |
-
nf_mult = min(2 ** n_layers, 8)
|
55 |
-
sequence += [
|
56 |
-
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
57 |
-
norm_layer(ndf * nf_mult),
|
58 |
-
nn.LeakyReLU(0.2, True)
|
59 |
-
]
|
60 |
-
|
61 |
-
sequence += [
|
62 |
-
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
63 |
-
self.main = nn.Sequential(*sequence)
|
64 |
-
|
65 |
-
def forward(self, input):
|
66 |
-
"""Standard forward."""
|
67 |
-
return self.main(input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/modules/losses/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from models.taming.modules.losses.vqperceptual import DummyLoss
|
2 |
-
|
|
|
|
|
|
models/taming/modules/losses/lpips.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
from torchvision import models
|
6 |
-
from collections import namedtuple
|
7 |
-
|
8 |
-
from models.taming.util import get_ckpt_path
|
9 |
-
|
10 |
-
|
11 |
-
class LPIPS(nn.Module):
|
12 |
-
# Learned perceptual metric
|
13 |
-
def __init__(self, use_dropout=True):
|
14 |
-
super().__init__()
|
15 |
-
self.scaling_layer = ScalingLayer()
|
16 |
-
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
17 |
-
self.net = vgg16(pretrained=True, requires_grad=False)
|
18 |
-
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
19 |
-
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
20 |
-
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
21 |
-
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
22 |
-
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
23 |
-
self.load_from_pretrained()
|
24 |
-
for param in self.parameters():
|
25 |
-
param.requires_grad = False
|
26 |
-
|
27 |
-
def load_from_pretrained(self, name="vgg_lpips"):
|
28 |
-
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
|
29 |
-
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
30 |
-
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
31 |
-
|
32 |
-
@classmethod
|
33 |
-
def from_pretrained(cls, name="vgg_lpips"):
|
34 |
-
if name != "vgg_lpips":
|
35 |
-
raise NotImplementedError
|
36 |
-
model = cls()
|
37 |
-
ckpt = get_ckpt_path(name)
|
38 |
-
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
39 |
-
return model
|
40 |
-
|
41 |
-
def forward(self, input, target):
|
42 |
-
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
43 |
-
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
44 |
-
feats0, feats1, diffs = {}, {}, {}
|
45 |
-
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
46 |
-
for kk in range(len(self.chns)):
|
47 |
-
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
48 |
-
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
49 |
-
|
50 |
-
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
51 |
-
val = res[0]
|
52 |
-
for l in range(1, len(self.chns)):
|
53 |
-
val += res[l]
|
54 |
-
return val
|
55 |
-
|
56 |
-
|
57 |
-
class ScalingLayer(nn.Module):
|
58 |
-
def __init__(self):
|
59 |
-
super(ScalingLayer, self).__init__()
|
60 |
-
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
61 |
-
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
62 |
-
|
63 |
-
def forward(self, inp):
|
64 |
-
return (inp - self.shift) / self.scale
|
65 |
-
|
66 |
-
|
67 |
-
class NetLinLayer(nn.Module):
|
68 |
-
""" A single linear layer which does a 1x1 conv """
|
69 |
-
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
70 |
-
super(NetLinLayer, self).__init__()
|
71 |
-
layers = [nn.Dropout(), ] if (use_dropout) else []
|
72 |
-
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
73 |
-
self.model = nn.Sequential(*layers)
|
74 |
-
|
75 |
-
|
76 |
-
class vgg16(torch.nn.Module):
|
77 |
-
def __init__(self, requires_grad=False, pretrained=True):
|
78 |
-
super(vgg16, self).__init__()
|
79 |
-
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
80 |
-
self.slice1 = torch.nn.Sequential()
|
81 |
-
self.slice2 = torch.nn.Sequential()
|
82 |
-
self.slice3 = torch.nn.Sequential()
|
83 |
-
self.slice4 = torch.nn.Sequential()
|
84 |
-
self.slice5 = torch.nn.Sequential()
|
85 |
-
self.N_slices = 5
|
86 |
-
for x in range(4):
|
87 |
-
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
88 |
-
for x in range(4, 9):
|
89 |
-
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
90 |
-
for x in range(9, 16):
|
91 |
-
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
92 |
-
for x in range(16, 23):
|
93 |
-
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
94 |
-
for x in range(23, 30):
|
95 |
-
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
96 |
-
if not requires_grad:
|
97 |
-
for param in self.parameters():
|
98 |
-
param.requires_grad = False
|
99 |
-
|
100 |
-
def forward(self, X):
|
101 |
-
h = self.slice1(X)
|
102 |
-
h_relu1_2 = h
|
103 |
-
h = self.slice2(h)
|
104 |
-
h_relu2_2 = h
|
105 |
-
h = self.slice3(h)
|
106 |
-
h_relu3_3 = h
|
107 |
-
h = self.slice4(h)
|
108 |
-
h_relu4_3 = h
|
109 |
-
h = self.slice5(h)
|
110 |
-
h_relu5_3 = h
|
111 |
-
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
112 |
-
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
113 |
-
return out
|
114 |
-
|
115 |
-
|
116 |
-
def normalize_tensor(x,eps=1e-10):
|
117 |
-
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
|
118 |
-
return x/(norm_factor+eps)
|
119 |
-
|
120 |
-
|
121 |
-
def spatial_average(x, keepdim=True):
|
122 |
-
return x.mean([2,3],keepdim=keepdim)
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/modules/losses/segmentation.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
import torch.nn.functional as F
|
3 |
-
|
4 |
-
|
5 |
-
class BCELoss(nn.Module):
|
6 |
-
def forward(self, prediction, target):
|
7 |
-
loss = F.binary_cross_entropy_with_logits(prediction,target)
|
8 |
-
return loss, {}
|
9 |
-
|
10 |
-
|
11 |
-
class BCELossWithQuant(nn.Module):
|
12 |
-
def __init__(self, codebook_weight=1.):
|
13 |
-
super().__init__()
|
14 |
-
self.codebook_weight = codebook_weight
|
15 |
-
|
16 |
-
def forward(self, qloss, target, prediction, split):
|
17 |
-
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
|
18 |
-
loss = bce_loss + self.codebook_weight*qloss
|
19 |
-
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
20 |
-
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
21 |
-
"{}/quant_loss".format(split): qloss.detach().mean()
|
22 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/modules/losses/vqperceptual.py
DELETED
@@ -1,136 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
from models.taming.modules.losses.lpips import LPIPS
|
6 |
-
from models.taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
7 |
-
|
8 |
-
|
9 |
-
class DummyLoss(nn.Module):
|
10 |
-
def __init__(self):
|
11 |
-
super().__init__()
|
12 |
-
|
13 |
-
|
14 |
-
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
15 |
-
if global_step < threshold:
|
16 |
-
weight = value
|
17 |
-
return weight
|
18 |
-
|
19 |
-
|
20 |
-
def hinge_d_loss(logits_real, logits_fake):
|
21 |
-
loss_real = torch.mean(F.relu(1. - logits_real))
|
22 |
-
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
23 |
-
d_loss = 0.5 * (loss_real + loss_fake)
|
24 |
-
return d_loss
|
25 |
-
|
26 |
-
|
27 |
-
def vanilla_d_loss(logits_real, logits_fake):
|
28 |
-
d_loss = 0.5 * (
|
29 |
-
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
30 |
-
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
31 |
-
return d_loss
|
32 |
-
|
33 |
-
|
34 |
-
class VQLPIPSWithDiscriminator(nn.Module):
|
35 |
-
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
36 |
-
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
37 |
-
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
38 |
-
disc_ndf=64, disc_loss="hinge"):
|
39 |
-
super().__init__()
|
40 |
-
assert disc_loss in ["hinge", "vanilla"]
|
41 |
-
self.codebook_weight = codebook_weight
|
42 |
-
self.pixel_weight = pixelloss_weight
|
43 |
-
self.perceptual_loss = LPIPS().eval()
|
44 |
-
self.perceptual_weight = perceptual_weight
|
45 |
-
|
46 |
-
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
47 |
-
n_layers=disc_num_layers,
|
48 |
-
use_actnorm=use_actnorm,
|
49 |
-
ndf=disc_ndf
|
50 |
-
).apply(weights_init)
|
51 |
-
self.discriminator_iter_start = disc_start
|
52 |
-
if disc_loss == "hinge":
|
53 |
-
self.disc_loss = hinge_d_loss
|
54 |
-
elif disc_loss == "vanilla":
|
55 |
-
self.disc_loss = vanilla_d_loss
|
56 |
-
else:
|
57 |
-
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
58 |
-
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
59 |
-
self.disc_factor = disc_factor
|
60 |
-
self.discriminator_weight = disc_weight
|
61 |
-
self.disc_conditional = disc_conditional
|
62 |
-
|
63 |
-
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
64 |
-
if last_layer is not None:
|
65 |
-
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
66 |
-
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
67 |
-
else:
|
68 |
-
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
69 |
-
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
70 |
-
|
71 |
-
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
72 |
-
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
73 |
-
d_weight = d_weight * self.discriminator_weight
|
74 |
-
return d_weight
|
75 |
-
|
76 |
-
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
77 |
-
global_step, last_layer=None, cond=None, split="train"):
|
78 |
-
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
79 |
-
if self.perceptual_weight > 0:
|
80 |
-
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
81 |
-
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
82 |
-
else:
|
83 |
-
p_loss = torch.tensor([0.0])
|
84 |
-
|
85 |
-
nll_loss = rec_loss
|
86 |
-
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
87 |
-
nll_loss = torch.mean(nll_loss)
|
88 |
-
|
89 |
-
# now the GAN part
|
90 |
-
if optimizer_idx == 0:
|
91 |
-
# generator update
|
92 |
-
if cond is None:
|
93 |
-
assert not self.disc_conditional
|
94 |
-
logits_fake = self.discriminator(reconstructions.contiguous())
|
95 |
-
else:
|
96 |
-
assert self.disc_conditional
|
97 |
-
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
98 |
-
g_loss = -torch.mean(logits_fake)
|
99 |
-
|
100 |
-
try:
|
101 |
-
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
102 |
-
except RuntimeError:
|
103 |
-
assert not self.training
|
104 |
-
d_weight = torch.tensor(0.0)
|
105 |
-
|
106 |
-
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
107 |
-
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
108 |
-
|
109 |
-
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
110 |
-
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
111 |
-
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
112 |
-
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
113 |
-
"{}/p_loss".format(split): p_loss.detach().mean(),
|
114 |
-
"{}/d_weight".format(split): d_weight.detach(),
|
115 |
-
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
116 |
-
"{}/g_loss".format(split): g_loss.detach().mean(),
|
117 |
-
}
|
118 |
-
return loss, log
|
119 |
-
|
120 |
-
if optimizer_idx == 1:
|
121 |
-
# second pass for discriminator update
|
122 |
-
if cond is None:
|
123 |
-
logits_real = self.discriminator(inputs.contiguous().detach())
|
124 |
-
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
125 |
-
else:
|
126 |
-
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
127 |
-
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
128 |
-
|
129 |
-
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
130 |
-
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
131 |
-
|
132 |
-
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
133 |
-
"{}/logits_real".format(split): logits_real.detach().mean(),
|
134 |
-
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
135 |
-
}
|
136 |
-
return d_loss, log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/modules/misc/coord.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
class CoordStage(object):
|
4 |
-
def __init__(self, n_embed, down_factor):
|
5 |
-
self.n_embed = n_embed
|
6 |
-
self.down_factor = down_factor
|
7 |
-
|
8 |
-
def eval(self):
|
9 |
-
return self
|
10 |
-
|
11 |
-
def encode(self, c):
|
12 |
-
"""fake vqmodel interface"""
|
13 |
-
assert 0.0 <= c.min() and c.max() <= 1.0
|
14 |
-
b,ch,h,w = c.shape
|
15 |
-
assert ch == 1
|
16 |
-
|
17 |
-
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
|
18 |
-
mode="area")
|
19 |
-
c = c.clamp(0.0, 1.0)
|
20 |
-
c = self.n_embed*c
|
21 |
-
c_quant = c.round()
|
22 |
-
c_ind = c_quant.to(dtype=torch.long)
|
23 |
-
|
24 |
-
info = None, None, c_ind
|
25 |
-
return c_quant, None, info
|
26 |
-
|
27 |
-
def decode(self, c):
|
28 |
-
c = c/self.n_embed
|
29 |
-
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
|
30 |
-
mode="nearest")
|
31 |
-
return c
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/modules/util.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
|
5 |
-
def count_params(model):
|
6 |
-
total_params = sum(p.numel() for p in model.parameters())
|
7 |
-
return total_params
|
8 |
-
|
9 |
-
|
10 |
-
class ActNorm(nn.Module):
|
11 |
-
def __init__(self, num_features, logdet=False, affine=True,
|
12 |
-
allow_reverse_init=False):
|
13 |
-
assert affine
|
14 |
-
super().__init__()
|
15 |
-
self.logdet = logdet
|
16 |
-
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
17 |
-
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
18 |
-
self.allow_reverse_init = allow_reverse_init
|
19 |
-
|
20 |
-
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
21 |
-
|
22 |
-
def initialize(self, input):
|
23 |
-
with torch.no_grad():
|
24 |
-
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
25 |
-
mean = (
|
26 |
-
flatten.mean(1)
|
27 |
-
.unsqueeze(1)
|
28 |
-
.unsqueeze(2)
|
29 |
-
.unsqueeze(3)
|
30 |
-
.permute(1, 0, 2, 3)
|
31 |
-
)
|
32 |
-
std = (
|
33 |
-
flatten.std(1)
|
34 |
-
.unsqueeze(1)
|
35 |
-
.unsqueeze(2)
|
36 |
-
.unsqueeze(3)
|
37 |
-
.permute(1, 0, 2, 3)
|
38 |
-
)
|
39 |
-
|
40 |
-
self.loc.data.copy_(-mean)
|
41 |
-
self.scale.data.copy_(1 / (std + 1e-6))
|
42 |
-
|
43 |
-
def forward(self, input, reverse=False):
|
44 |
-
if reverse:
|
45 |
-
return self.reverse(input)
|
46 |
-
if len(input.shape) == 2:
|
47 |
-
input = input[:,:,None,None]
|
48 |
-
squeeze = True
|
49 |
-
else:
|
50 |
-
squeeze = False
|
51 |
-
|
52 |
-
_, _, height, width = input.shape
|
53 |
-
|
54 |
-
if self.training and self.initialized.item() == 0:
|
55 |
-
self.initialize(input)
|
56 |
-
self.initialized.fill_(1)
|
57 |
-
|
58 |
-
h = self.scale * (input + self.loc)
|
59 |
-
|
60 |
-
if squeeze:
|
61 |
-
h = h.squeeze(-1).squeeze(-1)
|
62 |
-
|
63 |
-
if self.logdet:
|
64 |
-
log_abs = torch.log(torch.abs(self.scale))
|
65 |
-
logdet = height*width*torch.sum(log_abs)
|
66 |
-
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
67 |
-
return h, logdet
|
68 |
-
|
69 |
-
return h
|
70 |
-
|
71 |
-
def reverse(self, output):
|
72 |
-
if self.training and self.initialized.item() == 0:
|
73 |
-
if not self.allow_reverse_init:
|
74 |
-
raise RuntimeError(
|
75 |
-
"Initializing ActNorm in reverse direction is "
|
76 |
-
"disabled by default. Use allow_reverse_init=True to enable."
|
77 |
-
)
|
78 |
-
else:
|
79 |
-
self.initialize(output)
|
80 |
-
self.initialized.fill_(1)
|
81 |
-
|
82 |
-
if len(output.shape) == 2:
|
83 |
-
output = output[:,:,None,None]
|
84 |
-
squeeze = True
|
85 |
-
else:
|
86 |
-
squeeze = False
|
87 |
-
|
88 |
-
h = output / self.scale - self.loc
|
89 |
-
|
90 |
-
if squeeze:
|
91 |
-
h = h.squeeze(-1).squeeze(-1)
|
92 |
-
return h
|
93 |
-
|
94 |
-
|
95 |
-
class AbstractEncoder(nn.Module):
|
96 |
-
def __init__(self):
|
97 |
-
super().__init__()
|
98 |
-
|
99 |
-
def encode(self, *args, **kwargs):
|
100 |
-
raise NotImplementedError
|
101 |
-
|
102 |
-
|
103 |
-
class Labelator(AbstractEncoder):
|
104 |
-
"""Net2Net Interface for Class-Conditional Model"""
|
105 |
-
def __init__(self, n_classes, quantize_interface=True):
|
106 |
-
super().__init__()
|
107 |
-
self.n_classes = n_classes
|
108 |
-
self.quantize_interface = quantize_interface
|
109 |
-
|
110 |
-
def encode(self, c):
|
111 |
-
c = c[:,None]
|
112 |
-
if self.quantize_interface:
|
113 |
-
return c, None, [None, None, c.long()]
|
114 |
-
return c
|
115 |
-
|
116 |
-
|
117 |
-
class SOSProvider(AbstractEncoder):
|
118 |
-
# for unconditional training
|
119 |
-
def __init__(self, sos_token, quantize_interface=True):
|
120 |
-
super().__init__()
|
121 |
-
self.sos_token = sos_token
|
122 |
-
self.quantize_interface = quantize_interface
|
123 |
-
|
124 |
-
def encode(self, x):
|
125 |
-
# get batch size from data and replicate sos_token
|
126 |
-
c = torch.ones(x.shape[0], 1)*self.sos_token
|
127 |
-
c = c.long().to(x.device)
|
128 |
-
if self.quantize_interface:
|
129 |
-
return c, None, [None, None, c]
|
130 |
-
return c
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/modules/vqvae/quantize.py
DELETED
@@ -1,445 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import numpy as np
|
5 |
-
from torch import einsum
|
6 |
-
from einops import rearrange
|
7 |
-
|
8 |
-
|
9 |
-
class VectorQuantizer(nn.Module):
|
10 |
-
"""
|
11 |
-
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
12 |
-
____________________________________________
|
13 |
-
Discretization bottleneck part of the VQ-VAE.
|
14 |
-
Inputs:
|
15 |
-
- n_e : number of embeddings
|
16 |
-
- e_dim : dimension of embedding
|
17 |
-
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
18 |
-
_____________________________________________
|
19 |
-
"""
|
20 |
-
|
21 |
-
# NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
|
22 |
-
# a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
|
23 |
-
# used wherever VectorQuantizer has been used before and is additionally
|
24 |
-
# more efficient.
|
25 |
-
def __init__(self, n_e, e_dim, beta):
|
26 |
-
super(VectorQuantizer, self).__init__()
|
27 |
-
self.n_e = n_e
|
28 |
-
self.e_dim = e_dim
|
29 |
-
self.beta = beta
|
30 |
-
|
31 |
-
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
32 |
-
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
33 |
-
|
34 |
-
def forward(self, z):
|
35 |
-
"""
|
36 |
-
Inputs the output of the encoder network z and maps it to a discrete
|
37 |
-
one-hot vector that is the index of the closest embedding vector e_j
|
38 |
-
z (continuous) -> z_q (discrete)
|
39 |
-
z.shape = (batch, channel, height, width)
|
40 |
-
quantization pipeline:
|
41 |
-
1. get encoder input (B,C,H,W)
|
42 |
-
2. flatten input to (B*H*W,C)
|
43 |
-
"""
|
44 |
-
# reshape z -> (batch, height, width, channel) and flatten
|
45 |
-
z = z.permute(0, 2, 3, 1).contiguous()
|
46 |
-
z_flattened = z.view(-1, self.e_dim)
|
47 |
-
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
48 |
-
|
49 |
-
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
50 |
-
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
51 |
-
torch.matmul(z_flattened, self.embedding.weight.t())
|
52 |
-
|
53 |
-
## could possible replace this here
|
54 |
-
# #\start...
|
55 |
-
# find closest encodings
|
56 |
-
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
57 |
-
|
58 |
-
min_encodings = torch.zeros(
|
59 |
-
min_encoding_indices.shape[0], self.n_e).to(z)
|
60 |
-
min_encodings.scatter_(1, min_encoding_indices, 1)
|
61 |
-
|
62 |
-
# dtype min encodings: torch.float32
|
63 |
-
# min_encodings shape: torch.Size([2048, 512])
|
64 |
-
# min_encoding_indices.shape: torch.Size([2048, 1])
|
65 |
-
|
66 |
-
# get quantized latent vectors
|
67 |
-
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
68 |
-
#.........\end
|
69 |
-
|
70 |
-
# with:
|
71 |
-
# .........\start
|
72 |
-
#min_encoding_indices = torch.argmin(d, dim=1)
|
73 |
-
#z_q = self.embedding(min_encoding_indices)
|
74 |
-
# ......\end......... (TODO)
|
75 |
-
|
76 |
-
# compute loss for embedding
|
77 |
-
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
78 |
-
torch.mean((z_q - z.detach()) ** 2)
|
79 |
-
|
80 |
-
# preserve gradients
|
81 |
-
z_q = z + (z_q - z).detach()
|
82 |
-
|
83 |
-
# perplexity
|
84 |
-
e_mean = torch.mean(min_encodings, dim=0)
|
85 |
-
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
86 |
-
|
87 |
-
# reshape back to match original input shape
|
88 |
-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
89 |
-
|
90 |
-
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
91 |
-
|
92 |
-
def get_codebook_entry(self, indices, shape):
|
93 |
-
# shape specifying (batch, height, width, channel)
|
94 |
-
# TODO: check for more easy handling with nn.Embedding
|
95 |
-
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
96 |
-
min_encodings.scatter_(1, indices[:,None], 1)
|
97 |
-
|
98 |
-
# get quantized latent vectors
|
99 |
-
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
100 |
-
|
101 |
-
if shape is not None:
|
102 |
-
z_q = z_q.view(shape)
|
103 |
-
|
104 |
-
# reshape back to match original input shape
|
105 |
-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
106 |
-
|
107 |
-
return z_q
|
108 |
-
|
109 |
-
|
110 |
-
class GumbelQuantize(nn.Module):
|
111 |
-
"""
|
112 |
-
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
113 |
-
Gumbel Softmax trick quantizer
|
114 |
-
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
115 |
-
https://arxiv.org/abs/1611.01144
|
116 |
-
"""
|
117 |
-
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
|
118 |
-
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
|
119 |
-
remap=None, unknown_index="random"):
|
120 |
-
super().__init__()
|
121 |
-
|
122 |
-
self.embedding_dim = embedding_dim
|
123 |
-
self.n_embed = n_embed
|
124 |
-
|
125 |
-
self.straight_through = straight_through
|
126 |
-
self.temperature = temp_init
|
127 |
-
self.kl_weight = kl_weight
|
128 |
-
|
129 |
-
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
130 |
-
self.embed = nn.Embedding(n_embed, embedding_dim)
|
131 |
-
|
132 |
-
self.use_vqinterface = use_vqinterface
|
133 |
-
|
134 |
-
self.remap = remap
|
135 |
-
if self.remap is not None:
|
136 |
-
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
137 |
-
self.re_embed = self.used.shape[0]
|
138 |
-
self.unknown_index = unknown_index # "random" or "extra" or integer
|
139 |
-
if self.unknown_index == "extra":
|
140 |
-
self.unknown_index = self.re_embed
|
141 |
-
self.re_embed = self.re_embed+1
|
142 |
-
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
143 |
-
f"Using {self.unknown_index} for unknown indices.")
|
144 |
-
else:
|
145 |
-
self.re_embed = n_embed
|
146 |
-
|
147 |
-
def remap_to_used(self, inds):
|
148 |
-
ishape = inds.shape
|
149 |
-
assert len(ishape)>1
|
150 |
-
inds = inds.reshape(ishape[0],-1)
|
151 |
-
used = self.used.to(inds)
|
152 |
-
match = (inds[:,:,None]==used[None,None,...]).long()
|
153 |
-
new = match.argmax(-1)
|
154 |
-
unknown = match.sum(2)<1
|
155 |
-
if self.unknown_index == "random":
|
156 |
-
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
157 |
-
else:
|
158 |
-
new[unknown] = self.unknown_index
|
159 |
-
return new.reshape(ishape)
|
160 |
-
|
161 |
-
def unmap_to_all(self, inds):
|
162 |
-
ishape = inds.shape
|
163 |
-
assert len(ishape)>1
|
164 |
-
inds = inds.reshape(ishape[0],-1)
|
165 |
-
used = self.used.to(inds)
|
166 |
-
if self.re_embed > self.used.shape[0]: # extra token
|
167 |
-
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
168 |
-
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
169 |
-
return back.reshape(ishape)
|
170 |
-
|
171 |
-
def forward(self, z, temp=None, return_logits=False):
|
172 |
-
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
|
173 |
-
hard = self.straight_through if self.training else True
|
174 |
-
temp = self.temperature if temp is None else temp
|
175 |
-
|
176 |
-
logits = self.proj(z)
|
177 |
-
if self.remap is not None:
|
178 |
-
# continue only with used logits
|
179 |
-
full_zeros = torch.zeros_like(logits)
|
180 |
-
logits = logits[:,self.used,...]
|
181 |
-
|
182 |
-
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
183 |
-
if self.remap is not None:
|
184 |
-
# go back to all entries but unused set to zero
|
185 |
-
full_zeros[:,self.used,...] = soft_one_hot
|
186 |
-
soft_one_hot = full_zeros
|
187 |
-
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
|
188 |
-
|
189 |
-
# + kl divergence to the prior loss
|
190 |
-
qy = F.softmax(logits, dim=1)
|
191 |
-
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
192 |
-
|
193 |
-
ind = soft_one_hot.argmax(dim=1)
|
194 |
-
if self.remap is not None:
|
195 |
-
ind = self.remap_to_used(ind)
|
196 |
-
if self.use_vqinterface:
|
197 |
-
if return_logits:
|
198 |
-
return z_q, diff, (None, None, ind), logits
|
199 |
-
return z_q, diff, (None, None, ind)
|
200 |
-
return z_q, diff, ind
|
201 |
-
|
202 |
-
def get_codebook_entry(self, indices, shape):
|
203 |
-
b, h, w, c = shape
|
204 |
-
assert b*h*w == indices.shape[0]
|
205 |
-
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
|
206 |
-
if self.remap is not None:
|
207 |
-
indices = self.unmap_to_all(indices)
|
208 |
-
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
209 |
-
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
|
210 |
-
return z_q
|
211 |
-
|
212 |
-
|
213 |
-
class VectorQuantizer2(nn.Module):
|
214 |
-
"""
|
215 |
-
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
216 |
-
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
217 |
-
"""
|
218 |
-
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
219 |
-
# backwards compatibility we use the buggy version by default, but you can
|
220 |
-
# specify legacy=False to fix it.
|
221 |
-
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
222 |
-
sane_index_shape=False, legacy=True):
|
223 |
-
super().__init__()
|
224 |
-
self.n_e = n_e
|
225 |
-
self.e_dim = e_dim
|
226 |
-
self.beta = beta
|
227 |
-
self.legacy = legacy
|
228 |
-
|
229 |
-
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
230 |
-
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
231 |
-
|
232 |
-
self.remap = remap
|
233 |
-
if self.remap is not None:
|
234 |
-
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
235 |
-
self.re_embed = self.used.shape[0]
|
236 |
-
self.unknown_index = unknown_index # "random" or "extra" or integer
|
237 |
-
if self.unknown_index == "extra":
|
238 |
-
self.unknown_index = self.re_embed
|
239 |
-
self.re_embed = self.re_embed+1
|
240 |
-
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
241 |
-
f"Using {self.unknown_index} for unknown indices.")
|
242 |
-
else:
|
243 |
-
self.re_embed = n_e
|
244 |
-
|
245 |
-
self.sane_index_shape = sane_index_shape
|
246 |
-
|
247 |
-
def remap_to_used(self, inds):
|
248 |
-
ishape = inds.shape
|
249 |
-
assert len(ishape)>1
|
250 |
-
inds = inds.reshape(ishape[0],-1)
|
251 |
-
used = self.used.to(inds)
|
252 |
-
match = (inds[:,:,None]==used[None,None,...]).long()
|
253 |
-
new = match.argmax(-1)
|
254 |
-
unknown = match.sum(2)<1
|
255 |
-
if self.unknown_index == "random":
|
256 |
-
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
257 |
-
else:
|
258 |
-
new[unknown] = self.unknown_index
|
259 |
-
return new.reshape(ishape)
|
260 |
-
|
261 |
-
def unmap_to_all(self, inds):
|
262 |
-
ishape = inds.shape
|
263 |
-
assert len(ishape)>1
|
264 |
-
inds = inds.reshape(ishape[0],-1)
|
265 |
-
used = self.used.to(inds)
|
266 |
-
if self.re_embed > self.used.shape[0]: # extra token
|
267 |
-
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
268 |
-
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
269 |
-
return back.reshape(ishape)
|
270 |
-
|
271 |
-
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
272 |
-
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
|
273 |
-
assert rescale_logits==False, "Only for interface compatible with Gumbel"
|
274 |
-
assert return_logits==False, "Only for interface compatible with Gumbel"
|
275 |
-
# reshape z -> (batch, height, width, channel) and flatten
|
276 |
-
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
277 |
-
z_flattened = z.view(-1, self.e_dim)
|
278 |
-
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
279 |
-
|
280 |
-
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
281 |
-
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
282 |
-
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
283 |
-
|
284 |
-
min_encoding_indices = torch.argmin(d, dim=1)
|
285 |
-
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
286 |
-
perplexity = None
|
287 |
-
min_encodings = None
|
288 |
-
|
289 |
-
# compute loss for embedding
|
290 |
-
if not self.legacy:
|
291 |
-
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
292 |
-
torch.mean((z_q - z.detach()) ** 2)
|
293 |
-
else:
|
294 |
-
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
295 |
-
torch.mean((z_q - z.detach()) ** 2)
|
296 |
-
|
297 |
-
# preserve gradients
|
298 |
-
z_q = z + (z_q - z).detach()
|
299 |
-
|
300 |
-
# reshape back to match original input shape
|
301 |
-
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
302 |
-
|
303 |
-
if self.remap is not None:
|
304 |
-
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
|
305 |
-
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
306 |
-
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
|
307 |
-
|
308 |
-
if self.sane_index_shape:
|
309 |
-
min_encoding_indices = min_encoding_indices.reshape(
|
310 |
-
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
311 |
-
|
312 |
-
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
313 |
-
|
314 |
-
def get_codebook_entry(self, indices, shape):
|
315 |
-
# shape specifying (batch, height, width, channel)
|
316 |
-
if self.remap is not None:
|
317 |
-
indices = indices.reshape(shape[0],-1) # add batch axis
|
318 |
-
indices = self.unmap_to_all(indices)
|
319 |
-
indices = indices.reshape(-1) # flatten again
|
320 |
-
|
321 |
-
# get quantized latent vectors
|
322 |
-
z_q = self.embedding(indices)
|
323 |
-
|
324 |
-
if shape is not None:
|
325 |
-
z_q = z_q.view(shape)
|
326 |
-
# reshape back to match original input shape
|
327 |
-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
328 |
-
|
329 |
-
return z_q
|
330 |
-
|
331 |
-
class EmbeddingEMA(nn.Module):
|
332 |
-
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
333 |
-
super().__init__()
|
334 |
-
self.decay = decay
|
335 |
-
self.eps = eps
|
336 |
-
weight = torch.randn(num_tokens, codebook_dim)
|
337 |
-
self.weight = nn.Parameter(weight, requires_grad = False)
|
338 |
-
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
|
339 |
-
self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
|
340 |
-
self.update = True
|
341 |
-
|
342 |
-
def forward(self, embed_id):
|
343 |
-
return F.embedding(embed_id, self.weight)
|
344 |
-
|
345 |
-
def cluster_size_ema_update(self, new_cluster_size):
|
346 |
-
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
347 |
-
|
348 |
-
def embed_avg_ema_update(self, new_embed_avg):
|
349 |
-
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
350 |
-
|
351 |
-
def weight_update(self, num_tokens):
|
352 |
-
n = self.cluster_size.sum()
|
353 |
-
smoothed_cluster_size = (
|
354 |
-
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
355 |
-
)
|
356 |
-
#normalize embedding average with smoothed cluster size
|
357 |
-
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
358 |
-
self.weight.data.copy_(embed_normalized)
|
359 |
-
|
360 |
-
|
361 |
-
class EMAVectorQuantizer(nn.Module):
|
362 |
-
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
363 |
-
remap=None, unknown_index="random"):
|
364 |
-
super().__init__()
|
365 |
-
self.codebook_dim = codebook_dim
|
366 |
-
self.num_tokens = num_tokens
|
367 |
-
self.beta = beta
|
368 |
-
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
369 |
-
|
370 |
-
self.remap = remap
|
371 |
-
if self.remap is not None:
|
372 |
-
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
373 |
-
self.re_embed = self.used.shape[0]
|
374 |
-
self.unknown_index = unknown_index # "random" or "extra" or integer
|
375 |
-
if self.unknown_index == "extra":
|
376 |
-
self.unknown_index = self.re_embed
|
377 |
-
self.re_embed = self.re_embed+1
|
378 |
-
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
379 |
-
f"Using {self.unknown_index} for unknown indices.")
|
380 |
-
else:
|
381 |
-
self.re_embed = n_embed
|
382 |
-
|
383 |
-
def remap_to_used(self, inds):
|
384 |
-
ishape = inds.shape
|
385 |
-
assert len(ishape)>1
|
386 |
-
inds = inds.reshape(ishape[0],-1)
|
387 |
-
used = self.used.to(inds)
|
388 |
-
match = (inds[:,:,None]==used[None,None,...]).long()
|
389 |
-
new = match.argmax(-1)
|
390 |
-
unknown = match.sum(2)<1
|
391 |
-
if self.unknown_index == "random":
|
392 |
-
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
393 |
-
else:
|
394 |
-
new[unknown] = self.unknown_index
|
395 |
-
return new.reshape(ishape)
|
396 |
-
|
397 |
-
def unmap_to_all(self, inds):
|
398 |
-
ishape = inds.shape
|
399 |
-
assert len(ishape)>1
|
400 |
-
inds = inds.reshape(ishape[0],-1)
|
401 |
-
used = self.used.to(inds)
|
402 |
-
if self.re_embed > self.used.shape[0]: # extra token
|
403 |
-
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
404 |
-
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
405 |
-
return back.reshape(ishape)
|
406 |
-
|
407 |
-
def forward(self, z):
|
408 |
-
# reshape z -> (batch, height, width, channel) and flatten
|
409 |
-
#z, 'b c h w -> b h w c'
|
410 |
-
z = rearrange(z, 'b c h w -> b h w c')
|
411 |
-
z_flattened = z.reshape(-1, self.codebook_dim)
|
412 |
-
|
413 |
-
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
414 |
-
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
415 |
-
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
416 |
-
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
417 |
-
|
418 |
-
|
419 |
-
encoding_indices = torch.argmin(d, dim=1)
|
420 |
-
|
421 |
-
z_q = self.embedding(encoding_indices).view(z.shape)
|
422 |
-
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
423 |
-
avg_probs = torch.mean(encodings, dim=0)
|
424 |
-
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
425 |
-
|
426 |
-
if self.training and self.embedding.update:
|
427 |
-
#EMA cluster size
|
428 |
-
encodings_sum = encodings.sum(0)
|
429 |
-
self.embedding.cluster_size_ema_update(encodings_sum)
|
430 |
-
#EMA embedding average
|
431 |
-
embed_sum = encodings.transpose(0,1) @ z_flattened
|
432 |
-
self.embedding.embed_avg_ema_update(embed_sum)
|
433 |
-
#normalize embed_avg and update weight
|
434 |
-
self.embedding.weight_update(self.num_tokens)
|
435 |
-
|
436 |
-
# compute loss for embedding
|
437 |
-
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
438 |
-
|
439 |
-
# preserve gradients
|
440 |
-
z_q = z + (z_q - z).detach()
|
441 |
-
|
442 |
-
# reshape back to match original input shape
|
443 |
-
#z_q, 'b h w c -> b c h w'
|
444 |
-
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
445 |
-
return z_q, loss, (perplexity, encodings, encoding_indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/taming/util.py
DELETED
@@ -1,172 +0,0 @@
|
|
1 |
-
import os, hashlib
|
2 |
-
import requests
|
3 |
-
from tqdm import tqdm
|
4 |
-
import importlib
|
5 |
-
|
6 |
-
URL_MAP = {
|
7 |
-
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
8 |
-
}
|
9 |
-
|
10 |
-
CKPT_MAP = {
|
11 |
-
"vgg_lpips": "vgg.pth"
|
12 |
-
}
|
13 |
-
|
14 |
-
MD5_MAP = {
|
15 |
-
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
16 |
-
}
|
17 |
-
|
18 |
-
|
19 |
-
def get_obj_from_str(string, reload=False):
|
20 |
-
module, cls = string.rsplit(".", 1)
|
21 |
-
if reload:
|
22 |
-
module_imp = importlib.import_module(module)
|
23 |
-
importlib.reload(module_imp)
|
24 |
-
return getattr(importlib.import_module(module, package=None), cls)
|
25 |
-
|
26 |
-
|
27 |
-
def instantiate_from_config(config):
|
28 |
-
if not "target" in config:
|
29 |
-
raise KeyError("Expected key `target` to instantiate.")
|
30 |
-
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
31 |
-
|
32 |
-
|
33 |
-
def download(url, local_path, chunk_size=1024):
|
34 |
-
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
35 |
-
with requests.get(url, stream=True) as r:
|
36 |
-
total_size = int(r.headers.get("content-length", 0))
|
37 |
-
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
38 |
-
with open(local_path, "wb") as f:
|
39 |
-
for data in r.iter_content(chunk_size=chunk_size):
|
40 |
-
if data:
|
41 |
-
f.write(data)
|
42 |
-
pbar.update(chunk_size)
|
43 |
-
|
44 |
-
|
45 |
-
def md5_hash(path):
|
46 |
-
with open(path, "rb") as f:
|
47 |
-
content = f.read()
|
48 |
-
return hashlib.md5(content).hexdigest()
|
49 |
-
|
50 |
-
|
51 |
-
def get_ckpt_path(name, root, check=False):
|
52 |
-
assert name in URL_MAP
|
53 |
-
path = os.path.join(root, CKPT_MAP[name])
|
54 |
-
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
55 |
-
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
56 |
-
download(URL_MAP[name], path)
|
57 |
-
md5 = md5_hash(path)
|
58 |
-
assert md5 == MD5_MAP[name], md5
|
59 |
-
return path
|
60 |
-
|
61 |
-
|
62 |
-
class KeyNotFoundError(Exception):
|
63 |
-
def __init__(self, cause, keys=None, visited=None):
|
64 |
-
self.cause = cause
|
65 |
-
self.keys = keys
|
66 |
-
self.visited = visited
|
67 |
-
messages = list()
|
68 |
-
if keys is not None:
|
69 |
-
messages.append("Key not found: {}".format(keys))
|
70 |
-
if visited is not None:
|
71 |
-
messages.append("Visited: {}".format(visited))
|
72 |
-
messages.append("Cause:\n{}".format(cause))
|
73 |
-
message = "\n".join(messages)
|
74 |
-
super().__init__(message)
|
75 |
-
|
76 |
-
|
77 |
-
def retrieve(
|
78 |
-
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
79 |
-
):
|
80 |
-
"""Given a nested list or dict return the desired value at key expanding
|
81 |
-
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
82 |
-
is done in-place.
|
83 |
-
|
84 |
-
Parameters
|
85 |
-
----------
|
86 |
-
list_or_dict : list or dict
|
87 |
-
Possibly nested list or dictionary.
|
88 |
-
key : str
|
89 |
-
key/to/value, path like string describing all keys necessary to
|
90 |
-
consider to get to the desired value. List indices can also be
|
91 |
-
passed here.
|
92 |
-
splitval : str
|
93 |
-
String that defines the delimiter between keys of the
|
94 |
-
different depth levels in `key`.
|
95 |
-
default : obj
|
96 |
-
Value returned if :attr:`key` is not found.
|
97 |
-
expand : bool
|
98 |
-
Whether to expand callable nodes on the path or not.
|
99 |
-
|
100 |
-
Returns
|
101 |
-
-------
|
102 |
-
The desired value or if :attr:`default` is not ``None`` and the
|
103 |
-
:attr:`key` is not found returns ``default``.
|
104 |
-
|
105 |
-
Raises
|
106 |
-
------
|
107 |
-
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
108 |
-
``None``.
|
109 |
-
"""
|
110 |
-
|
111 |
-
keys = key.split(splitval)
|
112 |
-
|
113 |
-
success = True
|
114 |
-
try:
|
115 |
-
visited = []
|
116 |
-
parent = None
|
117 |
-
last_key = None
|
118 |
-
for key in keys:
|
119 |
-
if callable(list_or_dict):
|
120 |
-
if not expand:
|
121 |
-
raise KeyNotFoundError(
|
122 |
-
ValueError(
|
123 |
-
"Trying to get past callable node with expand=False."
|
124 |
-
),
|
125 |
-
keys=keys,
|
126 |
-
visited=visited,
|
127 |
-
)
|
128 |
-
list_or_dict = list_or_dict()
|
129 |
-
parent[last_key] = list_or_dict
|
130 |
-
|
131 |
-
last_key = key
|
132 |
-
parent = list_or_dict
|
133 |
-
|
134 |
-
try:
|
135 |
-
if isinstance(list_or_dict, dict):
|
136 |
-
list_or_dict = list_or_dict[key]
|
137 |
-
else:
|
138 |
-
list_or_dict = list_or_dict[int(key)]
|
139 |
-
except (KeyError, IndexError, ValueError) as e:
|
140 |
-
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
141 |
-
|
142 |
-
visited += [key]
|
143 |
-
# final expansion of retrieved value
|
144 |
-
if expand and callable(list_or_dict):
|
145 |
-
list_or_dict = list_or_dict()
|
146 |
-
parent[last_key] = list_or_dict
|
147 |
-
except KeyNotFoundError as e:
|
148 |
-
if default is None:
|
149 |
-
raise e
|
150 |
-
else:
|
151 |
-
list_or_dict = default
|
152 |
-
success = False
|
153 |
-
|
154 |
-
if not pass_success:
|
155 |
-
return list_or_dict
|
156 |
-
else:
|
157 |
-
return list_or_dict, success
|
158 |
-
|
159 |
-
|
160 |
-
if __name__ == "__main__":
|
161 |
-
config = {"keya": "a",
|
162 |
-
"keyb": "b",
|
163 |
-
"keyc":
|
164 |
-
{"cc1": 1,
|
165 |
-
"cc2": 2,
|
166 |
-
}
|
167 |
-
}
|
168 |
-
from omegaconf import OmegaConf
|
169 |
-
|
170 |
-
config = OmegaConf.create(config)
|
171 |
-
print(config)
|
172 |
-
retrieve(config, "keya")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/eval_utils.py
CHANGED
@@ -63,7 +63,7 @@ def eval_ocr(task, generator, models, sample, **kwargs):
|
|
63 |
|
64 |
|
65 |
def eval_step(task, generator, models, sample, **kwargs):
|
66 |
-
if task.cfg._name ==
|
67 |
return eval_ocr(task, generator, models, sample, **kwargs)
|
68 |
else:
|
69 |
raise NotImplementedError
|
|
|
63 |
|
64 |
|
65 |
def eval_step(task, generator, models, sample, **kwargs):
|
66 |
+
if task.cfg._name == "ocr":
|
67 |
return eval_ocr(task, generator, models, sample, **kwargs)
|
68 |
else:
|
69 |
raise NotImplementedError
|