glenn-jocher
commited on
Commit
•
eb97b2e
1
Parent(s):
d97d31e
NMS fast mode
Browse files
detect.py
CHANGED
@@ -76,7 +76,7 @@ def detect(save_img=False):
|
|
76 |
|
77 |
# Apply NMS
|
78 |
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
|
79 |
-
|
80 |
|
81 |
# Apply Classifier
|
82 |
if classify:
|
|
|
76 |
|
77 |
# Apply NMS
|
78 |
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
|
79 |
+
fast=True, classes=opt.classes, agnostic=opt.agnostic_nms)
|
80 |
|
81 |
# Apply Classifier
|
82 |
if classify:
|
test.py
CHANGED
@@ -19,7 +19,7 @@ def test(data,
|
|
19 |
augment=False,
|
20 |
model=None,
|
21 |
dataloader=None,
|
22 |
-
|
23 |
verbose=False): # 0 fast, 1 accurate
|
24 |
# Initialize/load model and set device
|
25 |
if model is None:
|
@@ -92,7 +92,7 @@ def test(data,
|
|
92 |
|
93 |
# Run NMS
|
94 |
t = torch_utils.time_synchronized()
|
95 |
-
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres,
|
96 |
t1 += torch_utils.time_synchronized() - t
|
97 |
|
98 |
# Statistics per image
|
|
|
19 |
augment=False,
|
20 |
model=None,
|
21 |
dataloader=None,
|
22 |
+
fast=False,
|
23 |
verbose=False): # 0 fast, 1 accurate
|
24 |
# Initialize/load model and set device
|
25 |
if model is None:
|
|
|
92 |
|
93 |
# Run NMS
|
94 |
t = torch_utils.time_synchronized()
|
95 |
+
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, fast=fast)
|
96 |
t1 += torch_utils.time_synchronized() - t
|
97 |
|
98 |
# Statistics per image
|
train.py
CHANGED
@@ -293,13 +293,13 @@ def train(hyp):
|
|
293 |
final_epoch = epoch + 1 == epochs
|
294 |
if not opt.notest or final_epoch: # Calculate mAP
|
295 |
results, maps, times = test.test(opt.data,
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
|
304 |
# Write
|
305 |
with open(results_file, 'a') as f:
|
@@ -325,10 +325,10 @@ def train(hyp):
|
|
325 |
if save:
|
326 |
with open(results_file, 'r') as f: # create checkpoint
|
327 |
ckpt = {'epoch': epoch,
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
|
333 |
# Save last, best and delete
|
334 |
torch.save(ckpt, last)
|
|
|
293 |
final_epoch = epoch + 1 == epochs
|
294 |
if not opt.notest or final_epoch: # Calculate mAP
|
295 |
results, maps, times = test.test(opt.data,
|
296 |
+
batch_size=batch_size,
|
297 |
+
imgsz=imgsz_test,
|
298 |
+
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
299 |
+
model=ema.ema,
|
300 |
+
single_cls=opt.single_cls,
|
301 |
+
dataloader=testloader,
|
302 |
+
fast=ni > n_burn)
|
303 |
|
304 |
# Write
|
305 |
with open(results_file, 'a') as f:
|
|
|
325 |
if save:
|
326 |
with open(results_file, 'r') as f: # create checkpoint
|
327 |
ckpt = {'epoch': epoch,
|
328 |
+
'best_fitness': best_fitness,
|
329 |
+
'training_results': f.read(),
|
330 |
+
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
|
331 |
+
'optimizer': None if final_epoch else optimizer.state_dict()}
|
332 |
|
333 |
# Save last, best and delete
|
334 |
torch.save(ckpt, last)
|
utils/utils.py
CHANGED
@@ -19,7 +19,7 @@ import torchvision
|
|
19 |
from scipy.signal import butter, filtfilt
|
20 |
from tqdm import tqdm
|
21 |
|
22 |
-
from . import torch_utils, google_utils #
|
23 |
|
24 |
# Set printoptions
|
25 |
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
@@ -460,29 +460,33 @@ def build_targets(p, targets, model):
|
|
460 |
|
461 |
return tcls, tbox, indices, anch
|
462 |
|
463 |
-
|
464 |
-
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
|
465 |
"""
|
466 |
Performs Non-Maximum Suppression on inference results
|
467 |
Returns detections with shape:
|
468 |
nx6 (x1, y1, x2, y2, conf, cls)
|
469 |
"""
|
|
|
470 |
|
471 |
# Settings
|
472 |
-
merge = True # merge for best mAP
|
473 |
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
474 |
max_det = 300 # maximum number of detections per image
|
475 |
time_limit = 10.0 # seconds to quit after
|
476 |
-
redundant =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
|
478 |
t = time.time()
|
479 |
-
nc = prediction[0].shape[1] - 5 # number of classes
|
480 |
-
multi_label &= nc > 1 # multiple labels per box
|
481 |
output = [None] * prediction.shape[0]
|
482 |
for xi, x in enumerate(prediction): # image index, image inference
|
483 |
# Apply constraints
|
|
|
484 |
x = x[x[:, 4] > conf_thres] # confidence
|
485 |
-
# x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height
|
486 |
|
487 |
# If none remain process next image
|
488 |
if not x.shape[0]:
|
|
|
19 |
from scipy.signal import butter, filtfilt
|
20 |
from tqdm import tqdm
|
21 |
|
22 |
+
from . import torch_utils, google_utils # torch_utils, google_utils
|
23 |
|
24 |
# Set printoptions
|
25 |
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
|
|
460 |
|
461 |
return tcls, tbox, indices, anch
|
462 |
|
463 |
+
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
|
|
|
464 |
"""
|
465 |
Performs Non-Maximum Suppression on inference results
|
466 |
Returns detections with shape:
|
467 |
nx6 (x1, y1, x2, y2, conf, cls)
|
468 |
"""
|
469 |
+
nc = prediction[0].shape[1] - 5 # number of classes
|
470 |
|
471 |
# Settings
|
|
|
472 |
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
473 |
max_det = 300 # maximum number of detections per image
|
474 |
time_limit = 10.0 # seconds to quit after
|
475 |
+
redundant = True # require redundant detections
|
476 |
+
fast |= conf_thres > 0.001 # fast mode
|
477 |
+
if fast:
|
478 |
+
merge = False
|
479 |
+
multi_label = False
|
480 |
+
else:
|
481 |
+
merge = True # merge for best mAP (adds 0.5ms/img)
|
482 |
+
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
483 |
|
484 |
t = time.time()
|
|
|
|
|
485 |
output = [None] * prediction.shape[0]
|
486 |
for xi, x in enumerate(prediction): # image index, image inference
|
487 |
# Apply constraints
|
488 |
+
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
489 |
x = x[x[:, 4] > conf_thres] # confidence
|
|
|
490 |
|
491 |
# If none remain process next image
|
492 |
if not x.shape[0]:
|