Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from collections import OrderedDict | |
from isegm.utils.serialization import serialize | |
from .is_model import ISModel | |
from isegm.model.modifiers import LRMult | |
from .modeling.hrformer import HRT_B_OCR_V3 | |
class HRFormerModel(ISModel): | |
def __init__( | |
self, | |
num_classes=1, | |
in_ch=6, | |
backbone_lr_mult=0.1, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.feature_extractor = HRT_B_OCR_V3(num_classes, in_ch) | |
self.feature_extractor.apply(LRMult(backbone_lr_mult)) | |
def backbone_forward(self, image, coord_features=None): | |
backbone_features = self.feature_extractor(image) | |
return {'instances': backbone_features[0], 'instances_aux': backbone_features[1]} | |
def init_weight(self, pretrained=None): | |
if pretrained is not None: | |
state_dict = torch.load(pretrained)['model'] | |
state_dict_rename = OrderedDict() | |
for k, v in state_dict.items(): | |
state_dict_rename['backbone.' + k] = v | |
ori_proj_weight = state_dict_rename['backbone.conv1.weight'] | |
state_dict_rename['backbone.conv1.weight'] = torch.cat([ori_proj_weight, ori_proj_weight], dim=1) | |
self.feature_extractor.load_state_dict(state_dict_rename, False) | |
print('Successfully loaded pretrained model.') | |