glenn-jocher
commited on
Commit
•
de9c25b
1
Parent(s):
a297efc
Use `export_formats()` in export.py (#6705)
Browse files* Use `export_formats()` in export.py
* list fix
export.py
CHANGED
@@ -433,9 +433,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
|
433 |
conf_thres=0.25 # TF.js NMS: confidence threshold
|
434 |
):
|
435 |
t = time.time()
|
436 |
-
include = [x.lower() for x in include]
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
439 |
|
440 |
# Load PyTorch model
|
441 |
device = select_device(device)
|
@@ -475,20 +478,19 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
|
475 |
# Exports
|
476 |
f = [''] * 10 # exported filenames
|
477 |
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
478 |
-
if
|
479 |
f[0] = export_torchscript(model, im, file, optimize)
|
480 |
-
if
|
481 |
f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
|
482 |
-
if
|
483 |
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
484 |
-
if
|
485 |
f[3] = export_openvino(model, im, file)
|
486 |
-
if
|
487 |
_, f[4] = export_coreml(model, im, file)
|
488 |
|
489 |
# TensorFlow Exports
|
490 |
-
if any(
|
491 |
-
pb, tflite, edgetpu, tfjs = tf_exports[1:]
|
492 |
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
|
493 |
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
|
494 |
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
|
|
433 |
conf_thres=0.25 # TF.js NMS: confidence threshold
|
434 |
):
|
435 |
t = time.time()
|
436 |
+
include = [x.lower() for x in include] # to lowercase
|
437 |
+
formats = tuple(export_formats()['Argument'][1:]) # --include arguments
|
438 |
+
flags = [x in include for x in formats]
|
439 |
+
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'
|
440 |
+
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
|
441 |
+
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
|
442 |
|
443 |
# Load PyTorch model
|
444 |
device = select_device(device)
|
|
|
478 |
# Exports
|
479 |
f = [''] * 10 # exported filenames
|
480 |
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
481 |
+
if jit:
|
482 |
f[0] = export_torchscript(model, im, file, optimize)
|
483 |
+
if engine: # TensorRT required before ONNX
|
484 |
f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
|
485 |
+
if onnx or xml: # OpenVINO requires ONNX
|
486 |
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
487 |
+
if xml: # OpenVINO
|
488 |
f[3] = export_openvino(model, im, file)
|
489 |
+
if coreml:
|
490 |
_, f[4] = export_coreml(model, im, file)
|
491 |
|
492 |
# TensorFlow Exports
|
493 |
+
if any((saved_model, pb, tflite, edgetpu, tfjs)):
|
|
|
494 |
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
|
495 |
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
|
496 |
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
|