File size: 2,824 Bytes
dd78229 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
# from timm.models.layers import trunc_normal_
from segmenter_model.utils import padding, unpadding
class Segmenter(nn.Module):
def __init__(
self,
encoder,
decoder,
n_cls,
):
super().__init__()
self.n_cls = n_cls
self.patch_size = encoder.patch_size
self.encoder = encoder
self.decoder = decoder
@torch.jit.ignore
def no_weight_decay(self):
def append_prefix_no_weight_decay(prefix, module):
return set(map(lambda x: prefix + x, module.no_weight_decay()))
nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union(
append_prefix_no_weight_decay("decoder.", self.decoder)
)
return nwd_params
def forward(self, im, decoder_features=False, no_upsample=False, encoder_features=False, no_rearrange=False,
cls_only=False, encoder_only=False):
H_ori, W_ori = im.size(2), im.size(3)
if not no_upsample:
im = padding(im, self.patch_size)
H, W = im.size(2), im.size(3)
x = self.encoder(im, return_features=True) # self.patch_size times smaller than im
# remove CLS/DIST tokens for decoding
num_extra_tokens = 1 + self.encoder.distilled
if cls_only:
return x[:, 0]
x = x[:, num_extra_tokens:]
if encoder_features:
enc_fts = x.clone()
if not no_rearrange:
GS = H // self.patch_size
enc_fts = rearrange(enc_fts, "b (h w) c -> b c h w", h=GS)
if encoder_only:
return enc_fts
if decoder_features:
output = self.decoder(x, (H, W), features_only=True, no_rearrange=no_rearrange)
if no_rearrange:
if encoder_features:
output = (enc_fts, output)
return output
else:
output = self.decoder(x, (H, W)) # shape (BS, NCLS, H/self.patch_size, W/self.patch_size)
if not no_upsample:
output = F.interpolate(output, size=(H, W), mode="bilinear") # upsample self.patch_size times
output = unpadding(output, (H_ori, W_ori))
if encoder_features:
output = (enc_fts, output)
return output
def get_attention_map_enc(self, im, layer_id):
return self.encoder.get_attention_map(im, layer_id)
def get_attention_map_dec(self, im, layer_id):
x = self.encoder(im, return_features=True)
# remove CLS/DIST tokens for decoding
num_extra_tokens = 1 + self.encoder.distilled
x = x[:, num_extra_tokens:]
return self.decoder.get_attention_map(x, layer_id)
|