File size: 1,609 Bytes
beb7843 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
from .yowo import YOWO
from .loss import build_criterion
# build YOWO detector
def build_yowo(args,
d_cfg,
m_cfg,
device,
num_classes=3,
trainable=False,
resume=None):
print('==============================')
print('Build {} ...'.format(args.version.upper()))
# build YOWO
model = YOWO(
cfg = m_cfg,
device = device,
num_classes = num_classes,
conf_thresh = 0.15,
nms_thresh = 0.5,
topk = 40,
trainable = trainable,
multi_hot = d_cfg['multi_hot'],
)
if trainable:
# Freeze backbone
if args.freeze_backbone_2d:
print('Freeze 2D Backbone ...')
for m in model.backbone_2d.parameters():
m.requires_grad = False
if args.freeze_backbone_3d:
print('Freeze 3D Backbone ...')
for m in model.backbone_3d.parameters():
m.requires_grad = False
# keep training
if resume is not None:
print('keep training: ', resume)
checkpoint = torch.load(resume, map_location='cpu')
# checkpoint state dict
checkpoint_state_dict = checkpoint.pop("model")
model.load_state_dict(checkpoint_state_dict)
# build criterion
criterion = build_criterion(
args, d_cfg['train_size'], num_classes, d_cfg['multi_hot'])
else:
criterion = None
return model, criterion
|