glenn-jocher
commited on
Commit
•
394131c
1
Parent(s):
199c9c7
Use torchvision.ops.nms (#1460)
Browse files- utils/general.py +2 -1
utils/general.py
CHANGED
@@ -15,6 +15,7 @@ import cv2
|
|
15 |
import matplotlib
|
16 |
import numpy as np
|
17 |
import torch
|
|
|
18 |
import yaml
|
19 |
|
20 |
from utils.google_utils import gsutil_getsize
|
@@ -323,7 +324,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
|
|
323 |
# Batched NMS
|
324 |
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
325 |
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
326 |
-
i =
|
327 |
if i.shape[0] > max_det: # limit detections
|
328 |
i = i[:max_det]
|
329 |
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
|
|
15 |
import matplotlib
|
16 |
import numpy as np
|
17 |
import torch
|
18 |
+
import torchvision
|
19 |
import yaml
|
20 |
|
21 |
from utils.google_utils import gsutil_getsize
|
|
|
324 |
# Batched NMS
|
325 |
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
326 |
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
327 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
328 |
if i.shape[0] > max_det: # limit detections
|
329 |
i = i[:max_det]
|
330 |
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|