File size: 9,321 Bytes
ae9ab2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import torch
import torch.nn as nn

from typing import Any, Tuple, Union

from utils import (
    ImageType,
    crop_image_part,
)

from layers import (
    SpectralConv2d,
    InitLayer,
    SLEBlock,
    UpsampleBlockT1,
    UpsampleBlockT2,
    DownsampleBlockT1,
    DownsampleBlockT2,
    Decoder,
)

from huggan.pytorch.huggan_mixin import HugGANModelHubMixin


class Generator(nn.Module, HugGANModelHubMixin):

    def __init__(self, in_channels: int,
                       out_channels: int):
        super().__init__()

        self._channels = {
                4:    1024,
                8:    512,
                16:   256,
                32:   128,
                64:   128,
                128:  64,
                256:  32,
                512:  16,
                1024: 8,
            }

        self._init = InitLayer(
                in_channels=in_channels,
                out_channels=self._channels[4],
            )

        self._upsample_8    = UpsampleBlockT2(in_channels=self._channels[4],   out_channels=self._channels[8]   )
        self._upsample_16   = UpsampleBlockT1(in_channels=self._channels[8],   out_channels=self._channels[16]  )
        self._upsample_32   = UpsampleBlockT2(in_channels=self._channels[16],  out_channels=self._channels[32]  )
        self._upsample_64   = UpsampleBlockT1(in_channels=self._channels[32],  out_channels=self._channels[64]  )
        self._upsample_128  = UpsampleBlockT2(in_channels=self._channels[64],  out_channels=self._channels[128] )
        self._upsample_256  = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] )
        self._upsample_512  = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] )
        self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024])

        self._sle_64  = SLEBlock(in_channels=self._channels[4],  out_channels=self._channels[64] )
        self._sle_128 = SLEBlock(in_channels=self._channels[8],  out_channels=self._channels[128])
        self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256])
        self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512])

        self._out_128 = nn.Sequential(
                SpectralConv2d(
                    in_channels=self._channels[128],
                    out_channels=out_channels,
                    kernel_size=1,
                    stride=1,
                    padding='same',
                    bias=False,
                ),
                nn.Tanh(),
            )

        self._out_1024 = nn.Sequential(
                SpectralConv2d(
                    in_channels=self._channels[1024],
                    out_channels=out_channels,
                    kernel_size=3,
                    stride=1,
                    padding='same',
                    bias=False,
                ),
                nn.Tanh(),
            )

    def forward(self, input: torch.Tensor) -> \
            Tuple[torch.Tensor, torch.Tensor]:
        size_4  = self._init(input)
        size_8  = self._upsample_8(size_4)
        size_16 = self._upsample_16(size_8)
        size_32 = self._upsample_32(size_16)

        size_64  = self._sle_64 (size_4,  self._upsample_64 (size_32) )
        size_128 = self._sle_128(size_8,  self._upsample_128(size_64) )
        size_256 = self._sle_256(size_16, self._upsample_256(size_128))
        size_512 = self._sle_512(size_32, self._upsample_512(size_256))

        size_1024 = self._upsample_1024(size_512)

        out_128  = self._out_128 (size_128)
        out_1024 = self._out_1024(size_1024)
        return out_1024, out_128


class Discriminrator(nn.Module, HugGANModelHubMixin):

    def __init__(self, in_channels: int):
        super().__init__()

        self._channels = {
                4:    1024,
                8:    512,
                16:   256,
                32:   128,
                64:   128,
                128:  64,
                256:  32,
                512:  16,
                1024: 8,
            }

        self._init = nn.Sequential(
                SpectralConv2d(
                        in_channels=in_channels,
                        out_channels=self._channels[1024],
                        kernel_size=4,
                        stride=2,
                        padding=1,
                        bias=False,
                    ),
                nn.LeakyReLU(negative_slope=0.2),
                SpectralConv2d(
                        in_channels=self._channels[1024],
                        out_channels=self._channels[512],
                        kernel_size=4,
                        stride=2,
                        padding=1,
                        bias=False,
                    ),
                nn.BatchNorm2d(num_features=self._channels[512]),
                nn.LeakyReLU(negative_slope=0.2),
            )

        self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256])
        self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128])
        self._downsample_64  = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] )
        self._downsample_32  = DownsampleBlockT2(in_channels=self._channels[64],  out_channels=self._channels[32] )
        self._downsample_16  = DownsampleBlockT2(in_channels=self._channels[32],  out_channels=self._channels[16] )

        self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64])
        self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32])
        self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16])

        self._small_track = nn.Sequential(
                SpectralConv2d(
                        in_channels=in_channels,
                        out_channels=self._channels[256],
                        kernel_size=4,
                        stride=2,
                        padding=1,
                        bias=False,
                    ),
                nn.LeakyReLU(negative_slope=0.2),
                DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]),
                DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ),
                DownsampleBlockT1(in_channels=self._channels[64],  out_channels=self._channels[32] ),
            )

        self._features_large = nn.Sequential(
                SpectralConv2d(
                        in_channels=self._channels[16] ,
                        out_channels=self._channels[8],
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        bias=False,
                    ),
                nn.BatchNorm2d(num_features=self._channels[8]),
                nn.LeakyReLU(negative_slope=0.2),
                SpectralConv2d(
                        in_channels=self._channels[8],
                        out_channels=1,
                        kernel_size=4,
                        stride=1,
                        padding=0,
                        bias=False,
                    )
            )

        self._features_small = nn.Sequential(
                SpectralConv2d(
                        in_channels=self._channels[32],
                        out_channels=1,
                        kernel_size=4,
                        stride=1,
                        padding=0,
                        bias=False,
                    ),
            )

        self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3)
        self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3)
        self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3)

    def forward(self, images_1024: torch.Tensor,
                      images_128: torch.Tensor,
                      image_type: ImageType) -> \
            Union[
                torch.Tensor,
                Tuple[torch.Tensor, Tuple[Any, Any, Any]]
            ]:
        # large track

        down_512 = self._init(images_1024)
        down_256 = self._downsample_256(down_512)
        down_128 = self._downsample_128(down_256)

        down_64 = self._downsample_64(down_128)
        down_64 = self._sle_64(down_512, down_64)

        down_32 = self._downsample_32(down_64)
        down_32 = self._sle_32(down_256, down_32)

        down_16 = self._downsample_16(down_32)
        down_16 = self._sle_16(down_128, down_16)

        # small track

        down_small = self._small_track(images_128)

        # features

        features_large = self._features_large(down_16).view(-1)
        features_small = self._features_small(down_small).view(-1)
        features = torch.cat([features_large, features_small], dim=0)

        # decoder

        if image_type != ImageType.FAKE:
            dec_large = self._decoder_large(down_16)
            dec_small = self._decoder_small(down_small)
            dec_piece = self._decoder_piece(crop_image_part(down_32, image_type))
            return features, (dec_large, dec_small, dec_piece)

        return features