glenn-jocher commited on
Commit
a9553c0
1 Parent(s): c6deb73

Refactor test.py arguments (#3558)

Browse files

* remove opt from test()

* pass kwargs

* update comments

* revert accidental default change

* multiple --img options

* add comments

Files changed (2) hide show
  1. detect.py +1 -1
  2. test.py +18 -28
detect.py CHANGED
@@ -33,7 +33,7 @@ def detect(opt):
33
  # Load model
34
  model = attempt_load(weights, map_location=device) # load FP32 model
35
  stride = int(model.stride.max()) # model stride
36
- imgsz = check_img_size(imgsz, s=stride) # check img_size
37
  names = model.module.names if hasattr(model, 'module') else model.names # get class names
38
  if half:
39
  model.half() # to FP16
 
33
  # Load model
34
  model = attempt_load(weights, map_location=device) # load FP32 model
35
  stride = int(model.stride.max()) # model stride
36
+ imgsz = check_img_size(imgsz, s=stride) # check image size
37
  names = model.module.names if hasattr(model, 'module') else model.names # get class names
38
  if half:
39
  model.half() # to FP16
test.py CHANGED
@@ -22,9 +22,9 @@ from utils.torch_utils import select_device, time_synchronized
22
  def test(data,
23
  weights=None,
24
  batch_size=32,
25
- imgsz=640,
26
- conf_thres=0.001,
27
- iou_thres=0.6, # for NMS
28
  save_json=False,
29
  single_cls=False,
30
  augment=False,
@@ -38,8 +38,12 @@ def test(data,
38
  plots=True,
39
  wandb_logger=None,
40
  compute_loss=None,
41
- half=True,
42
- opt=None):
 
 
 
 
43
  # Initialize/load model and set device
44
  training = model is not None
45
  if training: # called by train.py
@@ -47,16 +51,16 @@ def test(data,
47
 
48
  else: # called directly
49
  set_logging()
50
- device = select_device(opt.device, batch_size=batch_size)
51
 
52
  # Directories
53
- save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
54
  (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
55
 
56
  # Load model
57
  model = attempt_load(weights, map_location=device) # load FP32 model
58
  gs = max(int(model.stride.max()), 32) # grid size (max stride)
59
- imgsz = check_img_size(imgsz, s=gs) # check img_size
60
 
61
  # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
62
  # if device.type != 'cpu' and torch.cuda.device_count() > 1:
@@ -86,7 +90,7 @@ def test(data,
86
  if not training:
87
  if device.type != 'cpu':
88
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
89
- task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
90
  dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True,
91
  prefix=colorstr(f'{task}: '))[0]
92
 
@@ -294,7 +298,7 @@ if __name__ == '__main__':
294
  parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
295
  parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
296
  parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
297
- parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
298
  parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
299
  parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS')
300
  parser.add_argument('--task', default='val', help='train, val, test, speed or study')
@@ -312,31 +316,17 @@ if __name__ == '__main__':
312
  parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
313
  opt = parser.parse_args()
314
  opt.save_json |= opt.data.endswith('coco.yaml')
 
315
  opt.data = check_file(opt.data) # check file
316
  print(opt)
317
  check_requirements(exclude=('tensorboard', 'thop'))
318
 
319
  if opt.task in ('train', 'val', 'test'): # run normally
320
- test(opt.data,
321
- opt.weights,
322
- opt.batch_size,
323
- opt.img_size,
324
- opt.conf_thres,
325
- opt.iou_thres,
326
- opt.save_json,
327
- opt.single_cls,
328
- opt.augment,
329
- opt.verbose,
330
- save_txt=opt.save_txt | opt.save_hybrid,
331
- save_hybrid=opt.save_hybrid,
332
- save_conf=opt.save_conf,
333
- half=opt.half,
334
- opt=opt
335
- )
336
 
337
  elif opt.task == 'speed': # speed benchmarks
338
  for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
339
- test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False, opt=opt)
340
 
341
  elif opt.task == 'study': # run over a range of settings and save/plot
342
  # python test.py --task study --data coco.yaml --iou 0.7 --weights yolov5s.pt yolov5m.pt yolov5l.pt yolov5x.pt
@@ -347,7 +337,7 @@ if __name__ == '__main__':
347
  for i in x: # img-size
348
  print(f'\nRunning {f} point {i}...')
349
  r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json,
350
- plots=False, opt=opt)
351
  y.append(r + t) # results and times
352
  np.savetxt(f, y, fmt='%10.4g') # save
353
  os.system('zip -r study.zip study_*.txt')
 
22
  def test(data,
23
  weights=None,
24
  batch_size=32,
25
+ imgsz=640, # image size
26
+ conf_thres=0.001, # confidence threshold
27
+ iou_thres=0.6, # NMS IoU threshold
28
  save_json=False,
29
  single_cls=False,
30
  augment=False,
 
38
  plots=True,
39
  wandb_logger=None,
40
  compute_loss=None,
41
+ half=True, # FP16 half-precision inference
42
+ project='runs/test',
43
+ name='exp',
44
+ exist_ok=False,
45
+ task='val',
46
+ device=''):
47
  # Initialize/load model and set device
48
  training = model is not None
49
  if training: # called by train.py
 
51
 
52
  else: # called directly
53
  set_logging()
54
+ device = select_device(device, batch_size=batch_size)
55
 
56
  # Directories
57
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
58
  (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
59
 
60
  # Load model
61
  model = attempt_load(weights, map_location=device) # load FP32 model
62
  gs = max(int(model.stride.max()), 32) # grid size (max stride)
63
+ imgsz = check_img_size(imgsz, s=gs) # check image size
64
 
65
  # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
66
  # if device.type != 'cpu' and torch.cuda.device_count() > 1:
 
90
  if not training:
91
  if device.type != 'cpu':
92
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
93
+ task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
94
  dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True,
95
  prefix=colorstr(f'{task}: '))[0]
96
 
 
298
  parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
299
  parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
300
  parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
301
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
302
  parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
303
  parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS')
304
  parser.add_argument('--task', default='val', help='train, val, test, speed or study')
 
316
  parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
317
  opt = parser.parse_args()
318
  opt.save_json |= opt.data.endswith('coco.yaml')
319
+ opt.save_txt |= opt.save_hybrid
320
  opt.data = check_file(opt.data) # check file
321
  print(opt)
322
  check_requirements(exclude=('tensorboard', 'thop'))
323
 
324
  if opt.task in ('train', 'val', 'test'): # run normally
325
+ test(**vars(opt))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  elif opt.task == 'speed': # speed benchmarks
328
  for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
329
+ test(opt.data, w, opt.batch_size, opt.imgsz, 0.25, 0.45, save_json=False, plots=False)
330
 
331
  elif opt.task == 'study': # run over a range of settings and save/plot
332
  # python test.py --task study --data coco.yaml --iou 0.7 --weights yolov5s.pt yolov5m.pt yolov5l.pt yolov5x.pt
 
337
  for i in x: # img-size
338
  print(f'\nRunning {f} point {i}...')
339
  r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json,
340
+ plots=False)
341
  y.append(r + t) # results and times
342
  np.savetxt(f, y, fmt='%10.4g') # save
343
  os.system('zip -r study.zip study_*.txt')