GraCo / isegm /model /is_segformer_model.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
869 Bytes
import torch.nn as nn
from isegm.utils.serialization import serialize
from .is_model import ISModel
from isegm.model.modifiers import LRMult
from .modeling.segformer import MixVisionTransformer, SegformerHead
class SegformerModel(ISModel):
@serialize
def __init__(
self,
backbone_params=None,
decode_head_params=None,
backbone_lr_mult=0.1,
**kwargs
):
super().__init__(**kwargs)
self.feature_extractor = MixVisionTransformer(**backbone_params)
self.feature_extractor.apply(LRMult(backbone_lr_mult))
self.head = SegformerHead(**decode_head_params)
def backbone_forward(self, image, coord_features=None):
backbone_features = self.feature_extractor(image, coord_features)
return {'instances': self.head(backbone_features), 'instances_aux': None}