Thesis / models /yowo /encoder.py
Ryan-Pham's picture
Upload 103 files
beb7843 verified
raw
history blame
4.22 kB
import torch
import torch.nn as nn
from ..basic.conv import Conv2d
# Channel Self Attetion Module
class CSAM(nn.Module):
""" Channel attention module """
def __init__(self):
super(CSAM, self).__init__()
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B x C x H x W )
returns :
out : attention value + input feature
attention: B x C x C
"""
B, C, H, W = x.size()
# query / key / value
query = x.view(B, C, -1)
key = x.view(B, C, -1).permute(0, 2, 1)
value = x.view(B, C, -1)
# attention matrix
energy = torch.bmm(query, key)
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
attention = self.softmax(energy_new)
# attention
out = torch.bmm(attention, value)
out = out.view(B, C, H, W)
# output
out = self.gamma * out + x
return out
# Spatial Self Attetion Module
class SSAM(nn.Module):
""" Channel attention module """
def __init__(self):
super(SSAM, self).__init__()
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B x C x H x W )
returns :
out : attention value + input feature
attention: B x C x C
"""
B, C, H, W = x.size()
# query / key / value
query = x.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
key = x.view(B, C, -1) # [B, C, N]
value = x.view(B, C, -1).permute(0, 2, 1) # [B, N, C]
# attention matrix
energy = torch.bmm(query, key)
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
attention = self.softmax(energy_new)
# attention
out = torch.bmm(attention, value)
out = out.permute(0, 2, 1).contiguous().view(B, C, H, W)
# output
out = self.gamma * out + x
return out
# Channel Encoder
class ChannelEncoder(nn.Module):
def __init__(self, in_dim, out_dim, act_type='', norm_type=''):
super().__init__()
self.fuse_convs = nn.Sequential(
Conv2d(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type),
Conv2d(out_dim, out_dim, k=3, p=1, act_type=act_type, norm_type=norm_type),
CSAM(),
Conv2d(out_dim, out_dim, k=3, p=1, act_type=act_type, norm_type=norm_type),
nn.Dropout(0.1, inplace=False),
nn.Conv2d(out_dim, out_dim, kernel_size=1)
)
def forward(self, x1, x2):
"""
x: [B, C, H, W]
"""
x = torch.cat([x1, x2], dim=1)
# [B, CN, H, W] -> [B, C, H, W]
x = self.fuse_convs(x)
return x
# Spatial Encoder
class SpatialEncoder(nn.Module):
def __init__(self, in_dim, out_dim, act_type='', norm_type=''):
super().__init__()
self.fuse_convs = nn.Sequential(
Conv2d(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type),
Conv2d(out_dim, out_dim, k=3, p=1, act_type=act_type, norm_type=norm_type),
SSAM(),
Conv2d(out_dim, out_dim, k=3, p=1, act_type=act_type, norm_type=norm_type),
nn.Dropout(0.1, inplace=False),
nn.Conv2d(out_dim, out_dim, kernel_size=1)
)
def forward(self, x):
"""
x: [B, C, H, W]
"""
x = self.fuse_convs(x)
return x
def build_channel_encoder(cfg, in_dim, out_dim):
encoder = ChannelEncoder(
in_dim=in_dim,
out_dim=out_dim,
act_type=cfg['head_act'],
norm_type=cfg['head_norm']
)
return encoder
def build_spatial_encoder(cfg, in_dim, out_dim):
encoder = SpatialEncoder(
in_dim=in_dim,
out_dim=out_dim,
act_type=cfg['head_act'],
norm_type=cfg['head_norm']
)
return encoder