glenn-jocher commited on
Commit
de9c25b
1 Parent(s): a297efc

Use `export_formats()` in export.py (#6705)

Browse files

* Use `export_formats()` in export.py

* list fix

Files changed (1) hide show
  1. export.py +12 -10
export.py CHANGED
@@ -433,9 +433,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
433
  conf_thres=0.25 # TF.js NMS: confidence threshold
434
  ):
435
  t = time.time()
436
- include = [x.lower() for x in include]
437
- tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports
438
- file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
 
 
 
439
 
440
  # Load PyTorch model
441
  device = select_device(device)
@@ -475,20 +478,19 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
475
  # Exports
476
  f = [''] * 10 # exported filenames
477
  warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
478
- if 'torchscript' in include:
479
  f[0] = export_torchscript(model, im, file, optimize)
480
- if 'engine' in include: # TensorRT required before ONNX
481
  f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
482
- if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
483
  f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
484
- if 'openvino' in include:
485
  f[3] = export_openvino(model, im, file)
486
- if 'coreml' in include:
487
  _, f[4] = export_coreml(model, im, file)
488
 
489
  # TensorFlow Exports
490
- if any(tf_exports):
491
- pb, tflite, edgetpu, tfjs = tf_exports[1:]
492
  if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
493
  check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
494
  assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
 
433
  conf_thres=0.25 # TF.js NMS: confidence threshold
434
  ):
435
  t = time.time()
436
+ include = [x.lower() for x in include] # to lowercase
437
+ formats = tuple(export_formats()['Argument'][1:]) # --include arguments
438
+ flags = [x in include for x in formats]
439
+ assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'
440
+ jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
441
+ file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
442
 
443
  # Load PyTorch model
444
  device = select_device(device)
 
478
  # Exports
479
  f = [''] * 10 # exported filenames
480
  warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
481
+ if jit:
482
  f[0] = export_torchscript(model, im, file, optimize)
483
+ if engine: # TensorRT required before ONNX
484
  f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
485
+ if onnx or xml: # OpenVINO requires ONNX
486
  f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
487
+ if xml: # OpenVINO
488
  f[3] = export_openvino(model, im, file)
489
+ if coreml:
490
  _, f[4] = export_coreml(model, im, file)
491
 
492
  # TensorFlow Exports
493
+ if any((saved_model, pb, tflite, edgetpu, tfjs)):
 
494
  if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
495
  check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
496
  assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'