glenn-jocher commited on
Commit
a2b3c71
1 Parent(s): 8e5f9dd

Add suffix checks (#4711)

Browse files

* Add suffix checks

* Cleanup

* Cleanup2

* Cleanup3

Files changed (7) hide show
  1. detect.py +6 -4
  2. models/tf.py +3 -3
  3. models/yolo.py +4 -4
  4. train.py +4 -3
  5. utils/datasets.py +3 -3
  6. utils/general.py +16 -1
  7. val.py +5 -3
detect.py CHANGED
@@ -21,8 +21,9 @@ sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
21
 
22
  from models.experimental import attempt_load
23
  from utils.datasets import LoadStreams, LoadImages
24
- from utils.general import check_img_size, check_requirements, check_imshow, colorstr, is_ascii, non_max_suppression, \
25
- apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
 
26
  from utils.plots import Annotator, colors
27
  from utils.torch_utils import select_device, load_classifier, time_sync
28
 
@@ -68,8 +69,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)
68
 
69
  # Load model
70
  w = weights[0] if isinstance(weights, list) else weights
71
- classify, suffix = False, Path(w).suffix.lower()
72
- pt, onnx, tflite, pb, saved_model = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
 
73
  stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
74
  if pt:
75
  model = attempt_load(weights, map_location=device) # load FP32 model
 
21
 
22
  from models.experimental import attempt_load
23
  from utils.datasets import LoadStreams, LoadImages
24
+ from utils.general import check_img_size, check_imshow, check_requirements, check_suffix, colorstr, is_ascii, \
25
+ non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, \
26
+ save_one_box
27
  from utils.plots import Annotator, colors
28
  from utils.torch_utils import select_device, load_classifier, time_sync
29
 
 
69
 
70
  # Load model
71
  w = weights[0] if isinstance(weights, list) else weights
72
+ classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
73
+ check_suffix(w, suffixes) # check weights have acceptable suffix
74
+ pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
75
  stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
76
  if pt:
77
  model = attempt_load(weights, map_location=device) # load FP32 model
models/tf.py CHANGED
@@ -53,7 +53,7 @@ from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, C
53
  from models.experimental import MixConv2d, CrossConv, attempt_load
54
  from models.yolo import Detect
55
  from utils.datasets import LoadImages
56
- from utils.general import make_divisible, check_file, check_dataset
57
 
58
  logger = logging.getLogger(__name__)
59
 
@@ -447,7 +447,7 @@ if __name__ == "__main__":
447
  parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
448
  parser.add_argument('--score-thres', type=float, default=0.4, help='score threshold for NMS')
449
  opt = parser.parse_args()
450
- opt.cfg = check_file(opt.cfg) # check file
451
  opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
452
  print(opt)
453
 
@@ -534,7 +534,7 @@ if __name__ == "__main__":
534
  if opt.tfl_int8:
535
  # Representative Dataset
536
  if opt.source.endswith('.yaml'):
537
- with open(check_file(opt.source)) as f:
538
  data = yaml.load(f, Loader=yaml.FullLoader) # data dict
539
  check_dataset(data) # check
540
  opt.source = data['train']
 
53
  from models.experimental import MixConv2d, CrossConv, attempt_load
54
  from models.yolo import Detect
55
  from utils.datasets import LoadImages
56
+ from utils.general import check_dataset, check_yaml, make_divisible
57
 
58
  logger = logging.getLogger(__name__)
59
 
 
447
  parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
448
  parser.add_argument('--score-thres', type=float, default=0.4, help='score threshold for NMS')
449
  opt = parser.parse_args()
450
+ opt.cfg = check_yaml(opt.cfg) # check YAML
451
  opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
452
  print(opt)
453
 
 
534
  if opt.tfl_int8:
535
  # Representative Dataset
536
  if opt.source.endswith('.yaml'):
537
+ with open(check_yaml(opt.source)) as f:
538
  data = yaml.load(f, Loader=yaml.FullLoader) # data dict
539
  check_dataset(data) # check
540
  opt.source = data['train']
models/yolo.py CHANGED
@@ -17,10 +17,10 @@ sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path
17
  from models.common import *
18
  from models.experimental import *
19
  from utils.autoanchor import check_anchor_order
20
- from utils.general import make_divisible, check_file, set_logging
21
  from utils.plots import feature_visualization
22
- from utils.torch_utils import time_sync, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
23
- select_device, copy_attr
24
 
25
  try:
26
  import thop # for FLOPs computation
@@ -281,7 +281,7 @@ if __name__ == '__main__':
281
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
282
  parser.add_argument('--profile', action='store_true', help='profile model speed')
283
  opt = parser.parse_args()
284
- opt.cfg = check_file(opt.cfg) # check file
285
  set_logging()
286
  device = select_device(opt.device)
287
 
 
17
  from models.common import *
18
  from models.experimental import *
19
  from utils.autoanchor import check_anchor_order
20
+ from utils.general import check_yaml, make_divisible, set_logging
21
  from utils.plots import feature_visualization
22
+ from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \
23
+ select_device, time_sync
24
 
25
  try:
26
  import thop # for FLOPs computation
 
281
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
282
  parser.add_argument('--profile', action='store_true', help='profile model speed')
283
  opt = parser.parse_args()
284
+ opt.cfg = check_yaml(opt.cfg) # check YAML
285
  set_logging()
286
  device = select_device(opt.device)
287
 
train.py CHANGED
@@ -35,8 +35,8 @@ from models.yolo import Model
35
  from utils.autoanchor import check_anchors
36
  from utils.datasets import create_dataloader
37
  from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
38
- strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
39
- check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
40
  from utils.downloads import attempt_download
41
  from utils.loss import ComputeLoss
42
  from utils.plots import plot_labels, plot_evolve
@@ -484,7 +484,8 @@ def main(opt, callbacks=Callbacks()):
484
  opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
485
  LOGGER.info(f'Resuming training from {ckpt}')
486
  else:
487
- opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
 
488
  assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
489
  if opt.evolve:
490
  opt.project = 'runs/evolve'
 
35
  from utils.autoanchor import check_anchors
36
  from utils.datasets import create_dataloader
37
  from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
38
+ strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
39
+ check_yaml, check_suffix, print_mutation, set_logging, one_cycle, colorstr, methods
40
  from utils.downloads import attempt_download
41
  from utils.loss import ComputeLoss
42
  from utils.plots import plot_labels, plot_evolve
 
484
  opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
485
  LOGGER.info(f'Resuming training from {ckpt}')
486
  else:
487
+ check_suffix(opt.weights, '.pt') # check weights
488
+ opt.data, opt.cfg, opt.hyp = check_yaml(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp) # check YAMLs
489
  assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
490
  if opt.evolve:
491
  opt.project = 'runs/evolve'
utils/datasets.py CHANGED
@@ -26,8 +26,8 @@ from torch.utils.data import Dataset
26
  from tqdm import tqdm
27
 
28
  from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
29
- from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
30
- xyn2xy, segments2boxes, clean_str
31
  from utils.torch_utils import torch_distributed_zero_first
32
 
33
  # Parameters
@@ -938,7 +938,7 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profil
938
  im.save(im_dir / Path(f).name, quality=75) # save
939
 
940
  zipped, data_dir, yaml_path = unzip(Path(path))
941
- with open(check_file(yaml_path), errors='ignore') as f:
942
  data = yaml.safe_load(f) # data dict
943
  if zipped:
944
  data['path'] = data_dir # TODO: should this be dir.resolve()?
 
26
  from tqdm import tqdm
27
 
28
  from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
29
+ from utils.general import check_dataset, check_requirements, check_yaml, clean_str, segments2boxes, \
30
+ xywh2xyxy, xywhn2xyxy, xyxy2xywhn, xyn2xy
31
  from utils.torch_utils import torch_distributed_zero_first
32
 
33
  # Parameters
 
938
  im.save(im_dir / Path(f).name, quality=75) # save
939
 
940
  zipped, data_dir, yaml_path = unzip(Path(path))
941
+ with open(check_yaml(yaml_path), errors='ignore') as f:
942
  data = yaml.safe_load(f) # data dict
943
  if zipped:
944
  data['path'] = data_dir # TODO: should this be dir.resolve()?
utils/general.py CHANGED
@@ -242,8 +242,23 @@ def check_imshow():
242
  return False
243
 
244
 
245
- def check_file(file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  # Search/download file (if necessary) and return path
 
247
  file = str(file) # convert to str()
248
  if Path(file).is_file() or file == '': # exists
249
  return file
 
242
  return False
243
 
244
 
245
+ def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
246
+ # Check file(s) for acceptable suffixes
247
+ if any(suffix):
248
+ if isinstance(suffix, str):
249
+ suffix = [suffix]
250
+ for f in file if isinstance(file, (list, tuple)) else [file]:
251
+ assert Path(f).suffix.lower() in suffix, f"{msg}{f} acceptable suffix is {suffix}"
252
+
253
+
254
+ def check_yaml(file, suffix=('.yaml', '.yml')):
255
+ # Check YAML file(s) for acceptable suffixes
256
+ return check_file(file, suffix)
257
+
258
+
259
+ def check_file(file, suffix=''):
260
  # Search/download file (if necessary) and return path
261
+ check_suffix(file, suffix)
262
  file = str(file) # convert to str()
263
  if Path(file).is_file() or file == '': # exists
264
  return file
val.py CHANGED
@@ -22,8 +22,9 @@ sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
22
 
23
  from models.experimental import attempt_load
24
  from utils.datasets import create_dataloader
25
- from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
26
- box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr
 
27
  from utils.metrics import ap_per_class, ConfusionMatrix
28
  from utils.plots import plot_images, output_to_target, plot_study_txt
29
  from utils.torch_utils import select_device, time_sync
@@ -116,6 +117,7 @@ def run(data,
116
  (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
117
 
118
  # Load model
 
119
  model = attempt_load(weights, map_location=device) # load FP32 model
120
  gs = max(int(model.stride.max()), 32) # grid size (max stride)
121
  imgsz = check_img_size(imgsz, s=gs) # check image size
@@ -316,7 +318,7 @@ def parse_opt():
316
  opt = parser.parse_args()
317
  opt.save_json |= opt.data.endswith('coco.yaml')
318
  opt.save_txt |= opt.save_hybrid
319
- opt.data = check_file(opt.data) # check file
320
  return opt
321
 
322
 
 
22
 
23
  from models.experimental import attempt_load
24
  from utils.datasets import create_dataloader
25
+ from utils.general import coco80_to_coco91_class, check_dataset, check_img_size, check_requirements, \
26
+ check_suffix, check_yaml, box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, \
27
+ increment_path, colorstr
28
  from utils.metrics import ap_per_class, ConfusionMatrix
29
  from utils.plots import plot_images, output_to_target, plot_study_txt
30
  from utils.torch_utils import select_device, time_sync
 
117
  (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
118
 
119
  # Load model
120
+ check_suffix(weights, '.pt')
121
  model = attempt_load(weights, map_location=device) # load FP32 model
122
  gs = max(int(model.stride.max()), 32) # grid size (max stride)
123
  imgsz = check_img_size(imgsz, s=gs) # check image size
 
318
  opt = parser.parse_args()
319
  opt.save_json |= opt.data.endswith('coco.yaml')
320
  opt.save_txt |= opt.save_hybrid
321
+ opt.data = check_yaml(opt.data) # check YAML
322
  return opt
323
 
324