glenn-jocher commited on
Commit
66d73e4
1 Parent(s): 208493d

NMS updates

Browse files
Files changed (1) hide show
  1. utils/utils.py +6 -6
utils/utils.py CHANGED
@@ -494,7 +494,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
494
  continue
495
 
496
  # Compute conf
497
- x[..., 5:] *= x[..., 4:5] # conf = obj_conf * cls_conf
498
 
499
  # Box (center x, center y, width, height) to (x1, y1, x2, y2)
500
  box = xywh2xyxy(x[:, :4])
@@ -502,10 +502,10 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
502
  # Detections matrix nx6 (xyxy, conf, cls)
503
  if multi_label:
504
  i, j = (x[:, 5:] > conf_thres).nonzero().t()
505
- x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
506
  else: # best class only
507
- conf, j = x[:, 5:].max(1)
508
- x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)[conf > conf_thres]
509
 
510
  # Filter by class
511
  if classes:
@@ -524,8 +524,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
524
  # x = x[x[:, 4].argsort(descending=True)]
525
 
526
  # Batched NMS
527
- c = x[:, 5] * 0 if agnostic else x[:, 5] # classes
528
- boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
529
  i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
530
  if i.shape[0] > max_det: # limit detections
531
  i = i[:max_det]
 
494
  continue
495
 
496
  # Compute conf
497
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
498
 
499
  # Box (center x, center y, width, height) to (x1, y1, x2, y2)
500
  box = xywh2xyxy(x[:, :4])
 
502
  # Detections matrix nx6 (xyxy, conf, cls)
503
  if multi_label:
504
  i, j = (x[:, 5:] > conf_thres).nonzero().t()
505
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
506
  else: # best class only
507
+ conf, j = x[:, 5:].max(1, keepdim=True)
508
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
509
 
510
  # Filter by class
511
  if classes:
 
524
  # x = x[x[:, 4].argsort(descending=True)]
525
 
526
  # Batched NMS
527
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
528
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
529
  i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
530
  if i.shape[0] > max_det: # limit detections
531
  i = i[:max_det]