File size: 1,727 Bytes
da48dbe
 
487ee6d
da48dbe
 
 
 
 
 
 
 
 
 
fb140f6
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb140f6
da48dbe
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResnetEncoder(nn.Module):
    def __init__(self, append_layers=None):
        super(ResnetEncoder, self).__init__()
        from . import resnet

        # feature_size = 2048
        self.feature_dim = 2048
        self.encoder = resnet.load_ResNet50Model()    # out: 2048
        # regressor
        self.append_layers = append_layers

    def forward(self, inputs):
        """inputs: [bz, 3, h, w], range: [0,1]"""
        features = self.encoder(inputs)
        if self.append_layers:
            features = self.last_op(features)
        return features


class MLP(nn.Module):
    def __init__(self, channels=[2048, 1024, 1], last_op=None):
        super(MLP, self).__init__()
        layers = []

        for l in range(0, len(channels) - 1):
            layers.append(nn.Linear(channels[l], channels[l + 1]))
            if l < len(channels) - 2:
                layers.append(nn.ReLU())
        if last_op:
            layers.append(last_op)

        self.layers = nn.Sequential(*layers)

    def forward(self, inputs):
        outs = self.layers(inputs)
        return outs


class HRNEncoder(nn.Module):
    def __init__(self, append_layers=None):
        super(HRNEncoder, self).__init__()
        from . import hrnet

        self.feature_dim = 2048
        self.encoder = hrnet.load_HRNet(pretrained=True)    # out: 2048
        # regressor
        self.append_layers = append_layers

    def forward(self, inputs):
        """inputs: [bz, 3, h, w], range: [-1,1]"""
        features = self.encoder(inputs)["concat"]
        if self.append_layers:
            features = self.last_op(features)
        return features