|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from models.revcol import * |
|
|
|
|
|
|
|
def build_model(config): |
|
model_type = config.MODEL.TYPE |
|
|
|
|
|
|
|
if model_type == "revcol_tiny": |
|
model = revcol_tiny(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES, kernel_size = config.REVCOL.KERNEL_SIZE) |
|
|
|
|
|
|
|
elif model_type == "revcol_small": |
|
model = revcol_small(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES, kernel_size = config.REVCOL.KERNEL_SIZE) |
|
|
|
|
|
|
|
elif model_type == "revcol_base": |
|
model = revcol_base(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES , kernel_size = config.REVCOL.KERNEL_SIZE) |
|
|
|
|
|
|
|
elif model_type == "revcol_large": |
|
model = revcol_large(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES , head_init_scale=config.REVCOL.HEAD_INIT_SCALE, kernel_size = config.REVCOL.KERNEL_SIZE) |
|
|
|
|
|
|
|
elif model_type == "revcol_xlarge": |
|
model = revcol_xlarge(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES , head_init_scale=config.REVCOL.HEAD_INIT_SCALE, kernel_size = config.REVCOL.KERNEL_SIZE) |
|
|
|
else: |
|
raise NotImplementedError(f"Unkown model: {model_type}") |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|