File size: 1,681 Bytes
bd1c686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


from saicinpainting.training.modules.ffc0 import FFCResnetBlock
from saicinpainting.training.modules.ffc0 import FFC_BN_ACT




class myFFCResblock(nn.Module):
    def __init__(self, input_nc, output_nc, n_blocks=2, norm_layer=nn.BatchNorm2d,     #128--->64
                 padding_type='reflect', activation_layer=nn.ReLU,
                 resnet_conv_kwargs={},
                 spatial_transform_layers=None, spatial_transform_kwargs={},
                 add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
        assert (n_blocks >= 0)
        
        super().__init__()
        self.initial = FFC_BN_ACT(input_nc, input_nc, kernel_size=3, padding=1, dilation=1,
            norm_layer=norm_layer, activation_layer=activation_layer,
            padding_type=padding_type,
            **resnet_conv_kwargs)

        self.ffcresblock = FFCResnetBlock(input_nc, padding_type=padding_type, activation_layer=activation_layer,
            norm_layer=norm_layer, **resnet_conv_kwargs)

    

        self.final = FFC_BN_ACT(input_nc, output_nc, kernel_size=3, padding=1, dilation=1,     
            norm_layer=norm_layer,
            activation_layer=activation_layer,
            padding_type=padding_type,
            **resnet_conv_kwargs)







    def forward(self, x):

        x_l, x_g = self.initial(x)

        x_l, x_g = self.ffcresblock(x_l, x_g)
        x_l, x_g = self.ffcresblock(x_l, x_g)
        
        out_ = torch.cat([x_l, x_g], 1)

        x_lout, x_gout =self.final(out_)
        
        out = torch.cat([x_lout, x_gout], 1)
        return out