|
import abc |
|
from typing import Tuple, List |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv |
|
from saicinpainting.training.modules.multidilated_conv import MultidilatedConv |
|
|
|
|
|
class BaseDiscriminator(nn.Module): |
|
@abc.abstractmethod |
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
""" |
|
Predict scores and get intermediate activations. Useful for feature matching loss |
|
:return tuple (scores, list of intermediate activations) |
|
""" |
|
raise NotImplemented() |
|
|
|
|
|
def get_conv_block_ctor(kind='default'): |
|
if not isinstance(kind, str): |
|
return kind |
|
if kind == 'default': |
|
return nn.Conv2d |
|
if kind == 'depthwise': |
|
return DepthWiseSeperableConv |
|
if kind == 'multidilated': |
|
return MultidilatedConv |
|
raise ValueError(f'Unknown convolutional block kind {kind}') |
|
|
|
|
|
def get_norm_layer(kind='bn'): |
|
if not isinstance(kind, str): |
|
return kind |
|
if kind == 'bn': |
|
return nn.BatchNorm2d |
|
if kind == 'in': |
|
return nn.InstanceNorm2d |
|
raise ValueError(f'Unknown norm block kind {kind}') |
|
|
|
|
|
def get_activation(kind='tanh'): |
|
if kind == 'tanh': |
|
return nn.Tanh() |
|
if kind == 'sigmoid': |
|
return nn.Sigmoid() |
|
if kind is False: |
|
return nn.Identity() |
|
raise ValueError(f'Unknown activation kind {kind}') |
|
|
|
|
|
class SimpleMultiStepGenerator(nn.Module): |
|
def __init__(self, steps: List[nn.Module]): |
|
super().__init__() |
|
self.steps = nn.ModuleList(steps) |
|
|
|
def forward(self, x): |
|
cur_in = x |
|
outs = [] |
|
for step in self.steps: |
|
cur_out = step(cur_in) |
|
outs.append(cur_out) |
|
cur_in = torch.cat((cur_in, cur_out), dim=1) |
|
return torch.cat(outs[::-1], dim=1) |
|
|
|
def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features): |
|
if kind == 'convtranspose': |
|
return [nn.ConvTranspose2d(min(max_features, ngf * mult), |
|
min(max_features, int(ngf * mult / 2)), |
|
kernel_size=3, stride=2, padding=1, output_padding=1), |
|
norm_layer(min(max_features, int(ngf * mult / 2))), activation] |
|
elif kind == 'bilinear': |
|
return [nn.Upsample(scale_factor=2, mode='bilinear'), |
|
DepthWiseSeperableConv(min(max_features, ngf * mult), |
|
min(max_features, int(ngf * mult / 2)), |
|
kernel_size=3, stride=1, padding=1), |
|
norm_layer(min(max_features, int(ngf * mult / 2))), activation] |
|
else: |
|
raise Exception(f"Invalid deconv kind: {kind}") |