GraCo / isegm /model /is_hrformer_model.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
1.4 kB
import torch
import torch.nn as nn
from collections import OrderedDict
from isegm.utils.serialization import serialize
from .is_model import ISModel
from isegm.model.modifiers import LRMult
from .modeling.hrformer import HRT_B_OCR_V3
class HRFormerModel(ISModel):
@serialize
def __init__(
self,
num_classes=1,
in_ch=6,
backbone_lr_mult=0.1,
**kwargs
):
super().__init__(**kwargs)
self.feature_extractor = HRT_B_OCR_V3(num_classes, in_ch)
self.feature_extractor.apply(LRMult(backbone_lr_mult))
def backbone_forward(self, image, coord_features=None):
backbone_features = self.feature_extractor(image)
return {'instances': backbone_features[0], 'instances_aux': backbone_features[1]}
def init_weight(self, pretrained=None):
if pretrained is not None:
state_dict = torch.load(pretrained)['model']
state_dict_rename = OrderedDict()
for k, v in state_dict.items():
state_dict_rename['backbone.' + k] = v
ori_proj_weight = state_dict_rename['backbone.conv1.weight']
state_dict_rename['backbone.conv1.weight'] = torch.cat([ori_proj_weight, ori_proj_weight], dim=1)
self.feature_extractor.load_state_dict(state_dict_rename, False)
print('Successfully loaded pretrained model.')