|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
from saicinpainting.training.modules.ffc0 import FFCResnetBlock |
|
from saicinpainting.training.modules.ffc0 import FFC_BN_ACT |
|
|
|
|
|
|
|
|
|
class myFFCResblock(nn.Module): |
|
def __init__(self, input_nc, output_nc, n_blocks=2, norm_layer=nn.BatchNorm2d, |
|
padding_type='reflect', activation_layer=nn.ReLU, |
|
resnet_conv_kwargs={}, |
|
spatial_transform_layers=None, spatial_transform_kwargs={}, |
|
add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}): |
|
assert (n_blocks >= 0) |
|
|
|
super().__init__() |
|
self.initial = FFC_BN_ACT(input_nc, input_nc, kernel_size=3, padding=1, dilation=1, |
|
norm_layer=norm_layer, activation_layer=activation_layer, |
|
padding_type=padding_type, |
|
**resnet_conv_kwargs) |
|
|
|
self.ffcresblock = FFCResnetBlock(input_nc, padding_type=padding_type, activation_layer=activation_layer, |
|
norm_layer=norm_layer, **resnet_conv_kwargs) |
|
|
|
|
|
|
|
self.final = FFC_BN_ACT(input_nc, output_nc, kernel_size=3, padding=1, dilation=1, |
|
norm_layer=norm_layer, |
|
activation_layer=activation_layer, |
|
padding_type=padding_type, |
|
**resnet_conv_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x_l, x_g = self.initial(x) |
|
|
|
x_l, x_g = self.ffcresblock(x_l, x_g) |
|
x_l, x_g = self.ffcresblock(x_l, x_g) |
|
|
|
out_ = torch.cat([x_l, x_g], 1) |
|
|
|
x_lout, x_gout =self.final(out_) |
|
|
|
out = torch.cat([x_lout, x_gout], 1) |
|
return out |
|
|
|
|
|
|
|
|