glenn-jocher
commited on
Commit
•
86897e3
1
Parent(s):
e9b3de4
Update train.py test batch_size (#2148)
Browse files* Update train.py
* Update loss.py
- train.py +2 -2
- utils/loss.py +1 -2
train.py
CHANGED
@@ -190,7 +190,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
190 |
# Process 0
|
191 |
if rank in [-1, 0]:
|
192 |
ema.updates = start_epoch * nb // accumulate # set EMA updates
|
193 |
-
testloader = create_dataloader(test_path, imgsz_test,
|
194 |
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
195 |
world_size=opt.world_size, workers=opt.workers,
|
196 |
pad=0.5, prefix=colorstr('val: '))[0]
|
@@ -338,7 +338,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
338 |
final_epoch = epoch + 1 == epochs
|
339 |
if not opt.notest or final_epoch: # Calculate mAP
|
340 |
results, maps, times = test.test(opt.data,
|
341 |
-
batch_size=
|
342 |
imgsz=imgsz_test,
|
343 |
model=ema.ema,
|
344 |
single_cls=opt.single_cls,
|
|
|
190 |
# Process 0
|
191 |
if rank in [-1, 0]:
|
192 |
ema.updates = start_epoch * nb // accumulate # set EMA updates
|
193 |
+
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
|
194 |
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
195 |
world_size=opt.world_size, workers=opt.workers,
|
196 |
pad=0.5, prefix=colorstr('val: '))[0]
|
|
|
338 |
final_epoch = epoch + 1 == epochs
|
339 |
if not opt.notest or final_epoch: # Calculate mAP
|
340 |
results, maps, times = test.test(opt.data,
|
341 |
+
batch_size=batch_size * 2,
|
342 |
imgsz=imgsz_test,
|
343 |
model=ema.ema,
|
344 |
single_cls=opt.single_cls,
|
utils/loss.py
CHANGED
@@ -105,8 +105,7 @@ class ComputeLoss:
|
|
105 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
106 |
|
107 |
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
|
108 |
-
self.balance = {3: [3.67, 1.0, 0.43], 4: [
|
109 |
-
# self.balance = [1.0] * det.nl
|
110 |
self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index
|
111 |
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
|
112 |
for k in 'na', 'nc', 'nl', 'anchors':
|
|
|
105 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
106 |
|
107 |
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
|
108 |
+
self.balance = {3: [3.67, 1.0, 0.43], 4: [4.0, 1.0, 0.25, 0.06], 5: [4.0, 1.0, 0.25, 0.06, .02]}[det.nl]
|
|
|
109 |
self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index
|
110 |
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
|
111 |
for k in 'na', 'nc', 'nl', 'anchors':
|