import torch import torch.nn as nn from ..basic.conv import Conv2d class DecoupledHead(nn.Module): def __init__(self, cfg): super().__init__() print('==============================') print('Head: Decoupled Head') self.num_cls_heads = cfg['num_cls_heads'] self.num_reg_heads = cfg['num_reg_heads'] self.act_type = cfg['head_act'] self.norm_type = cfg['head_norm'] self.head_dim = cfg['head_dim'] self.depthwise = cfg['head_depthwise'] self.cls_head = nn.Sequential(*[ Conv2d(self.head_dim, self.head_dim, k=3, p=1, s=1, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise) for _ in range(self.num_cls_heads)]) self.reg_head = nn.Sequential(*[ Conv2d(self.head_dim, self.head_dim, k=3, p=1, s=1, act_type=self.act_type, norm_type=self.norm_type, depthwise=self.depthwise) for _ in range(self.num_reg_heads)]) def forward(self, cls_feat, reg_feat): cls_feats = self.cls_head(cls_feat) reg_feats = self.reg_head(reg_feat) return cls_feats, reg_feats def build_head(cfg): return DecoupledHead(cfg)