|
""" |
|
Code adapted from timm https://github.com/huggingface/pytorch-image-models |
|
|
|
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich |
|
""" |
|
|
|
import os |
|
from typing import Any, Dict, Optional, Union |
|
|
|
import timm |
|
|
|
|
|
from mivolo.model.mivolo_model import * |
|
from timm.layers import set_layer_config |
|
from timm.models._factory import parse_model_name |
|
from timm.models._helpers import load_state_dict, remap_checkpoint |
|
from timm.models._hub import load_model_config_from_hf |
|
from timm.models._pretrained import PretrainedCfg, split_model_name_tag |
|
from timm.models._registry import is_model, model_entrypoint |
|
|
|
|
|
def load_checkpoint( |
|
model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None |
|
): |
|
if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"): |
|
|
|
if hasattr(model, "load_pretrained"): |
|
timm.models._model_builder.load_pretrained(checkpoint_path) |
|
else: |
|
raise NotImplementedError("Model cannot load numpy checkpoint") |
|
return |
|
state_dict = load_state_dict(checkpoint_path, use_ema) |
|
if remap: |
|
state_dict = remap_checkpoint(model, state_dict) |
|
if filter_keys: |
|
for sd_key in list(state_dict.keys()): |
|
for filter_key in filter_keys: |
|
if filter_key in sd_key: |
|
if sd_key in state_dict: |
|
del state_dict[sd_key] |
|
|
|
rep = [] |
|
if state_dict_map is not None: |
|
|
|
for state_k in list(state_dict.keys()): |
|
for target_k, target_v in state_dict_map.items(): |
|
if target_v in state_k: |
|
target_name = state_k.replace(target_v, target_k) |
|
state_dict[target_name] = state_dict[state_k] |
|
rep.append(state_k) |
|
for r in rep: |
|
if r in state_dict: |
|
del state_dict[r] |
|
|
|
incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False) |
|
return incompatible_keys |
|
|
|
|
|
def create_model( |
|
model_name: str, |
|
pretrained: bool = False, |
|
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, |
|
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, |
|
checkpoint_path: str = "", |
|
scriptable: Optional[bool] = None, |
|
exportable: Optional[bool] = None, |
|
no_jit: Optional[bool] = None, |
|
filter_keys=None, |
|
state_dict_map=None, |
|
**kwargs, |
|
): |
|
"""Create a model |
|
Lookup model's entrypoint function and pass relevant args to create a new model. |
|
""" |
|
|
|
|
|
|
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
|
|
model_source, model_name = parse_model_name(model_name) |
|
if model_source == "hf-hub": |
|
assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub." |
|
|
|
|
|
pretrained_cfg, model_name = load_model_config_from_hf(model_name) |
|
else: |
|
model_name, pretrained_tag = split_model_name_tag(model_name) |
|
if not pretrained_cfg: |
|
|
|
pretrained_cfg = pretrained_tag |
|
|
|
if not is_model(model_name): |
|
raise RuntimeError("Unknown model (%s)" % model_name) |
|
|
|
create_fn = model_entrypoint(model_name) |
|
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): |
|
model = create_fn( |
|
pretrained=pretrained, |
|
pretrained_cfg=pretrained_cfg, |
|
pretrained_cfg_overlay=pretrained_cfg_overlay, |
|
**kwargs, |
|
) |
|
|
|
if checkpoint_path: |
|
load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map) |
|
|
|
return model |
|
|