glenn-jocher commited on
Commit
86897e3
1 Parent(s): e9b3de4

Update train.py test batch_size (#2148)

Browse files

* Update train.py

* Update loss.py

Files changed (2) hide show
  1. train.py +2 -2
  2. utils/loss.py +1 -2
train.py CHANGED
@@ -190,7 +190,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
190
  # Process 0
191
  if rank in [-1, 0]:
192
  ema.updates = start_epoch * nb // accumulate # set EMA updates
193
- testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader
194
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
195
  world_size=opt.world_size, workers=opt.workers,
196
  pad=0.5, prefix=colorstr('val: '))[0]
@@ -338,7 +338,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
338
  final_epoch = epoch + 1 == epochs
339
  if not opt.notest or final_epoch: # Calculate mAP
340
  results, maps, times = test.test(opt.data,
341
- batch_size=total_batch_size,
342
  imgsz=imgsz_test,
343
  model=ema.ema,
344
  single_cls=opt.single_cls,
 
190
  # Process 0
191
  if rank in [-1, 0]:
192
  ema.updates = start_epoch * nb // accumulate # set EMA updates
193
+ testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
194
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
195
  world_size=opt.world_size, workers=opt.workers,
196
  pad=0.5, prefix=colorstr('val: '))[0]
 
338
  final_epoch = epoch + 1 == epochs
339
  if not opt.notest or final_epoch: # Calculate mAP
340
  results, maps, times = test.test(opt.data,
341
+ batch_size=batch_size * 2,
342
  imgsz=imgsz_test,
343
  model=ema.ema,
344
  single_cls=opt.single_cls,
utils/loss.py CHANGED
@@ -105,8 +105,7 @@ class ComputeLoss:
105
  BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
106
 
107
  det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
108
- self.balance = {3: [3.67, 1.0, 0.43], 4: [3.78, 1.0, 0.39, 0.22], 5: [3.88, 1.0, 0.37, 0.17, 0.10]}[det.nl]
109
- # self.balance = [1.0] * det.nl
110
  self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index
111
  self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
112
  for k in 'na', 'nc', 'nl', 'anchors':
 
105
  BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
106
 
107
  det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
108
+ self.balance = {3: [3.67, 1.0, 0.43], 4: [4.0, 1.0, 0.25, 0.06], 5: [4.0, 1.0, 0.25, 0.06, .02]}[det.nl]
 
109
  self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index
110
  self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
111
  for k in 'na', 'nc', 'nl', 'anchors':