qubvel-hf's picture
qubvel-hf HF staff
Init project
c509e76
from __future__ import absolute_import
import math
import numpy as np
import sys
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import init
def conv3x3_block(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1)
block = nn.Sequential(
conv_layer,
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True),
)
return block
class STNHead(nn.Module):
def __init__(self, in_planes, num_ctrlpoints, activation='none'):
super(STNHead, self).__init__()
self.in_planes = in_planes
self.num_ctrlpoints = num_ctrlpoints
self.activation = activation
self.stn_convnet = nn.Sequential(
conv3x3_block(in_planes, 32), # 32*64
nn.MaxPool2d(kernel_size=2, stride=2),
conv3x3_block(32, 64), # 16*32
nn.MaxPool2d(kernel_size=2, stride=2),
conv3x3_block(64, 128), # 8*16
nn.MaxPool2d(kernel_size=2, stride=2),
conv3x3_block(128, 256), # 4*8
nn.MaxPool2d(kernel_size=2, stride=2),
conv3x3_block(256, 256), # 2*4,
nn.MaxPool2d(kernel_size=2, stride=2),
conv3x3_block(256, 256)) # 1*2 > 256*8*8
self.stn_fc1 = nn.Sequential(
# nn.Linear(2*256, 512),
nn.Linear(8*8*256, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True))
self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2)
self.init_weights(self.stn_convnet)
self.init_weights(self.stn_fc1)
self.init_stn(self.stn_fc2)
def init_weights(self, module):
for m in module.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.001)
m.bias.data.zero_()
def init_stn(self, stn_fc2):
# margin = 0.01
# sampling_num_per_side = int(self.num_ctrlpoints / 2)
# ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side)
# ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
# ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin)
# ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
# ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
# ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
margin_x, margin_y = 0.35,0.35
# margin_x, margin_y = 0,0
num_ctrl_pts_per_side = (self.num_ctrlpoints-4) // 4 +2
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0).astype(np.float32)
if self.activation is 'none':
pass
elif self.activation == 'sigmoid':
ctrl_points = -np.log(1. / ctrl_points - 1.)
stn_fc2.weight.data.zero_()
stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
def forward(self, x):
x = self.stn_convnet(x)
batch_size, _, h, w = x.size()
x = x.view(batch_size, -1)
img_feat = self.stn_fc1(x)
x = self.stn_fc2(0.1 * img_feat)
if self.activation == 'sigmoid':
x = F.sigmoid(x)
x = x.view(-1, self.num_ctrlpoints, 2)
return img_feat, x
if __name__ == "__main__":
in_planes = 3
num_ctrlpoints = 20
activation='none' # 'sigmoid'
stn_head = STNHead(in_planes, num_ctrlpoints, activation)
input = torch.randn(10, 3, 32, 64)
control_points = stn_head(input)
print(control_points.size())