glenn-jocher commited on
Commit
68e6ab6
1 Parent(s): 791dadb

Hub device mismatch bug fix (#1619)

Browse files
Files changed (1) hide show
  1. utils/general.py +2 -2
utils/general.py CHANGED
@@ -265,7 +265,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, classes=None,
265
  detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
266
  """
267
 
268
- nc = prediction[0].shape[1] - 5 # number of classes
269
  xc = prediction[..., 4] > conf_thres # candidates
270
 
271
  # Settings
@@ -277,7 +277,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, classes=None,
277
  merge = False # use merge-NMS
278
 
279
  t = time.time()
280
- output = [torch.zeros(0, 6)] * prediction.shape[0]
281
  for xi, x in enumerate(prediction): # image index, image inference
282
  # Apply constraints
283
  # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
 
265
  detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
266
  """
267
 
268
+ nc = prediction.shape[2] - 5 # number of classes
269
  xc = prediction[..., 4] > conf_thres # candidates
270
 
271
  # Settings
 
277
  merge = False # use merge-NMS
278
 
279
  t = time.time()
280
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
281
  for xi, x in enumerate(prediction): # image index, image inference
282
  # Apply constraints
283
  # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height