glenn-jocher commited on
Commit
394131c
1 Parent(s): 199c9c7

Use torchvision.ops.nms (#1460)

Browse files
Files changed (1) hide show
  1. utils/general.py +2 -1
utils/general.py CHANGED
@@ -15,6 +15,7 @@ import cv2
15
  import matplotlib
16
  import numpy as np
17
  import torch
 
18
  import yaml
19
 
20
  from utils.google_utils import gsutil_getsize
@@ -323,7 +324,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
323
  # Batched NMS
324
  c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
325
  boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
326
- i = torch.ops.torchvision.nms(boxes, scores, iou_thres)
327
  if i.shape[0] > max_det: # limit detections
328
  i = i[:max_det]
329
  if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
 
15
  import matplotlib
16
  import numpy as np
17
  import torch
18
+ import torchvision
19
  import yaml
20
 
21
  from utils.google_utils import gsutil_getsize
 
324
  # Batched NMS
325
  c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
326
  boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
327
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
328
  if i.shape[0] > max_det: # limit detections
329
  i = i[:max_det]
330
  if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)