File size: 1,414 Bytes
beb7843 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import torch.nn as nn
from ..basic.conv import Conv2d
class DecoupledHead(nn.Module):
def __init__(self, cfg):
super().__init__()
print('==============================')
print('Head: Decoupled Head')
self.num_cls_heads = cfg['num_cls_heads']
self.num_reg_heads = cfg['num_reg_heads']
self.act_type = cfg['head_act']
self.norm_type = cfg['head_norm']
self.head_dim = cfg['head_dim']
self.depthwise = cfg['head_depthwise']
self.cls_head = nn.Sequential(*[
Conv2d(self.head_dim,
self.head_dim,
k=3, p=1, s=1,
act_type=self.act_type,
norm_type=self.norm_type,
depthwise=self.depthwise)
for _ in range(self.num_cls_heads)])
self.reg_head = nn.Sequential(*[
Conv2d(self.head_dim,
self.head_dim,
k=3, p=1, s=1,
act_type=self.act_type,
norm_type=self.norm_type,
depthwise=self.depthwise)
for _ in range(self.num_reg_heads)])
def forward(self, cls_feat, reg_feat):
cls_feats = self.cls_head(cls_feat)
reg_feats = self.reg_head(reg_feat)
return cls_feats, reg_feats
def build_head(cfg):
return DecoupledHead(cfg)
|