glenn-jocher commited on
Commit
0ada058
1 Parent(s): 10c85bf

Generalized regression criterion renaming (#1120)

Browse files
Files changed (6) hide show
  1. data/hyp.finetune.yaml +1 -1
  2. data/hyp.scratch.yaml +1 -1
  3. sotabench.py +1 -1
  4. test.py +1 -1
  5. train.py +9 -9
  6. utils/general.py +9 -9
data/hyp.finetune.yaml CHANGED
@@ -15,7 +15,7 @@ weight_decay: 0.00036
15
  warmup_epochs: 2.0
16
  warmup_momentum: 0.5
17
  warmup_bias_lr: 0.05
18
- giou: 0.0296
19
  cls: 0.243
20
  cls_pw: 0.631
21
  obj: 0.301
 
15
  warmup_epochs: 2.0
16
  warmup_momentum: 0.5
17
  warmup_bias_lr: 0.05
18
+ box: 0.0296
19
  cls: 0.243
20
  cls_pw: 0.631
21
  obj: 0.301
data/hyp.scratch.yaml CHANGED
@@ -10,7 +10,7 @@ weight_decay: 0.0005 # optimizer weight decay 5e-4
10
  warmup_epochs: 3.0 # warmup epochs (fractions ok)
11
  warmup_momentum: 0.8 # warmup initial momentum
12
  warmup_bias_lr: 0.1 # warmup initial bias lr
13
- giou: 0.05 # box loss gain
14
  cls: 0.5 # cls loss gain
15
  cls_pw: 1.0 # cls BCELoss positive_weight
16
  obj: 1.0 # obj loss gain (scale with pixels)
 
10
  warmup_epochs: 3.0 # warmup epochs (fractions ok)
11
  warmup_momentum: 0.8 # warmup initial momentum
12
  warmup_bias_lr: 0.1 # warmup initial bias lr
13
+ box: 0.05 # box loss gain
14
  cls: 0.5 # cls loss gain
15
  cls_pw: 1.0 # cls BCELoss positive_weight
16
  obj: 1.0 # obj loss gain (scale with pixels)
sotabench.py CHANGED
@@ -113,7 +113,7 @@ def test(data,
113
 
114
  # Compute loss
115
  if training: # if model has loss hyperparameters
116
- loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
117
 
118
  # Run NMS
119
  t = time_synchronized()
 
113
 
114
  # Compute loss
115
  if training: # if model has loss hyperparameters
116
+ loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls
117
 
118
  # Run NMS
119
  t = time_synchronized()
test.py CHANGED
@@ -106,7 +106,7 @@ def test(data,
106
 
107
  # Compute loss
108
  if training: # if model has loss hyperparameters
109
- loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
110
 
111
  # Run NMS
112
  t = time_synchronized()
 
106
 
107
  # Compute loss
108
  if training: # if model has loss hyperparameters
109
+ loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls
110
 
111
  # Run NMS
112
  t = time_synchronized()
train.py CHANGED
@@ -195,7 +195,7 @@ def train(hyp, opt, device, tb_writer=None):
195
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
196
  model.nc = nc # attach number of classes to model
197
  model.hyp = hyp # attach hyperparameters to model
198
- model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
199
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
200
  model.names = names
201
 
@@ -204,7 +204,7 @@ def train(hyp, opt, device, tb_writer=None):
204
  nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
205
  # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
206
  maps = np.zeros(nc) # mAP per class
207
- results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
208
  scheduler.last_epoch = start_epoch - 1 # do not move
209
  scaler = amp.GradScaler(enabled=cuda)
210
  logger.info('Image sizes %g train, %g test\nUsing %g dataloader workers\nLogging results to %s\n'
@@ -234,7 +234,7 @@ def train(hyp, opt, device, tb_writer=None):
234
  if rank != -1:
235
  dataloader.sampler.set_epoch(epoch)
236
  pbar = enumerate(dataloader)
237
- logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
238
  if rank in [-1, 0]:
239
  pbar = tqdm(pbar, total=nb) # progress bar
240
  optimizer.zero_grad()
@@ -245,7 +245,7 @@ def train(hyp, opt, device, tb_writer=None):
245
  # Warmup
246
  if ni <= nw:
247
  xi = [0, nw] # x interp
248
- # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
249
  accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
250
  for j, x in enumerate(optimizer.param_groups):
251
  # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
@@ -319,21 +319,21 @@ def train(hyp, opt, device, tb_writer=None):
319
 
320
  # Write
321
  with open(results_file, 'a') as f:
322
- f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
323
  if len(opt.name) and opt.bucket:
324
  os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
325
 
326
  # Tensorboard
327
  if tb_writer:
328
- tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss
329
  'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
330
- 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss
331
  'x/lr0', 'x/lr1', 'x/lr2'] # params
332
  for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
333
  tb_writer.add_scalar(tag, x, epoch)
334
 
335
  # Update best mAP
336
- fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
337
  if fi > best_fitness:
338
  best_fitness = fi
339
 
@@ -463,7 +463,7 @@ if __name__ == '__main__':
463
  'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
464
  'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
465
  'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
466
- 'giou': (1, 0.02, 0.2), # GIoU loss gain
467
  'cls': (1, 0.2, 4.0), # cls loss gain
468
  'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
469
  'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
 
195
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
196
  model.nc = nc # attach number of classes to model
197
  model.hyp = hyp # attach hyperparameters to model
198
+ model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
199
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
200
  model.names = names
201
 
 
204
  nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
205
  # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
206
  maps = np.zeros(nc) # mAP per class
207
+ results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, [email protected], val_loss(box, obj, cls)
208
  scheduler.last_epoch = start_epoch - 1 # do not move
209
  scaler = amp.GradScaler(enabled=cuda)
210
  logger.info('Image sizes %g train, %g test\nUsing %g dataloader workers\nLogging results to %s\n'
 
234
  if rank != -1:
235
  dataloader.sampler.set_epoch(epoch)
236
  pbar = enumerate(dataloader)
237
+ logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
238
  if rank in [-1, 0]:
239
  pbar = tqdm(pbar, total=nb) # progress bar
240
  optimizer.zero_grad()
 
245
  # Warmup
246
  if ni <= nw:
247
  xi = [0, nw] # x interp
248
+ # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
249
  accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
250
  for j, x in enumerate(optimizer.param_groups):
251
  # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
 
319
 
320
  # Write
321
  with open(results_file, 'a') as f:
322
+ f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, [email protected], val_loss(box, obj, cls)
323
  if len(opt.name) and opt.bucket:
324
  os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
325
 
326
  # Tensorboard
327
  if tb_writer:
328
+ tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
329
  'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
330
+ 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
331
  'x/lr0', 'x/lr1', 'x/lr2'] # params
332
  for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
333
  tb_writer.add_scalar(tag, x, epoch)
334
 
335
  # Update best mAP
336
+ fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, [email protected]]
337
  if fi > best_fitness:
338
  best_fitness = fi
339
 
 
463
  'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
464
  'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
465
  'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
466
+ 'box': (1, 0.02, 0.2), # box loss gain
467
  'cls': (1, 0.2, 4.0), # cls loss gain
468
  'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
469
  'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
utils/general.py CHANGED
@@ -509,11 +509,11 @@ def compute_loss(p, targets, model): # predictions, targets, model
509
  pxy = ps[:, :2].sigmoid() * 2. - 0.5
510
  pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
511
  pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
512
- giou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # giou(prediction, target)
513
- lbox += (1.0 - giou).mean() # giou loss
514
 
515
  # Objectness
516
- tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
517
 
518
  # Classification
519
  if model.nc > 1: # cls loss (only if multiple classes)
@@ -528,7 +528,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
528
  lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
529
 
530
  s = 3 / np # output count scaling
531
- lbox *= h['giou'] * s
532
  lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
533
  lcls *= h['cls'] * s
534
  bs = tobj.shape[0] # batch size
@@ -1234,7 +1234,7 @@ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general im
1234
  def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay()
1235
  # Plot training 'results*.txt', overlaying train and val losses
1236
  s = ['train', 'train', 'train', 'Precision', '[email protected]', 'val', 'val', 'val', 'Recall', '[email protected]:0.95'] # legends
1237
- t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
1238
  for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
1239
  results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
1240
  n = results.shape[1] # number of rows
@@ -1254,13 +1254,13 @@ def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_
1254
  fig.savefig(f.replace('.txt', '.png'), dpi=200)
1255
 
1256
 
1257
- def plot_results(start=0, stop=0, bucket='', id=(), labels=(),
1258
- save_dir=''): # from utils.general import *; plot_results()
1259
  # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
1260
  fig, ax = plt.subplots(2, 5, figsize=(12, 6))
1261
  ax = ax.ravel()
1262
- s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
1263
- 'val GIoU', 'val Objectness', 'val Classification', '[email protected]', '[email protected]:0.95']
1264
  if bucket:
1265
  # os.system('rm -rf storage.googleapis.com')
1266
  # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
 
509
  pxy = ps[:, :2].sigmoid() * 2. - 0.5
510
  pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
511
  pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
512
+ iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
513
+ lbox += (1.0 - iou).mean() # iou loss
514
 
515
  # Objectness
516
+ tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
517
 
518
  # Classification
519
  if model.nc > 1: # cls loss (only if multiple classes)
 
528
  lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
529
 
530
  s = 3 / np # output count scaling
531
+ lbox *= h['box'] * s
532
  lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
533
  lcls *= h['cls'] * s
534
  bs = tobj.shape[0] # batch size
 
1234
  def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay()
1235
  # Plot training 'results*.txt', overlaying train and val losses
1236
  s = ['train', 'train', 'train', 'Precision', '[email protected]', 'val', 'val', 'val', 'Recall', '[email protected]:0.95'] # legends
1237
+ t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
1238
  for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
1239
  results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
1240
  n = results.shape[1] # number of rows
 
1254
  fig.savefig(f.replace('.txt', '.png'), dpi=200)
1255
 
1256
 
1257
+ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
1258
+ # from utils.general import *; plot_results()
1259
  # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
1260
  fig, ax = plt.subplots(2, 5, figsize=(12, 6))
1261
  ax = ax.ravel()
1262
+ s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
1263
+ 'val Box', 'val Objectness', 'val Classification', '[email protected]', '[email protected]:0.95']
1264
  if bucket:
1265
  # os.system('rm -rf storage.googleapis.com')
1266
  # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]