Jean-Baptiste Martin glenn-jocher commited on
Commit
1cad0ce
1 Parent(s): deb434a

Allow `multi_label` option for NMS with PyTorch Hub (#4728)

Browse files

* Allow specifying multi_label option for NMS when using torch hub

* Reformat

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. models/common.py +3 -1
models/common.py CHANGED
@@ -278,6 +278,7 @@ class AutoShape(nn.Module):
278
  conf = 0.25 # NMS confidence threshold
279
  iou = 0.45 # NMS IoU threshold
280
  classes = None # (optional list) filter by class
 
281
  max_det = 1000 # maximum number of detections per image
282
 
283
  def __init__(self, model):
@@ -337,7 +338,8 @@ class AutoShape(nn.Module):
337
  t.append(time_sync())
338
 
339
  # Post-process
340
- y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
 
341
  for i in range(n):
342
  scale_coords(shape1, y[i][:, :4], shape0[i])
343
 
 
278
  conf = 0.25 # NMS confidence threshold
279
  iou = 0.45 # NMS IoU threshold
280
  classes = None # (optional list) filter by class
281
+ multi_label = False # NMS multiple labels per box
282
  max_det = 1000 # maximum number of detections per image
283
 
284
  def __init__(self, model):
 
338
  t.append(time_sync())
339
 
340
  # Post-process
341
+ y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes,
342
+ multi_label=self.multi_label, max_det=self.max_det) # NMS
343
  for i in range(n):
344
  scale_coords(shape1, y[i][:, :4], shape0[i])
345