Spaces:
Running
Running
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
## Created by: RainbowSecret | |
## Microsoft Research | |
## [email protected], [email protected] | |
## Copyright (c) 2021 | |
## | |
## This source code is licensed under the MIT-style license found in the | |
## LICENSE file in the root directory of this source tree | |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
import os | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# from .hrformer_helper.backbone_selector import BackboneSelector | |
from .hrformer_helper.hrt.module_helper import ModuleHelper | |
from .hrformer_helper.hrt.modules.spatial_ocr_block import SpatialGather_Module, SpatialOCR_Module | |
from .hrformer_helper.hrt.logger import Logger as Log | |
from .hrformer_helper.hrt.hrt_backbone import HRTBackbone, HRTBackbone_v2 | |
class BackboneSelector(object): | |
def __init__(self, configer): | |
self.configer = configer | |
def get_backbone(self, **params): | |
backbone = self.configer.get("network", "backbone") | |
model = None | |
# if ( | |
# "resnet" in backbone or "resnext" in backbone or "resnest" in backbone | |
# ) and "senet" not in backbone: | |
# model = ResNetBackbone(self.configer)(**params) | |
if "hrt" in backbone: | |
model = HRTBackbone(self.configer)(**params) | |
pass | |
# elif "hrnet" in backbone: | |
# model = HRNetBackbone(self.configer)(**params) | |
# elif "swin" in backbone: | |
# model = SwinTransformerBackbone(self.configer)(**params) | |
else: | |
Log.error("Backbone {} is invalid.".format(backbone)) | |
exit(1) | |
return model | |
class HRT_B_OCR_V3(nn.Module): | |
def __init__(self, num_classes, in_ch=3, backbone='hrt_base', bn_type="torchbn", pretrained=None): | |
super(HRT_B_OCR_V3, self).__init__() | |
self.num_classes = num_classes | |
self.bn_type = bn_type | |
self.backbone = HRTBackbone_v2(backbone, pretrained, in_ch)() | |
in_channels = 1170 | |
hidden_dim = 512 | |
group_channel = math.gcd(in_channels, hidden_dim) | |
self.conv3x3 = nn.Sequential( | |
nn.Conv2d( | |
in_channels, | |
hidden_dim, | |
kernel_size=7, | |
stride=1, | |
padding=3, | |
groups=group_channel, | |
), | |
ModuleHelper.BNReLU( | |
hidden_dim, bn_type=self.bn_type | |
), | |
) | |
self.ocr_gather_head = SpatialGather_Module(self.num_classes) | |
self.ocr_distri_head = SpatialOCR_Module( | |
in_channels=hidden_dim, | |
key_channels=hidden_dim // 2, | |
out_channels=hidden_dim, | |
scale=1, | |
dropout=0.05, | |
bn_type=self.bn_type, | |
) | |
self.cls_head = nn.Conv2d( | |
hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True | |
) | |
self.aux_head = nn.Sequential( | |
nn.Conv2d( | |
in_channels, | |
hidden_dim, | |
kernel_size=7, | |
stride=1, | |
padding=3, | |
groups=group_channel, | |
), | |
ModuleHelper.BNReLU( | |
hidden_dim, bn_type=self.bn_type | |
), | |
nn.Conv2d( | |
hidden_dim, | |
self.num_classes, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=True, | |
), | |
) | |
def forward(self, x_): | |
x = self.backbone(x_) | |
_, _, h, w = x[0].size() | |
feat1 = x[0] | |
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) | |
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) | |
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) | |
feats = torch.cat([feat1, feat2, feat3, feat4], 1) | |
out_aux = self.aux_head(feats) | |
feats = self.conv3x3(feats) | |
context = self.ocr_gather_head(feats, out_aux) | |
feats = self.ocr_distri_head(feats, context) | |
out = self.cls_head(feats) | |
out_aux = F.interpolate( | |
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
out = F.interpolate( | |
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
return out_aux, out | |
class HRT_S_OCR_V2(nn.Module): | |
def __init__(self, num_classes, backbone='hrt_small', bn_type="torchbn", pretrained=None): | |
super(HRT_S_OCR_V2, self).__init__() | |
self.num_classes = num_classes | |
self.bn_type = bn_type | |
self.backbone = HRTBackbone_v2(backbone, pretrained)() | |
in_channels = 480 | |
self.conv3x3 = nn.Sequential( | |
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), | |
ModuleHelper.BNReLU(512, bn_type=self.bn_type), | |
) | |
self.ocr_gather_head = SpatialGather_Module(self.num_classes) | |
self.ocr_distri_head = SpatialOCR_Module( | |
in_channels=512, | |
key_channels=256, | |
out_channels=512, | |
scale=1, | |
dropout=0.05, | |
bn_type=self.bn_type, | |
) | |
self.cls_head = nn.Conv2d( | |
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True | |
) | |
self.aux_head = nn.Sequential( | |
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), | |
ModuleHelper.BNReLU(512, bn_type=self.bn_type), | |
nn.Conv2d( | |
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True | |
), | |
) | |
def forward(self, x_): | |
x = self.backbone(x_) | |
_, _, h, w = x[0].size() | |
feat1 = x[0] | |
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) | |
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) | |
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) | |
feats = torch.cat([feat1, feat2, feat3, feat4], 1) | |
out_aux = self.aux_head(feats) | |
feats = self.conv3x3(feats) | |
context = self.ocr_gather_head(feats, out_aux) | |
feats = self.ocr_distri_head(feats, context) | |
out = self.cls_head(feats) | |
out_aux = F.interpolate( | |
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
out = F.interpolate( | |
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
return out_aux, out | |
class HRT_SMALL_OCR_V2(nn.Module): | |
def __init__(self, configer): | |
super(HRT_SMALL_OCR_V2, self).__init__() | |
self.configer = configer | |
self.num_classes = self.configer.get("data", "num_classes") | |
self.backbone = BackboneSelector(configer).get_backbone() | |
in_channels = 480 | |
self.conv3x3 = nn.Sequential( | |
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), | |
ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")), | |
) | |
self.ocr_gather_head = SpatialGather_Module(self.num_classes) | |
self.ocr_distri_head = SpatialOCR_Module( | |
in_channels=512, | |
key_channels=256, | |
out_channels=512, | |
scale=1, | |
dropout=0.05, | |
bn_type=self.configer.get("network", "bn_type"), | |
) | |
self.cls_head = nn.Conv2d( | |
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True | |
) | |
self.aux_head = nn.Sequential( | |
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), | |
ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")), | |
nn.Conv2d( | |
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True | |
), | |
) | |
def forward(self, x_): | |
x = self.backbone(x_) | |
_, _, h, w = x[0].size() | |
feat1 = x[0] | |
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) | |
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) | |
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) | |
feats = torch.cat([feat1, feat2, feat3, feat4], 1) | |
out_aux = self.aux_head(feats) | |
feats = self.conv3x3(feats) | |
context = self.ocr_gather_head(feats, out_aux) | |
feats = self.ocr_distri_head(feats, context) | |
out = self.cls_head(feats) | |
out_aux = F.interpolate( | |
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
out = F.interpolate( | |
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
return out_aux, out | |
class HRT_BASE_OCR_V2(nn.Module): | |
def __init__(self, configer): | |
super(HRT_BASE_OCR_V2, self).__init__() | |
self.configer = configer | |
self.num_classes = self.configer.get("data", "num_classes") | |
self.backbone = BackboneSelector(configer).get_backbone() | |
in_channels = 1170 | |
self.conv3x3 = nn.Sequential( | |
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), | |
ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")), | |
) | |
self.ocr_gather_head = SpatialGather_Module(self.num_classes) | |
self.ocr_distri_head = SpatialOCR_Module( | |
in_channels=512, | |
key_channels=256, | |
out_channels=512, | |
scale=1, | |
dropout=0.05, | |
bn_type=self.configer.get("network", "bn_type"), | |
) | |
self.cls_head = nn.Conv2d( | |
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True | |
) | |
self.aux_head = nn.Sequential( | |
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), | |
ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")), | |
nn.Conv2d( | |
512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True | |
), | |
) | |
def forward(self, x_): | |
x = self.backbone(x_) | |
_, _, h, w = x[0].size() | |
feat1 = x[0] | |
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) | |
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) | |
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) | |
feats = torch.cat([feat1, feat2, feat3, feat4], 1) | |
out_aux = self.aux_head(feats) | |
feats = self.conv3x3(feats) | |
context = self.ocr_gather_head(feats, out_aux) | |
feats = self.ocr_distri_head(feats, context) | |
out = self.cls_head(feats) | |
out_aux = F.interpolate( | |
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
out = F.interpolate( | |
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
return out_aux, out | |
class HRT_SMALL_OCR_V3(nn.Module): | |
def __init__(self, configer): | |
super(HRT_SMALL_OCR_V3, self).__init__() | |
self.configer = configer | |
self.num_classes = self.configer.get("data", "num_classes") | |
self.backbone = BackboneSelector(configer).get_backbone() | |
in_channels = 480 | |
hidden_dim = 512 | |
group_channel = math.gcd(in_channels, hidden_dim) | |
self.conv3x3 = nn.Sequential( | |
nn.Conv2d( | |
in_channels, | |
hidden_dim, | |
kernel_size=7, | |
stride=1, | |
padding=3, | |
groups=group_channel, | |
), | |
ModuleHelper.BNReLU( | |
hidden_dim, bn_type=self.configer.get("network", "bn_type") | |
), | |
) | |
self.ocr_gather_head = SpatialGather_Module(self.num_classes) | |
self.ocr_distri_head = SpatialOCR_Module( | |
in_channels=hidden_dim, | |
key_channels=hidden_dim // 2, | |
out_channels=hidden_dim, | |
scale=1, | |
dropout=0.05, | |
bn_type=self.configer.get("network", "bn_type"), | |
) | |
self.cls_head = nn.Conv2d( | |
hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True | |
) | |
self.aux_head = nn.Sequential( | |
nn.Conv2d( | |
in_channels, | |
hidden_dim, | |
kernel_size=7, | |
stride=1, | |
padding=3, | |
groups=group_channel, | |
), | |
ModuleHelper.BNReLU( | |
hidden_dim, bn_type=self.configer.get("network", "bn_type") | |
), | |
nn.Conv2d( | |
hidden_dim, | |
self.num_classes, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=True, | |
), | |
) | |
def forward(self, x_): | |
x = self.backbone(x_) | |
_, _, h, w = x[0].size() | |
feat1 = x[0] | |
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) | |
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) | |
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) | |
feats = torch.cat([feat1, feat2, feat3, feat4], 1) | |
out_aux = self.aux_head(feats) | |
feats = self.conv3x3(feats) | |
context = self.ocr_gather_head(feats, out_aux) | |
feats = self.ocr_distri_head(feats, context) | |
out = self.cls_head(feats) | |
out_aux = F.interpolate( | |
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
out = F.interpolate( | |
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
return out_aux, out | |
class HRT_BASE_OCR_V3(nn.Module): | |
def __init__(self, configer): | |
super(HRT_BASE_OCR_V3, self).__init__() | |
self.configer = configer | |
self.num_classes = self.configer.get("data", "num_classes") | |
self.backbone = BackboneSelector(configer).get_backbone() | |
in_channels = 1170 | |
hidden_dim = 512 | |
group_channel = math.gcd(in_channels, hidden_dim) | |
self.conv3x3 = nn.Sequential( | |
nn.Conv2d( | |
in_channels, | |
hidden_dim, | |
kernel_size=7, | |
stride=1, | |
padding=3, | |
groups=group_channel, | |
), | |
ModuleHelper.BNReLU( | |
hidden_dim, bn_type=self.configer.get("network", "bn_type") | |
), | |
) | |
self.ocr_gather_head = SpatialGather_Module(self.num_classes) | |
self.ocr_distri_head = SpatialOCR_Module( | |
in_channels=hidden_dim, | |
key_channels=hidden_dim // 2, | |
out_channels=hidden_dim, | |
scale=1, | |
dropout=0.05, | |
bn_type=self.configer.get("network", "bn_type"), | |
) | |
self.cls_head = nn.Conv2d( | |
hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True | |
) | |
self.aux_head = nn.Sequential( | |
nn.Conv2d( | |
in_channels, | |
hidden_dim, | |
kernel_size=7, | |
stride=1, | |
padding=3, | |
groups=group_channel, | |
), | |
ModuleHelper.BNReLU( | |
hidden_dim, bn_type=self.configer.get("network", "bn_type") | |
), | |
nn.Conv2d( | |
hidden_dim, | |
self.num_classes, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=True, | |
), | |
) | |
def forward(self, x_): | |
x = self.backbone(x_) | |
_, _, h, w = x[0].size() | |
feat1 = x[0] | |
feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) | |
feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) | |
feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) | |
feats = torch.cat([feat1, feat2, feat3, feat4], 1) | |
out_aux = self.aux_head(feats) | |
feats = self.conv3x3(feats) | |
context = self.ocr_gather_head(feats, out_aux) | |
feats = self.ocr_distri_head(feats, context) | |
out = self.cls_head(feats) | |
out_aux = F.interpolate( | |
out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
out = F.interpolate( | |
out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True | |
) | |
return out_aux, out |