glenn-jocher
commited on
Commit
•
9455796
1
Parent(s):
1dcb774
Remove `formats` variable to avoid `pd` conflict (#7993)
Browse files* Remove `formats` variable to avoid `pd` conflict
* Update export.py
- export.py +5 -5
- utils/benchmarks.py +2 -4
export.py
CHANGED
@@ -475,9 +475,9 @@ def run(
|
|
475 |
):
|
476 |
t = time.time()
|
477 |
include = [x.lower() for x in include] # to lowercase
|
478 |
-
|
479 |
-
flags = [x in include for x in
|
480 |
-
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {
|
481 |
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
|
482 |
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
|
483 |
|
@@ -499,7 +499,7 @@ def run(
|
|
499 |
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
500 |
|
501 |
# Update model
|
502 |
-
if half and not
|
503 |
im, model = im.half(), model.half() # to FP16
|
504 |
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
|
505 |
for k, m in model.named_modules():
|
@@ -531,7 +531,7 @@ def run(
|
|
531 |
if any((saved_model, pb, tflite, edgetpu, tfjs)):
|
532 |
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
|
533 |
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
|
534 |
-
assert not
|
535 |
model, f[5] = export_saved_model(model.cpu(),
|
536 |
im,
|
537 |
file,
|
|
|
475 |
):
|
476 |
t = time.time()
|
477 |
include = [x.lower() for x in include] # to lowercase
|
478 |
+
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
|
479 |
+
flags = [x in include for x in fmts]
|
480 |
+
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
|
481 |
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
|
482 |
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
|
483 |
|
|
|
499 |
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
|
500 |
|
501 |
# Update model
|
502 |
+
if half and not coreml and not xml:
|
503 |
im, model = im.half(), model.half() # to FP16
|
504 |
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
|
505 |
for k, m in model.named_modules():
|
|
|
531 |
if any((saved_model, pb, tflite, edgetpu, tfjs)):
|
532 |
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
|
533 |
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
|
534 |
+
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
535 |
model, f[5] = export_saved_model(model.cpu(),
|
536 |
im,
|
537 |
file,
|
utils/benchmarks.py
CHANGED
@@ -56,9 +56,8 @@ def run(
|
|
56 |
pt_only=False, # test PyTorch only
|
57 |
):
|
58 |
y, t = [], time.time()
|
59 |
-
formats = export.export_formats()
|
60 |
device = select_device(device)
|
61 |
-
for i, (name, f, suffix, gpu) in
|
62 |
try:
|
63 |
assert i != 9, 'Edge TPU not supported'
|
64 |
assert i != 10, 'TF.js not supported'
|
@@ -104,9 +103,8 @@ def test(
|
|
104 |
pt_only=False, # test PyTorch only
|
105 |
):
|
106 |
y, t = [], time.time()
|
107 |
-
formats = export.export_formats()
|
108 |
device = select_device(device)
|
109 |
-
for i, (name, f, suffix, gpu) in
|
110 |
try:
|
111 |
w = weights if f == '-' else \
|
112 |
export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights
|
|
|
56 |
pt_only=False, # test PyTorch only
|
57 |
):
|
58 |
y, t = [], time.time()
|
|
|
59 |
device = select_device(device)
|
60 |
+
for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
|
61 |
try:
|
62 |
assert i != 9, 'Edge TPU not supported'
|
63 |
assert i != 10, 'TF.js not supported'
|
|
|
103 |
pt_only=False, # test PyTorch only
|
104 |
):
|
105 |
y, t = [], time.time()
|
|
|
106 |
device = select_device(device)
|
107 |
+
for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
|
108 |
try:
|
109 |
w = weights if f == '-' else \
|
110 |
export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights
|