|
''' |
|
Author: Chris Xiao [email protected] |
|
Date: 2023-09-16 19:47:31 |
|
LastEditors: Chris Xiao [email protected] |
|
LastEditTime: 2023-09-30 16:56:23 |
|
FilePath: /EndoSAM/endoSAM/model.py |
|
Description: EndoSAM model adapter |
|
I Love IU |
|
Copyright (c) 2023 by Chris Xiao [email protected], All Rights Reserved. |
|
''' |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from utils import postprocess_masks |
|
|
|
class EndoSAMAdapter(nn.Module): |
|
def __init__(self, device, |
|
num_classes, |
|
sam_mask_encoder, |
|
sam_prompt_encoder, |
|
sam_mask_decoder, |
|
num_token=8, |
|
): |
|
super(EndoSAMAdapter, self).__init__() |
|
self.device = device |
|
self.num_classes = num_classes - 1 |
|
self.num_token = num_token |
|
self.sam_mask_encoder = sam_mask_encoder.to(self.device) |
|
self.sam_prompt_encoder = sam_prompt_encoder.to(self.device) |
|
self.sam_mask_decoder = sam_mask_decoder.to(self.device) |
|
self.prototype_prompt_encoder = Prototype_Prompt_Encoder(feat_dim=256, |
|
hidden_dim_dense=128, |
|
hidden_dim_sparse=128, |
|
size=64, |
|
num_tokens=self.num_token).to(self.device) |
|
self.learnable_prototypes_model = Learnable_Prototypes(num_classes=self.num_classes, feat_dim = 256).to(self.device) |
|
self.prototypes = self.learnable_prototypes_model() |
|
self.sam_mask_encoder.to(self.device) |
|
self.sam_prompt_encoder.to(self.device) |
|
self.sam_mask_decoder.to(self.device) |
|
|
|
for _, param in self.prototype_prompt_encoder.named_parameters(): |
|
param.requires_grad = True |
|
|
|
for _, param in self.learnable_prototypes_model.named_parameters(): |
|
param.requires_grad = True |
|
|
|
for _, param in self.sam_mask_decoder.named_parameters(): |
|
param.requires_grad = True |
|
|
|
for _, param in self.sam_mask_encoder.named_parameters(): |
|
param.requires_grad = False |
|
|
|
for _, param in self.sam_prompt_encoder.named_parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x): |
|
sam_features = self.sam_mask_encoder(x) |
|
sam_features = rearrange(sam_features, 'b c h w -> b (h w) c') |
|
cls_ids = torch.tensor(1).repeat(sam_features.shape[0]).to(self.device) |
|
dense_embeddings, sparse_embeddings = self.prototype_prompt_encoder(sam_features, self.prototypes, cls_ids, self.num_classes) |
|
pred = [] |
|
pred_quality = [] |
|
sam_features = rearrange(sam_features,'b (h w) c -> b c h w', h=64, w=64) |
|
for dense_embedding, sparse_embedding, features_per_image in zip(dense_embeddings.unsqueeze(1), sparse_embeddings.unsqueeze(1), sam_features): |
|
low_res_masks_per_image, mask_quality_per_image = self.sam_mask_decoder( |
|
image_embeddings=features_per_image.unsqueeze(0), |
|
image_pe=self.sam_prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embedding, |
|
dense_prompt_embeddings=dense_embedding, |
|
multimask_output=True, |
|
) |
|
pred_per_image = postprocess_masks( |
|
low_res_masks_per_image, |
|
input_size=(819, 1024), |
|
original_size=(1024, 1280), |
|
) |
|
|
|
pred.append(pred_per_image) |
|
pred_quality.append(mask_quality_per_image) |
|
|
|
pred = torch.cat(pred, dim=0) |
|
pred_quality = torch.cat(pred_quality,dim=0) |
|
return pred, pred_quality |
|
|
|
class Prototype_Prompt_Encoder(nn.Module): |
|
def __init__(self, feat_dim=256, |
|
hidden_dim_dense=128, |
|
hidden_dim_sparse=128, |
|
size=64, |
|
num_tokens=8): |
|
|
|
super(Prototype_Prompt_Encoder, self).__init__() |
|
self.dense_fc_1 = nn.Conv2d(feat_dim, hidden_dim_dense, 1) |
|
self.dense_fc_2 = nn.Conv2d(hidden_dim_dense, feat_dim, 1) |
|
|
|
self.relu = nn.ReLU() |
|
|
|
self.sparse_fc_1 = nn.Conv1d(size*size, hidden_dim_sparse, 1) |
|
self.sparse_fc_2 = nn.Conv1d(hidden_dim_sparse, num_tokens, 1) |
|
|
|
|
|
pn_cls_embeddings = [nn.Embedding(num_tokens, feat_dim) for _ in range(2)] |
|
|
|
|
|
self.pn_cls_embeddings = nn.ModuleList(pn_cls_embeddings) |
|
|
|
def forward(self, feat, prototypes, cls_ids, num_classes): |
|
|
|
cls_prompts = prototypes.unsqueeze(-1) |
|
cls_prompts = torch.stack([cls_prompts for _ in range(feat.size(0))], dim=0) |
|
|
|
|
|
feat = torch.stack([feat for _ in range(cls_prompts.size(1))], dim=1) |
|
|
|
sim = torch.matmul(feat, cls_prompts) |
|
|
|
|
|
feat = feat + feat*sim |
|
feat_sparse = feat.clone() |
|
|
|
|
|
one_hot = torch.nn.functional.one_hot(cls_ids-1,num_classes) |
|
feat = feat[one_hot == 1] |
|
feat = rearrange(feat.squeeze(1),'b (h w) c -> b c h w', h=64, w=64) |
|
dense_embeddings = self.dense_fc_2(self.relu(self.dense_fc_1(feat))) |
|
|
|
|
|
feat_sparse = rearrange(feat_sparse,'b num_cls hw c -> (b num_cls) hw c') |
|
sparse_embeddings = self.sparse_fc_2(self.relu(self.sparse_fc_1(feat_sparse))) |
|
sparse_embeddings = rearrange(sparse_embeddings,'(b num_cls) n c -> b num_cls n c', num_cls=1) |
|
|
|
pos_embed = self.pn_cls_embeddings[1].weight.unsqueeze(0).unsqueeze(0) * one_hot.unsqueeze(-1).unsqueeze(-1) |
|
neg_embed = self.pn_cls_embeddings[0].weight.unsqueeze(0).unsqueeze(0) * (1-one_hot).unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
sparse_embeddings = sparse_embeddings + pos_embed.detach() + neg_embed.detach() |
|
|
|
sparse_embeddings = rearrange(sparse_embeddings,'b num_cls n c -> b (num_cls n) c') |
|
|
|
return dense_embeddings, sparse_embeddings |
|
|
|
|
|
class Learnable_Prototypes(nn.Module): |
|
def __init__(self, num_classes=7 , feat_dim=256): |
|
super(Learnable_Prototypes, self).__init__() |
|
self.class_embeddings = nn.Embedding(num_classes, feat_dim) |
|
|
|
def forward(self): |
|
return self.class_embeddings.weight |