GraCo / isegm /model /modeling /hrformer.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
16.8 kB
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: RainbowSecret
## Microsoft Research
## [email protected], [email protected]
## Copyright (c) 2021
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# from .hrformer_helper.backbone_selector import BackboneSelector
from .hrformer_helper.hrt.module_helper import ModuleHelper
from .hrformer_helper.hrt.modules.spatial_ocr_block import SpatialGather_Module, SpatialOCR_Module
from .hrformer_helper.hrt.logger import Logger as Log
from .hrformer_helper.hrt.hrt_backbone import HRTBackbone, HRTBackbone_v2
class BackboneSelector(object):
def __init__(self, configer):
self.configer = configer
def get_backbone(self, **params):
backbone = self.configer.get("network", "backbone")
model = None
# if (
# "resnet" in backbone or "resnext" in backbone or "resnest" in backbone
# ) and "senet" not in backbone:
# model = ResNetBackbone(self.configer)(**params)
if "hrt" in backbone:
model = HRTBackbone(self.configer)(**params)
pass
# elif "hrnet" in backbone:
# model = HRNetBackbone(self.configer)(**params)
# elif "swin" in backbone:
# model = SwinTransformerBackbone(self.configer)(**params)
else:
Log.error("Backbone {} is invalid.".format(backbone))
exit(1)
return model
class HRT_B_OCR_V3(nn.Module):
def __init__(self, num_classes, in_ch=3, backbone='hrt_base', bn_type="torchbn", pretrained=None):
super(HRT_B_OCR_V3, self).__init__()
self.num_classes = num_classes
self.bn_type = bn_type
self.backbone = HRTBackbone_v2(backbone, pretrained, in_ch)()
in_channels = 1170
hidden_dim = 512
group_channel = math.gcd(in_channels, hidden_dim)
self.conv3x3 = nn.Sequential(
nn.Conv2d(
in_channels,
hidden_dim,
kernel_size=7,
stride=1,
padding=3,
groups=group_channel,
),
ModuleHelper.BNReLU(
hidden_dim, bn_type=self.bn_type
),
)
self.ocr_gather_head = SpatialGather_Module(self.num_classes)
self.ocr_distri_head = SpatialOCR_Module(
in_channels=hidden_dim,
key_channels=hidden_dim // 2,
out_channels=hidden_dim,
scale=1,
dropout=0.05,
bn_type=self.bn_type,
)
self.cls_head = nn.Conv2d(
hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
)
self.aux_head = nn.Sequential(
nn.Conv2d(
in_channels,
hidden_dim,
kernel_size=7,
stride=1,
padding=3,
groups=group_channel,
),
ModuleHelper.BNReLU(
hidden_dim, bn_type=self.bn_type
),
nn.Conv2d(
hidden_dim,
self.num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True,
),
)
def forward(self, x_):
x = self.backbone(x_)
_, _, h, w = x[0].size()
feat1 = x[0]
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
feats = torch.cat([feat1, feat2, feat3, feat4], 1)
out_aux = self.aux_head(feats)
feats = self.conv3x3(feats)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
out = self.cls_head(feats)
out_aux = F.interpolate(
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
out = F.interpolate(
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
return out_aux, out
class HRT_S_OCR_V2(nn.Module):
def __init__(self, num_classes, backbone='hrt_small', bn_type="torchbn", pretrained=None):
super(HRT_S_OCR_V2, self).__init__()
self.num_classes = num_classes
self.bn_type = bn_type
self.backbone = HRTBackbone_v2(backbone, pretrained)()
in_channels = 480
self.conv3x3 = nn.Sequential(
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
ModuleHelper.BNReLU(512, bn_type=self.bn_type),
)
self.ocr_gather_head = SpatialGather_Module(self.num_classes)
self.ocr_distri_head = SpatialOCR_Module(
in_channels=512,
key_channels=256,
out_channels=512,
scale=1,
dropout=0.05,
bn_type=self.bn_type,
)
self.cls_head = nn.Conv2d(
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
)
self.aux_head = nn.Sequential(
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
ModuleHelper.BNReLU(512, bn_type=self.bn_type),
nn.Conv2d(
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
),
)
def forward(self, x_):
x = self.backbone(x_)
_, _, h, w = x[0].size()
feat1 = x[0]
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
feats = torch.cat([feat1, feat2, feat3, feat4], 1)
out_aux = self.aux_head(feats)
feats = self.conv3x3(feats)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
out = self.cls_head(feats)
out_aux = F.interpolate(
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
out = F.interpolate(
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
return out_aux, out
class HRT_SMALL_OCR_V2(nn.Module):
def __init__(self, configer):
super(HRT_SMALL_OCR_V2, self).__init__()
self.configer = configer
self.num_classes = self.configer.get("data", "num_classes")
self.backbone = BackboneSelector(configer).get_backbone()
in_channels = 480
self.conv3x3 = nn.Sequential(
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")),
)
self.ocr_gather_head = SpatialGather_Module(self.num_classes)
self.ocr_distri_head = SpatialOCR_Module(
in_channels=512,
key_channels=256,
out_channels=512,
scale=1,
dropout=0.05,
bn_type=self.configer.get("network", "bn_type"),
)
self.cls_head = nn.Conv2d(
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
)
self.aux_head = nn.Sequential(
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")),
nn.Conv2d(
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
),
)
def forward(self, x_):
x = self.backbone(x_)
_, _, h, w = x[0].size()
feat1 = x[0]
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
feats = torch.cat([feat1, feat2, feat3, feat4], 1)
out_aux = self.aux_head(feats)
feats = self.conv3x3(feats)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
out = self.cls_head(feats)
out_aux = F.interpolate(
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
out = F.interpolate(
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
return out_aux, out
class HRT_BASE_OCR_V2(nn.Module):
def __init__(self, configer):
super(HRT_BASE_OCR_V2, self).__init__()
self.configer = configer
self.num_classes = self.configer.get("data", "num_classes")
self.backbone = BackboneSelector(configer).get_backbone()
in_channels = 1170
self.conv3x3 = nn.Sequential(
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")),
)
self.ocr_gather_head = SpatialGather_Module(self.num_classes)
self.ocr_distri_head = SpatialOCR_Module(
in_channels=512,
key_channels=256,
out_channels=512,
scale=1,
dropout=0.05,
bn_type=self.configer.get("network", "bn_type"),
)
self.cls_head = nn.Conv2d(
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
)
self.aux_head = nn.Sequential(
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")),
nn.Conv2d(
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
),
)
def forward(self, x_):
x = self.backbone(x_)
_, _, h, w = x[0].size()
feat1 = x[0]
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
feats = torch.cat([feat1, feat2, feat3, feat4], 1)
out_aux = self.aux_head(feats)
feats = self.conv3x3(feats)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
out = self.cls_head(feats)
out_aux = F.interpolate(
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
out = F.interpolate(
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
return out_aux, out
class HRT_SMALL_OCR_V3(nn.Module):
def __init__(self, configer):
super(HRT_SMALL_OCR_V3, self).__init__()
self.configer = configer
self.num_classes = self.configer.get("data", "num_classes")
self.backbone = BackboneSelector(configer).get_backbone()
in_channels = 480
hidden_dim = 512
group_channel = math.gcd(in_channels, hidden_dim)
self.conv3x3 = nn.Sequential(
nn.Conv2d(
in_channels,
hidden_dim,
kernel_size=7,
stride=1,
padding=3,
groups=group_channel,
),
ModuleHelper.BNReLU(
hidden_dim, bn_type=self.configer.get("network", "bn_type")
),
)
self.ocr_gather_head = SpatialGather_Module(self.num_classes)
self.ocr_distri_head = SpatialOCR_Module(
in_channels=hidden_dim,
key_channels=hidden_dim // 2,
out_channels=hidden_dim,
scale=1,
dropout=0.05,
bn_type=self.configer.get("network", "bn_type"),
)
self.cls_head = nn.Conv2d(
hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
)
self.aux_head = nn.Sequential(
nn.Conv2d(
in_channels,
hidden_dim,
kernel_size=7,
stride=1,
padding=3,
groups=group_channel,
),
ModuleHelper.BNReLU(
hidden_dim, bn_type=self.configer.get("network", "bn_type")
),
nn.Conv2d(
hidden_dim,
self.num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True,
),
)
def forward(self, x_):
x = self.backbone(x_)
_, _, h, w = x[0].size()
feat1 = x[0]
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
feats = torch.cat([feat1, feat2, feat3, feat4], 1)
out_aux = self.aux_head(feats)
feats = self.conv3x3(feats)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
out = self.cls_head(feats)
out_aux = F.interpolate(
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
out = F.interpolate(
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
return out_aux, out
class HRT_BASE_OCR_V3(nn.Module):
def __init__(self, configer):
super(HRT_BASE_OCR_V3, self).__init__()
self.configer = configer
self.num_classes = self.configer.get("data", "num_classes")
self.backbone = BackboneSelector(configer).get_backbone()
in_channels = 1170
hidden_dim = 512
group_channel = math.gcd(in_channels, hidden_dim)
self.conv3x3 = nn.Sequential(
nn.Conv2d(
in_channels,
hidden_dim,
kernel_size=7,
stride=1,
padding=3,
groups=group_channel,
),
ModuleHelper.BNReLU(
hidden_dim, bn_type=self.configer.get("network", "bn_type")
),
)
self.ocr_gather_head = SpatialGather_Module(self.num_classes)
self.ocr_distri_head = SpatialOCR_Module(
in_channels=hidden_dim,
key_channels=hidden_dim // 2,
out_channels=hidden_dim,
scale=1,
dropout=0.05,
bn_type=self.configer.get("network", "bn_type"),
)
self.cls_head = nn.Conv2d(
hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True
)
self.aux_head = nn.Sequential(
nn.Conv2d(
in_channels,
hidden_dim,
kernel_size=7,
stride=1,
padding=3,
groups=group_channel,
),
ModuleHelper.BNReLU(
hidden_dim, bn_type=self.configer.get("network", "bn_type")
),
nn.Conv2d(
hidden_dim,
self.num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True,
),
)
def forward(self, x_):
x = self.backbone(x_)
_, _, h, w = x[0].size()
feat1 = x[0]
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True)
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True)
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True)
feats = torch.cat([feat1, feat2, feat3, feat4], 1)
out_aux = self.aux_head(feats)
feats = self.conv3x3(feats)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
out = self.cls_head(feats)
out_aux = F.interpolate(
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
out = F.interpolate(
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True
)
return out_aux, out