rap-sam / app /configs /rap_sam_r50_12e_adaptor.py
HarborYuan's picture
add rap_sam
502989e
raw
history blame
No virus
2.66 kB
from mmdet.models import ResNet, MaskFormerFusionHead, CrossEntropyLoss, DiceLoss
from app.models.detectors import YOSOVideoSam
from app.models.heads import RapSAMVideoHead
from app.models.necks import YOSONeck
num_things_classes = 80
num_stuff_classes = 53
ov_model_name = 'convnext_large_d_320'
ov_datasets_name = 'CocoPanopticOVDataset'
num_classes = num_things_classes + num_stuff_classes
model = dict(
type=YOSOVideoSam,
data_preprocessor=None,
backbone=dict(
type=ResNet,
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
init_cfg=None,
),
neck=dict(
type=YOSONeck,
agg_dim=128,
hidden_dim=256,
backbone_shape=[256, 512, 1024, 2048],
),
panoptic_head=dict(
type=RapSAMVideoHead,
prompt_with_kernel_updator=False,
panoptic_with_kernel_updator=True,
use_adaptor=True,
use_kernel_updator=True,
sphere_cls=True,
ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}',
num_stages=3,
feat_channels=256,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
num_queries=100,
loss_cls=dict(
type=CrossEntropyLoss,
use_sigmoid=False,
loss_weight=2.0,
reduction='mean',
class_weight=[1.0] * num_classes + [0.1]),
loss_mask=dict(
type=CrossEntropyLoss,
use_sigmoid=True,
reduction='mean',
loss_weight=5.0),
loss_dice=dict(
type=DiceLoss,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=5.0)
),
panoptic_fusion_head=dict(
type=MaskFormerFusionHead,
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
loss_panoptic=None,
init_cfg=None
),
train_cfg=None,
test_cfg=dict(
panoptic_on=True,
# For now, the dataset does not support
# evaluating semantic segmentation metric.
semantic_on=False,
instance_on=True,
# max_per_image is for instance segmentation.
max_per_image=100,
iou_thr=0.8,
# In Mask2Former's panoptic postprocessing,
# it will filter mask area where score is less than 0.5 .
filter_low_score=True),
init_cfg=dict(
type='Pretrained',
checkpoint='models/rapsam_r50_12e.pth'
)
)