Spaces:
Build error
Build error
from timm.models import create_model | |
from .swin_transformer import SwinTransformer | |
from . import focalnet | |
def build_model(config): | |
model_type = config.TYPE | |
print(f"Creating model: {model_type}") | |
if "swin" in model_type: | |
model = SwinTransformer( | |
num_classes=0, | |
img_size=config.IMG_SIZE, | |
patch_size=config.SWIN.PATCH_SIZE, | |
in_chans=config.SWIN.IN_CHANS, | |
embed_dim=config.SWIN.EMBED_DIM, | |
depths=config.SWIN.DEPTHS, | |
num_heads=config.SWIN.NUM_HEADS, | |
window_size=config.SWIN.WINDOW_SIZE, | |
mlp_ratio=config.SWIN.MLP_RATIO, | |
qkv_bias=config.SWIN.QKV_BIAS, | |
qk_scale=config.SWIN.QK_SCALE, | |
drop_rate=config.DROP_RATE, | |
drop_path_rate=config.DROP_PATH_RATE, | |
ape=config.SWIN.APE, | |
patch_norm=config.SWIN.PATCH_NORM, | |
use_checkpoint=False | |
) | |
elif "focal" in model_type: | |
model = create_model( | |
model_type, | |
pretrained=False, | |
img_size=config.IMG_SIZE, | |
num_classes=0, | |
drop_path_rate=config.DROP_PATH_RATE, | |
use_conv_embed=config.FOCAL.USE_CONV_EMBED, | |
use_layerscale=config.FOCAL.USE_LAYERSCALE, | |
use_postln=config.FOCAL.USE_POSTLN | |
) | |
elif "vit" in model_type: | |
model = create_model( | |
model_type, | |
pretrained=is_pretrained, | |
img_size=config.DATA.IMG_SIZE, | |
num_classes=config.MODEL.NUM_CLASSES, | |
) | |
elif "resnet" in model_type: | |
model = create_model( | |
model_type, | |
pretrained=is_pretrained, | |
num_classes=config.MODEL.NUM_CLASSES | |
) | |
else: | |
model = create_model( | |
model_type, | |
pretrained=is_pretrained, | |
num_classes=config.MODEL.NUM_CLASSES | |
) | |
return model | |