Upload 4 files
Browse files- __init__.py +18 -0
- decoder.py +94 -0
- encoder.py +93 -0
- utils.py +59 -0
__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io, requests
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from dall_e.encoder import Encoder
|
6 |
+
from dall_e.decoder import Decoder
|
7 |
+
from dall_e.utils import map_pixels, unmap_pixels
|
8 |
+
|
9 |
+
def load_model(path: str, device: torch.device = None) -> nn.Module:
|
10 |
+
if path.startswith('http://') or path.startswith('https://'):
|
11 |
+
resp = requests.get(path)
|
12 |
+
resp.raise_for_status()
|
13 |
+
|
14 |
+
with io.BytesIO(resp.content) as buf:
|
15 |
+
return torch.load(buf, map_location=device)
|
16 |
+
else:
|
17 |
+
with open(path, 'rb') as f:
|
18 |
+
return torch.load(f, map_location=device)
|
decoder.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import attr
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from collections import OrderedDict
|
9 |
+
from functools import partial
|
10 |
+
from dall_e.utils import Conv2d
|
11 |
+
|
12 |
+
@attr.s(eq=False, repr=False)
|
13 |
+
class DecoderBlock(nn.Module):
|
14 |
+
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
15 |
+
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
|
16 |
+
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
17 |
+
|
18 |
+
device: torch.device = attr.ib(default=None)
|
19 |
+
requires_grad: bool = attr.ib(default=False)
|
20 |
+
|
21 |
+
def __attrs_post_init__(self) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.n_hid = self.n_out // 4
|
24 |
+
self.post_gain = 1 / (self.n_layers ** 2)
|
25 |
+
|
26 |
+
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
|
27 |
+
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
|
28 |
+
self.res_path = nn.Sequential(OrderedDict([
|
29 |
+
('relu_1', nn.ReLU()),
|
30 |
+
('conv_1', make_conv(self.n_in, self.n_hid, 1)),
|
31 |
+
('relu_2', nn.ReLU()),
|
32 |
+
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
|
33 |
+
('relu_3', nn.ReLU()),
|
34 |
+
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
|
35 |
+
('relu_4', nn.ReLU()),
|
36 |
+
('conv_4', make_conv(self.n_hid, self.n_out, 3)),]))
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
39 |
+
return self.id_path(x) + self.post_gain * self.res_path(x)
|
40 |
+
|
41 |
+
@attr.s(eq=False, repr=False)
|
42 |
+
class Decoder(nn.Module):
|
43 |
+
group_count: int = 4
|
44 |
+
n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8)
|
45 |
+
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
|
46 |
+
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
|
47 |
+
output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
|
48 |
+
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
|
49 |
+
|
50 |
+
device: torch.device = attr.ib(default=torch.device('cpu'))
|
51 |
+
requires_grad: bool = attr.ib(default=False)
|
52 |
+
use_mixed_precision: bool = attr.ib(default=True)
|
53 |
+
|
54 |
+
def __attrs_post_init__(self) -> None:
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
blk_range = range(self.n_blk_per_group)
|
58 |
+
n_layers = self.group_count * self.n_blk_per_group
|
59 |
+
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
|
60 |
+
make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device,
|
61 |
+
requires_grad=self.requires_grad)
|
62 |
+
|
63 |
+
self.blocks = nn.Sequential(OrderedDict([
|
64 |
+
('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)),
|
65 |
+
('group_1', nn.Sequential(OrderedDict([
|
66 |
+
*[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
|
67 |
+
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
|
68 |
+
]))),
|
69 |
+
('group_2', nn.Sequential(OrderedDict([
|
70 |
+
*[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
|
71 |
+
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
|
72 |
+
]))),
|
73 |
+
('group_3', nn.Sequential(OrderedDict([
|
74 |
+
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
|
75 |
+
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
|
76 |
+
]))),
|
77 |
+
('group_4', nn.Sequential(OrderedDict([
|
78 |
+
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
|
79 |
+
]))),
|
80 |
+
('output', nn.Sequential(OrderedDict([
|
81 |
+
('relu', nn.ReLU()),
|
82 |
+
('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)),
|
83 |
+
]))),
|
84 |
+
]))
|
85 |
+
|
86 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
87 |
+
if len(x.shape) != 4:
|
88 |
+
raise ValueError(f'input shape {x.shape} is not 4d')
|
89 |
+
if x.shape[1] != self.vocab_size:
|
90 |
+
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}')
|
91 |
+
if x.dtype != torch.float32:
|
92 |
+
raise ValueError('input must have dtype torch.float32')
|
93 |
+
|
94 |
+
return self.blocks(x)
|
encoder.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import attr
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from collections import OrderedDict
|
9 |
+
from functools import partial
|
10 |
+
from dall_e.utils import Conv2d
|
11 |
+
|
12 |
+
@attr.s(eq=False, repr=False)
|
13 |
+
class EncoderBlock(nn.Module):
|
14 |
+
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
15 |
+
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
|
16 |
+
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
17 |
+
|
18 |
+
device: torch.device = attr.ib(default=None)
|
19 |
+
requires_grad: bool = attr.ib(default=False)
|
20 |
+
|
21 |
+
def __attrs_post_init__(self) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.n_hid = self.n_out // 4
|
24 |
+
self.post_gain = 1 / (self.n_layers ** 2)
|
25 |
+
|
26 |
+
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
|
27 |
+
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
|
28 |
+
self.res_path = nn.Sequential(OrderedDict([
|
29 |
+
('relu_1', nn.ReLU()),
|
30 |
+
('conv_1', make_conv(self.n_in, self.n_hid, 3)),
|
31 |
+
('relu_2', nn.ReLU()),
|
32 |
+
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
|
33 |
+
('relu_3', nn.ReLU()),
|
34 |
+
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
|
35 |
+
('relu_4', nn.ReLU()),
|
36 |
+
('conv_4', make_conv(self.n_hid, self.n_out, 1)),]))
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
39 |
+
return self.id_path(x) + self.post_gain * self.res_path(x)
|
40 |
+
|
41 |
+
@attr.s(eq=False, repr=False)
|
42 |
+
class Encoder(nn.Module):
|
43 |
+
group_count: int = 4
|
44 |
+
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
|
45 |
+
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
|
46 |
+
input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
|
47 |
+
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
|
48 |
+
|
49 |
+
device: torch.device = attr.ib(default=torch.device('cpu'))
|
50 |
+
requires_grad: bool = attr.ib(default=False)
|
51 |
+
use_mixed_precision: bool = attr.ib(default=True)
|
52 |
+
|
53 |
+
def __attrs_post_init__(self) -> None:
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
blk_range = range(self.n_blk_per_group)
|
57 |
+
n_layers = self.group_count * self.n_blk_per_group
|
58 |
+
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
|
59 |
+
make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device,
|
60 |
+
requires_grad=self.requires_grad)
|
61 |
+
|
62 |
+
self.blocks = nn.Sequential(OrderedDict([
|
63 |
+
('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
|
64 |
+
('group_1', nn.Sequential(OrderedDict([
|
65 |
+
*[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
|
66 |
+
('pool', nn.MaxPool2d(kernel_size=2)),
|
67 |
+
]))),
|
68 |
+
('group_2', nn.Sequential(OrderedDict([
|
69 |
+
*[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
|
70 |
+
('pool', nn.MaxPool2d(kernel_size=2)),
|
71 |
+
]))),
|
72 |
+
('group_3', nn.Sequential(OrderedDict([
|
73 |
+
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
|
74 |
+
('pool', nn.MaxPool2d(kernel_size=2)),
|
75 |
+
]))),
|
76 |
+
('group_4', nn.Sequential(OrderedDict([
|
77 |
+
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
|
78 |
+
]))),
|
79 |
+
('output', nn.Sequential(OrderedDict([
|
80 |
+
('relu', nn.ReLU()),
|
81 |
+
('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)),
|
82 |
+
]))),
|
83 |
+
]))
|
84 |
+
|
85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
+
if len(x.shape) != 4:
|
87 |
+
raise ValueError(f'input shape {x.shape} is not 4d')
|
88 |
+
if x.shape[1] != self.input_channels:
|
89 |
+
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
|
90 |
+
if x.dtype != torch.float32:
|
91 |
+
raise ValueError('input must have dtype torch.float32')
|
92 |
+
|
93 |
+
return self.blocks(x)
|
utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import attr
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
logit_laplace_eps: float = 0.1
|
9 |
+
|
10 |
+
@attr.s(eq=False)
|
11 |
+
class Conv2d(nn.Module):
|
12 |
+
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
13 |
+
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
14 |
+
kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
|
15 |
+
|
16 |
+
use_float16: bool = attr.ib(default=True)
|
17 |
+
device: torch.device = attr.ib(default=torch.device('cpu'))
|
18 |
+
requires_grad: bool = attr.ib(default=False)
|
19 |
+
|
20 |
+
def __attrs_post_init__(self) -> None:
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
|
24 |
+
device=self.device, requires_grad=self.requires_grad)
|
25 |
+
w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
|
26 |
+
|
27 |
+
b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
|
28 |
+
requires_grad=self.requires_grad)
|
29 |
+
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
|
30 |
+
|
31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
32 |
+
if self.use_float16 and 'cuda' in self.w.device.type:
|
33 |
+
if x.dtype != torch.float16:
|
34 |
+
x = x.half()
|
35 |
+
|
36 |
+
w, b = self.w.half(), self.b.half()
|
37 |
+
else:
|
38 |
+
if x.dtype != torch.float32:
|
39 |
+
x = x.float()
|
40 |
+
|
41 |
+
w, b = self.w, self.b
|
42 |
+
|
43 |
+
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
|
44 |
+
|
45 |
+
def map_pixels(x: torch.Tensor) -> torch.Tensor:
|
46 |
+
if len(x.shape) != 4:
|
47 |
+
raise ValueError('expected input to be 4d')
|
48 |
+
if x.dtype != torch.float:
|
49 |
+
raise ValueError('expected input to have type float')
|
50 |
+
|
51 |
+
return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps
|
52 |
+
|
53 |
+
def unmap_pixels(x: torch.Tensor) -> torch.Tensor:
|
54 |
+
if len(x.shape) != 4:
|
55 |
+
raise ValueError('expected input to be 4d')
|
56 |
+
if x.dtype != torch.float:
|
57 |
+
raise ValueError('expected input to have type float')
|
58 |
+
|
59 |
+
return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1)
|