|
from einops import rearrange |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
from segmenter_model.utils import padding, unpadding |
|
|
|
|
|
class Segmenter(nn.Module): |
|
def __init__( |
|
self, |
|
encoder, |
|
decoder, |
|
n_cls, |
|
): |
|
super().__init__() |
|
self.n_cls = n_cls |
|
self.patch_size = encoder.patch_size |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
def append_prefix_no_weight_decay(prefix, module): |
|
return set(map(lambda x: prefix + x, module.no_weight_decay())) |
|
|
|
nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union( |
|
append_prefix_no_weight_decay("decoder.", self.decoder) |
|
) |
|
return nwd_params |
|
|
|
def forward(self, im, decoder_features=False, no_upsample=False, encoder_features=False, no_rearrange=False, |
|
cls_only=False, encoder_only=False): |
|
H_ori, W_ori = im.size(2), im.size(3) |
|
if not no_upsample: |
|
im = padding(im, self.patch_size) |
|
H, W = im.size(2), im.size(3) |
|
|
|
x = self.encoder(im, return_features=True) |
|
|
|
|
|
num_extra_tokens = 1 + self.encoder.distilled |
|
|
|
if cls_only: |
|
return x[:, 0] |
|
x = x[:, num_extra_tokens:] |
|
|
|
if encoder_features: |
|
enc_fts = x.clone() |
|
if not no_rearrange: |
|
GS = H // self.patch_size |
|
enc_fts = rearrange(enc_fts, "b (h w) c -> b c h w", h=GS) |
|
if encoder_only: |
|
return enc_fts |
|
|
|
if decoder_features: |
|
output = self.decoder(x, (H, W), features_only=True, no_rearrange=no_rearrange) |
|
if no_rearrange: |
|
if encoder_features: |
|
output = (enc_fts, output) |
|
return output |
|
else: |
|
output = self.decoder(x, (H, W)) |
|
|
|
if not no_upsample: |
|
output = F.interpolate(output, size=(H, W), mode="bilinear") |
|
output = unpadding(output, (H_ori, W_ori)) |
|
|
|
if encoder_features: |
|
output = (enc_fts, output) |
|
return output |
|
|
|
def get_attention_map_enc(self, im, layer_id): |
|
return self.encoder.get_attention_map(im, layer_id) |
|
|
|
def get_attention_map_dec(self, im, layer_id): |
|
x = self.encoder(im, return_features=True) |
|
|
|
|
|
num_extra_tokens = 1 + self.encoder.distilled |
|
x = x[:, num_extra_tokens:] |
|
|
|
return self.decoder.get_attention_map(x, layer_id) |
|
|