File size: 2,862 Bytes
dd78229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import backbone_picie as backbone


class PanopticFPN(nn.Module):
    def __init__(self, args):
        super(PanopticFPN, self).__init__()
        self.backbone = backbone.__dict__[args.arch](pretrained=args.pretrain)
        if args.arch == 'vit_small':
            self.decoder = FPNDecoderViT(args)
        else:
            self.decoder = FPNDecoder(args)

    def forward(self, x, encoder_features=False, decoder_features=False):
        feats = self.backbone(x)
        dec_outs = self.decoder(feats)

        if encoder_features:
            return feats['res5'], dec_outs
        else:
            return dec_outs


class FPNDecoder(nn.Module):
    def __init__(self, args):
        super(FPNDecoder, self).__init__()
        if args.arch == 'resnet18':
            mfactor = 1
            out_dim = 128
        else:
            mfactor = 4
            out_dim = 256

        self.layer4 = nn.Conv2d(512 * mfactor // 8, out_dim, kernel_size=1, stride=1, padding=0)
        self.layer3 = nn.Conv2d(512 * mfactor // 4, out_dim, kernel_size=1, stride=1, padding=0)
        self.layer2 = nn.Conv2d(512 * mfactor // 2, out_dim, kernel_size=1, stride=1, padding=0)
        self.layer1 = nn.Conv2d(512 * mfactor, out_dim, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        o1 = self.layer1(x['res5'])
        o2 = self.upsample_add(o1, self.layer2(x['res4']))
        o3 = self.upsample_add(o2, self.layer3(x['res3']))
        o4 = self.upsample_add(o3, self.layer4(x['res2']))

        return o4

    def upsample_add(self, x, y):
        _, _, H, W = y.size()

        return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) + y


class FPNDecoderViT(nn.Module):
    def __init__(self, args):
        super(FPNDecoderViT, self).__init__()
        if args.arch == 'resnet18' or args.arch == 'vit_small':
            mfactor = 1
            out_dim = 128
        else:
            mfactor = 4
            out_dim = 256

        self.upsample_rate = 4

        self.layer4 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
        self.layer3 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
        self.layer2 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
        self.layer1 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        o1 = self.layer1(x[3])
        o1 = F.interpolate(o1, scale_factor=4, mode='bilinear', align_corners=False)
        o2 = self.upsample_add(o1, self.layer2(x[2]))
        o3 = self.upsample_add(o2, self.layer3(x[1]))
        o4 = self.upsample_add(o3, self.layer4(x[0]))

        return o4

    def upsample_add(self, x, y):
        return F.interpolate(y, scale_factor=self.upsample_rate, mode='bilinear', align_corners=False) + x