Spaces:
Running
Running
""" | |
Code adapted from timm https://github.com/huggingface/pytorch-image-models | |
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich | |
""" | |
import os | |
from typing import Any, Dict, Optional, Union | |
import timm | |
# register new models | |
from mivolo.model.mivolo_model import * # noqa: F403, F401 | |
from timm.layers import set_layer_config | |
from timm.models._factory import parse_model_name | |
from timm.models._helpers import load_state_dict, remap_checkpoint | |
from timm.models._hub import load_model_config_from_hf | |
from timm.models._pretrained import PretrainedCfg, split_model_name_tag | |
from timm.models._registry import is_model, model_entrypoint | |
def load_checkpoint( | |
model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None | |
): | |
if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"): | |
# numpy checkpoint, try to load via model specific load_pretrained fn | |
if hasattr(model, "load_pretrained"): | |
timm.models._model_builder.load_pretrained(checkpoint_path) | |
else: | |
raise NotImplementedError("Model cannot load numpy checkpoint") | |
return | |
state_dict = load_state_dict(checkpoint_path, use_ema) | |
if remap: | |
state_dict = remap_checkpoint(model, state_dict) | |
if filter_keys: | |
for sd_key in list(state_dict.keys()): | |
for filter_key in filter_keys: | |
if filter_key in sd_key: | |
if sd_key in state_dict: | |
del state_dict[sd_key] | |
rep = [] | |
if state_dict_map is not None: | |
# 'patch_embed.conv1.' : 'patch_embed.conv.' | |
for state_k in list(state_dict.keys()): | |
for target_k, target_v in state_dict_map.items(): | |
if target_v in state_k: | |
target_name = state_k.replace(target_v, target_k) | |
state_dict[target_name] = state_dict[state_k] | |
rep.append(state_k) | |
for r in rep: | |
if r in state_dict: | |
del state_dict[r] | |
incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False) | |
return incompatible_keys | |
def create_model( | |
model_name: str, | |
pretrained: bool = False, | |
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, | |
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, | |
checkpoint_path: str = "", | |
scriptable: Optional[bool] = None, | |
exportable: Optional[bool] = None, | |
no_jit: Optional[bool] = None, | |
filter_keys=None, | |
state_dict_map=None, | |
**kwargs, | |
): | |
"""Create a model | |
Lookup model's entrypoint function and pass relevant args to create a new model. | |
""" | |
# Parameters that aren't supported by all models or are intended to only override model defaults if set | |
# should default to None in command line args/cfg. Remove them if they are present and not set so that | |
# non-supporting models don't break and default args remain in effect. | |
kwargs = {k: v for k, v in kwargs.items() if v is not None} | |
model_source, model_name = parse_model_name(model_name) | |
if model_source == "hf-hub": | |
assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub." | |
# For model names specified in the form `hf-hub:path/architecture_name@revision`, | |
# load model weights + pretrained_cfg from Hugging Face hub. | |
pretrained_cfg, model_name = load_model_config_from_hf(model_name) | |
else: | |
model_name, pretrained_tag = split_model_name_tag(model_name) | |
if not pretrained_cfg: | |
# a valid pretrained_cfg argument takes priority over tag in model name | |
pretrained_cfg = pretrained_tag | |
if not is_model(model_name): | |
raise RuntimeError("Unknown model (%s)" % model_name) | |
create_fn = model_entrypoint(model_name) | |
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): | |
model = create_fn( | |
pretrained=pretrained, | |
pretrained_cfg=pretrained_cfg, | |
pretrained_cfg_overlay=pretrained_cfg_overlay, | |
**kwargs, | |
) | |
if checkpoint_path: | |
load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map) | |
return model | |