glenn-jocher
commited on
Commit
•
66d73e4
1
Parent(s):
208493d
NMS updates
Browse files- utils/utils.py +6 -6
utils/utils.py
CHANGED
@@ -494,7 +494,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
|
|
494 |
continue
|
495 |
|
496 |
# Compute conf
|
497 |
-
x[
|
498 |
|
499 |
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
500 |
box = xywh2xyxy(x[:, :4])
|
@@ -502,10 +502,10 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
|
|
502 |
# Detections matrix nx6 (xyxy, conf, cls)
|
503 |
if multi_label:
|
504 |
i, j = (x[:, 5:] > conf_thres).nonzero().t()
|
505 |
-
x = torch.cat((box[i], x[i, j + 5]
|
506 |
else: # best class only
|
507 |
-
conf, j = x[:, 5:].max(1)
|
508 |
-
x = torch.cat((box, conf
|
509 |
|
510 |
# Filter by class
|
511 |
if classes:
|
@@ -524,8 +524,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
|
|
524 |
# x = x[x[:, 4].argsort(descending=True)]
|
525 |
|
526 |
# Batched NMS
|
527 |
-
c = x[:, 5] * 0 if agnostic else
|
528 |
-
boxes, scores = x[:, :4]
|
529 |
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
530 |
if i.shape[0] > max_det: # limit detections
|
531 |
i = i[:max_det]
|
|
|
494 |
continue
|
495 |
|
496 |
# Compute conf
|
497 |
+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
498 |
|
499 |
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
500 |
box = xywh2xyxy(x[:, :4])
|
|
|
502 |
# Detections matrix nx6 (xyxy, conf, cls)
|
503 |
if multi_label:
|
504 |
i, j = (x[:, 5:] > conf_thres).nonzero().t()
|
505 |
+
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
|
506 |
else: # best class only
|
507 |
+
conf, j = x[:, 5:].max(1, keepdim=True)
|
508 |
+
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
509 |
|
510 |
# Filter by class
|
511 |
if classes:
|
|
|
524 |
# x = x[x[:, 4].argsort(descending=True)]
|
525 |
|
526 |
# Batched NMS
|
527 |
+
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
528 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
529 |
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
530 |
if i.shape[0] > max_det: # limit detections
|
531 |
i = i[:max_det]
|