File size: 6,103 Bytes
dd78229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from pathlib import Path
import yaml
import torch
import math
import os
import torch.nn as nn

from timm.models.helpers import load_pretrained, load_custom_pretrained
from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn
from timm.models.registry import register_model
from timm.models.vision_transformer import _create_vision_transformer
from segmenter_model.decoder import MaskTransformer
from segmenter_model.segmenter import Segmenter
import segmenter_model.torch as ptu

from segmenter_model.vit_dino import vit_small, VisionTransformer


@register_model
def vit_base_patch8_384(pretrained=False, **kwargs):
    """ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
    """
    model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
    model = _create_vision_transformer(
        "vit_base_patch8_384",
        pretrained=pretrained,
        default_cfg=dict(
            url="",
            input_size=(3, 384, 384),
            mean=(0.5, 0.5, 0.5),
            std=(0.5, 0.5, 0.5),
            num_classes=1000,
        ),
        **model_kwargs,
    )
    return model


def create_vit(model_cfg):
    model_cfg = model_cfg.copy()
    backbone = model_cfg.pop("backbone")
    if 'pretrained_weights' in model_cfg:
        pretrained_weights = model_cfg.pop('pretrained_weights')

    if 'dino' in backbone:
        if backbone.lower() == 'dino_vits16':
            model_cfg['drop_rate'] = model_cfg['dropout']
            model = vit_small(**model_cfg)
            # hard-coded for now, too lazy
            ciirc_path = '/home/vobecant/PhD/weights/dino/dino_deitsmall16_pretrain.pth'
            karolina_path = '/scratch/project/dd-21-20/pretrained_weights/dino/dino_deitsmall16_pretrain.pth'
            if os.path.exists(ciirc_path):
                pretrained_weights = ciirc_path
            elif os.path.exists(karolina_path):
                pretrained_weights = karolina_path
            else:
                raise Exception('DINO weights not found!')
            model.load_state_dict(torch.load(pretrained_weights), strict=True)
        else:
            model = torch.hub.load('facebookresearch/dino:main', backbone)
        setattr(model, 'd_model', model.num_features)
        setattr(model, 'patch_size', model.patch_embed.patch_size)
        setattr(model, 'distilled', False)
        model.forward = lambda x, return_features: model.get_intermediate_layers(x, n=1)[0]
    else:
        normalization = model_cfg.pop("normalization")
        model_cfg["n_cls"] = 1000
        mlp_expansion_ratio = 4
        model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"]

        if backbone in default_cfgs:
            default_cfg = default_cfgs[backbone]
        else:
            default_cfg = dict(
                pretrained=False,
                num_classes=1000,
                drop_rate=0.0,
                drop_path_rate=0.0,
                drop_block_rate=None,
            )

        default_cfg["input_size"] = (
            3,
            model_cfg["image_size"][0],
            model_cfg["image_size"][1],
        )
        model = VisionTransformer(**model_cfg)
        if backbone == "vit_base_patch8_384":
            path = os.path.expandvars("/home/vobecant/PhD/weights/vit_base_patch8_384.pth")
            state_dict = torch.load(path, map_location="cpu")
            filtered_dict = checkpoint_filter_fn(state_dict, model)
            model.load_state_dict(filtered_dict, strict=True)
        elif "deit" in backbone:
            load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn)
        else:
            load_custom_pretrained(model, default_cfg)

    return model


def create_decoder(encoder, decoder_cfg):
    decoder_cfg = decoder_cfg.copy()
    name = decoder_cfg.pop("name")
    decoder_cfg["d_encoder"] = encoder.d_model
    decoder_cfg["patch_size"] = encoder.patch_size

    if "linear" in name:
        decoder = DecoderLinear(**decoder_cfg)
    elif name == "mask_transformer":
        dim = encoder.d_model
        n_heads = dim // 64
        decoder_cfg["n_heads"] = n_heads
        decoder_cfg["d_model"] = dim
        decoder_cfg["d_ff"] = 4 * dim
        decoder = MaskTransformer(**decoder_cfg)
    elif 'deeplab' in name:
        decoder = DeepLabHead(in_channels=encoder.d_model, num_classes=decoder_cfg["n_cls"],
                              patch_size=decoder_cfg["patch_size"])
    else:
        raise ValueError(f"Unknown decoder: {name}")
    return decoder


def create_segmenter(model_cfg):
    model_cfg = model_cfg.copy()
    decoder_cfg = model_cfg.pop("decoder")
    decoder_cfg["n_cls"] = model_cfg["n_cls"]

    if 'weights_path' in model_cfg.keys():
        weights_path = model_cfg.pop('weights_path')
    else:
        weights_path = None

    encoder = create_vit(model_cfg)
    decoder = create_decoder(encoder, decoder_cfg)
    model = Segmenter(encoder, decoder, n_cls=model_cfg["n_cls"])

    if weights_path is not None:
        raise Exception('Wants to load weights to the complete segmenter insice create_segmenter method!')
        state_dict = torch.load(weights_path, map_location="cpu")
        if 'model' in state_dict:
            state_dict = state_dict['model']
        msg = model.load_state_dict(state_dict, strict=False)
        print(msg)

    return model


def load_model(model_path, decoder_only=False, variant_path=None):
    variant_path = Path(model_path).parent / "variant.yml" if variant_path is None else variant_path
    with open(variant_path, "r") as f:
        variant = yaml.load(f, Loader=yaml.FullLoader)
    net_kwargs = variant["net_kwargs"]

    model = create_segmenter(net_kwargs)
    data = torch.load(model_path, map_location=ptu.device)
    checkpoint = data["model"]

    if decoder_only:
        model.decoder.load_state_dict(checkpoint, strict=True)
    else:
        model.load_state_dict(checkpoint, strict=True)

    return model, variant