Spaces:
Running
Running
import torch.nn as nn | |
from isegm.utils.serialization import serialize | |
from .is_model import ISModel | |
from isegm.model.modifiers import LRMult | |
from .modeling.segformer import MixVisionTransformer, SegformerHead | |
class SegformerModel(ISModel): | |
def __init__( | |
self, | |
backbone_params=None, | |
decode_head_params=None, | |
backbone_lr_mult=0.1, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.feature_extractor = MixVisionTransformer(**backbone_params) | |
self.feature_extractor.apply(LRMult(backbone_lr_mult)) | |
self.head = SegformerHead(**decode_head_params) | |
def backbone_forward(self, image, coord_features=None): | |
backbone_features = self.feature_extractor(image, coord_features) | |
return {'instances': self.head(backbone_features), 'instances_aux': None} | |