Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
@@ -118,3 +118,104 @@ class Generator(nn.Module):
|
|
118 |
final_upscaled = self.rgb_layers[steps - 1](upscaled)
|
119 |
final_out = self.rgb_layers[steps](out)
|
120 |
return self.fade_in(alpha, final_upscaled, final_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
final_upscaled = self.rgb_layers[steps - 1](upscaled)
|
119 |
final_out = self.rgb_layers[steps](out)
|
120 |
return self.fade_in(alpha, final_upscaled, final_out)
|
121 |
+
|
122 |
+
|
123 |
+
class Discriminator(nn.Module):
|
124 |
+
def __init__(self, z_dim, in_channels, img_channels=3):
|
125 |
+
super(Discriminator, self).__init__()
|
126 |
+
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
|
127 |
+
self.leaky = nn.LeakyReLU(0.2)
|
128 |
+
|
129 |
+
# here we work back ways from factors because the discriminator
|
130 |
+
# should be mirrored from the generator. So the first prog_block and
|
131 |
+
# rgb layer we append will work for input size 1024x1024, then 512->256-> etc
|
132 |
+
for i in range(len(factors) - 1, 0, -1):
|
133 |
+
conv_in = int(in_channels * factors[i])
|
134 |
+
conv_out = int(in_channels * factors[i - 1])
|
135 |
+
self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
|
136 |
+
self.rgb_layers.append(
|
137 |
+
WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
|
138 |
+
)
|
139 |
+
|
140 |
+
# perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
|
141 |
+
# did this to "mirror" the generator initial_rgb
|
142 |
+
self.initial_rgb = WSConv2d(
|
143 |
+
img_channels, in_channels, kernel_size=1, stride=1, padding=0
|
144 |
+
)
|
145 |
+
self.rgb_layers.append(self.initial_rgb)
|
146 |
+
self.avg_pool = nn.AvgPool2d(
|
147 |
+
kernel_size=2, stride=2
|
148 |
+
) # down sampling using avg pool
|
149 |
+
|
150 |
+
# this is the block for 4x4 input size
|
151 |
+
self.final_block = nn.Sequential(
|
152 |
+
# +1 to in_channels because we concatenate from MiniBatch std
|
153 |
+
WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
|
154 |
+
nn.LeakyReLU(0.2),
|
155 |
+
WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
|
156 |
+
nn.LeakyReLU(0.2),
|
157 |
+
WSConv2d(
|
158 |
+
in_channels, 1, kernel_size=1, padding=0, stride=1
|
159 |
+
), # we use this instead of linear layer
|
160 |
+
)
|
161 |
+
|
162 |
+
def fade_in(self, alpha, downscaled, out):
|
163 |
+
"""Used to fade in downscaled using avg pooling and output from CNN"""
|
164 |
+
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
|
165 |
+
return alpha * out + (1 - alpha) * downscaled
|
166 |
+
|
167 |
+
def minibatch_std(self, x):
|
168 |
+
batch_statistics = (
|
169 |
+
torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
|
170 |
+
)
|
171 |
+
# we take the std for each example (across all channels, and pixels) then we repeat it
|
172 |
+
# for a single channel and concatenate it with the image. In this way the discriminator
|
173 |
+
# will get information about the variation in the batch/image
|
174 |
+
return torch.cat([x, batch_statistics], dim=1)
|
175 |
+
|
176 |
+
def forward(self, x, alpha, steps):
|
177 |
+
# where we should start in the list of prog_blocks, maybe a bit confusing but
|
178 |
+
# the last is for the 4x4. So example let's say steps=1, then we should start
|
179 |
+
# at the second to last because input_size will be 8x8. If steps==0 we just
|
180 |
+
# use the final block
|
181 |
+
cur_step = len(self.prog_blocks) - steps
|
182 |
+
|
183 |
+
# convert from rgb as initial step, this will depend on
|
184 |
+
# the image size (each will have it's on rgb layer)
|
185 |
+
out = self.leaky(self.rgb_layers[cur_step](x))
|
186 |
+
|
187 |
+
if steps == 0: # i.e, image is 4x4
|
188 |
+
out = self.minibatch_std(out)
|
189 |
+
return self.final_block(out).view(out.shape[0], -1)
|
190 |
+
|
191 |
+
# because prog_blocks might change the channels, for down scale we use rgb_layer
|
192 |
+
# from previous/smaller size which in our case correlates to +1 in the indexing
|
193 |
+
downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
|
194 |
+
out = self.avg_pool(self.prog_blocks[cur_step](out))
|
195 |
+
|
196 |
+
# the fade_in is done first between the downscaled and the input
|
197 |
+
# this is opposite from the generator
|
198 |
+
out = self.fade_in(alpha, downscaled, out)
|
199 |
+
|
200 |
+
for step in range(cur_step + 1, len(self.prog_blocks)):
|
201 |
+
out = self.prog_blocks[step](out)
|
202 |
+
out = self.avg_pool(out)
|
203 |
+
|
204 |
+
out = self.minibatch_std(out)
|
205 |
+
return self.final_block(out).view(out.shape[0], -1)
|
206 |
+
|
207 |
+
|
208 |
+
if __name__ == "__main__":
|
209 |
+
Z_DIM = 100
|
210 |
+
IN_CHANNELS = 256
|
211 |
+
gen = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
|
212 |
+
critic = Discriminator(Z_DIM, IN_CHANNELS, img_channels=3)
|
213 |
+
|
214 |
+
for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
|
215 |
+
num_steps = int(log2(img_size / 4))
|
216 |
+
x = torch.randn((1, Z_DIM, 1, 1))
|
217 |
+
z = gen(x, 0.5, steps=num_steps)
|
218 |
+
assert z.shape == (1, 3, img_size, img_size)
|
219 |
+
out = critic(z, alpha=0.5, steps=num_steps)
|
220 |
+
assert out.shape == (1, 1)
|
221 |
+
print(f"Success! At img size: {img_size}")
|