Merge branch 'ultralytics:master' into main
Browse files- export.py +15 -6
- models/common.py +16 -6
- utils/metrics.py +6 -0
- val.py +4 -0
export.py
CHANGED
@@ -216,8 +216,9 @@ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
|
|
216 |
return None, None
|
217 |
|
218 |
|
219 |
-
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False
|
220 |
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
|
|
|
221 |
try:
|
222 |
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
|
223 |
try:
|
@@ -230,11 +231,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
|
|
230 |
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
231 |
grid = model.model[-1].anchor_grid
|
232 |
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
233 |
-
export_onnx(model, im, file, 12, train,
|
234 |
model.model[-1].anchor_grid = grid
|
235 |
else: # TensorRT >= 8
|
236 |
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
|
237 |
-
export_onnx(model, im, file, 13, train,
|
238 |
onnx = file.with_suffix('.onnx')
|
239 |
|
240 |
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
@@ -263,6 +264,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
|
|
263 |
for out in outputs:
|
264 |
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
|
267 |
if builder.platform_has_fast_fp16 and half:
|
268 |
config.set_flag(trt.BuilderFlag.FP16)
|
@@ -460,7 +469,7 @@ def run(
|
|
460 |
keras=False, # use Keras
|
461 |
optimize=False, # TorchScript: optimize for mobile
|
462 |
int8=False, # CoreML/TF INT8 quantization
|
463 |
-
dynamic=False, # ONNX/TF: dynamic axes
|
464 |
simplify=False, # ONNX: simplify model
|
465 |
opset=12, # ONNX: opset version
|
466 |
verbose=False, # TensorRT: verbose log
|
@@ -520,7 +529,7 @@ def run(
|
|
520 |
if jit:
|
521 |
f[0] = export_torchscript(model, im, file, optimize)
|
522 |
if engine: # TensorRT required before ONNX
|
523 |
-
f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
|
524 |
if onnx or xml: # OpenVINO requires ONNX
|
525 |
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
526 |
if xml: # OpenVINO
|
@@ -579,7 +588,7 @@ def parse_opt():
|
|
579 |
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
|
580 |
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
|
581 |
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
|
582 |
-
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
|
583 |
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
|
584 |
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
|
585 |
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
|
|
|
216 |
return None, None
|
217 |
|
218 |
|
219 |
+
def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4, verbose=False):
|
220 |
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
|
221 |
+
prefix = colorstr('TensorRT:')
|
222 |
try:
|
223 |
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
|
224 |
try:
|
|
|
231 |
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
232 |
grid = model.model[-1].anchor_grid
|
233 |
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
234 |
+
export_onnx(model, im, file, 12, train, dynamic, simplify) # opset 12
|
235 |
model.model[-1].anchor_grid = grid
|
236 |
else: # TensorRT >= 8
|
237 |
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
|
238 |
+
export_onnx(model, im, file, 13, train, dynamic, simplify) # opset 13
|
239 |
onnx = file.with_suffix('.onnx')
|
240 |
|
241 |
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
|
|
264 |
for out in outputs:
|
265 |
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
266 |
|
267 |
+
if dynamic:
|
268 |
+
if im.shape[0] <= 1:
|
269 |
+
LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
|
270 |
+
profile = builder.create_optimization_profile()
|
271 |
+
for inp in inputs:
|
272 |
+
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
|
273 |
+
config.add_optimization_profile(profile)
|
274 |
+
|
275 |
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
|
276 |
if builder.platform_has_fast_fp16 and half:
|
277 |
config.set_flag(trt.BuilderFlag.FP16)
|
|
|
469 |
keras=False, # use Keras
|
470 |
optimize=False, # TorchScript: optimize for mobile
|
471 |
int8=False, # CoreML/TF INT8 quantization
|
472 |
+
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
|
473 |
simplify=False, # ONNX: simplify model
|
474 |
opset=12, # ONNX: opset version
|
475 |
verbose=False, # TensorRT: verbose log
|
|
|
529 |
if jit:
|
530 |
f[0] = export_torchscript(model, im, file, optimize)
|
531 |
if engine: # TensorRT required before ONNX
|
532 |
+
f[1] = export_engine(model, im, file, train, half, dynamic, simplify, workspace, verbose)
|
533 |
if onnx or xml: # OpenVINO requires ONNX
|
534 |
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
535 |
if xml: # OpenVINO
|
|
|
588 |
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
|
589 |
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
|
590 |
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
|
591 |
+
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
|
592 |
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
|
593 |
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
|
594 |
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
|
models/common.py
CHANGED
@@ -384,19 +384,24 @@ class DetectMultiBackend(nn.Module):
|
|
384 |
logger = trt.Logger(trt.Logger.INFO)
|
385 |
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
386 |
model = runtime.deserialize_cuda_engine(f.read())
|
|
|
387 |
bindings = OrderedDict()
|
388 |
fp16 = False # default updated below
|
|
|
389 |
for index in range(model.num_bindings):
|
390 |
name = model.get_binding_name(index)
|
391 |
dtype = trt.nptype(model.get_binding_dtype(index))
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
|
394 |
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
|
395 |
-
if model.binding_is_input(index) and dtype == np.float16:
|
396 |
-
fp16 = True
|
397 |
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
398 |
-
|
399 |
-
batch_size = bindings['images'].shape[0]
|
400 |
elif coreml: # CoreML
|
401 |
LOGGER.info(f'Loading {w} for CoreML inference...')
|
402 |
import coremltools as ct
|
@@ -466,7 +471,12 @@ class DetectMultiBackend(nn.Module):
|
|
466 |
im = im.cpu().numpy() # FP32
|
467 |
y = self.executable_network([im])[self.output_layer]
|
468 |
elif self.engine: # TensorRT
|
469 |
-
|
|
|
|
|
|
|
|
|
|
|
470 |
self.binding_addrs['images'] = int(im.data_ptr())
|
471 |
self.context.execute_v2(list(self.binding_addrs.values()))
|
472 |
y = self.bindings['output'].data
|
|
|
384 |
logger = trt.Logger(trt.Logger.INFO)
|
385 |
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
386 |
model = runtime.deserialize_cuda_engine(f.read())
|
387 |
+
context = model.create_execution_context()
|
388 |
bindings = OrderedDict()
|
389 |
fp16 = False # default updated below
|
390 |
+
dynamic_input = False
|
391 |
for index in range(model.num_bindings):
|
392 |
name = model.get_binding_name(index)
|
393 |
dtype = trt.nptype(model.get_binding_dtype(index))
|
394 |
+
if model.binding_is_input(index):
|
395 |
+
if -1 in tuple(model.get_binding_shape(index)): # dynamic
|
396 |
+
dynamic_input = True
|
397 |
+
context.set_binding_shape(index, tuple(model.get_profile_shape(0, index)[2]))
|
398 |
+
if dtype == np.float16:
|
399 |
+
fp16 = True
|
400 |
+
shape = tuple(context.get_binding_shape(index))
|
401 |
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
|
402 |
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
|
|
|
|
|
403 |
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
404 |
+
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
|
|
405 |
elif coreml: # CoreML
|
406 |
LOGGER.info(f'Loading {w} for CoreML inference...')
|
407 |
import coremltools as ct
|
|
|
471 |
im = im.cpu().numpy() # FP32
|
472 |
y = self.executable_network([im])[self.output_layer]
|
473 |
elif self.engine: # TensorRT
|
474 |
+
if im.shape != self.bindings['images'].shape and self.dynamic_input:
|
475 |
+
self.context.set_binding_shape(self.model.get_binding_index('images'), im.shape) # reshape if dynamic
|
476 |
+
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
477 |
+
assert im.shape == self.bindings['images'].shape, (
|
478 |
+
f"image shape {im.shape} exceeds model max shape {self.bindings['images'].shape}" if self.dynamic_input
|
479 |
+
else f"image shape {im.shape} does not match model shape {self.bindings['images'].shape}")
|
480 |
self.binding_addrs['images'] = int(im.data_ptr())
|
481 |
self.context.execute_v2(list(self.binding_addrs.values()))
|
482 |
y = self.bindings['output'].data
|
utils/metrics.py
CHANGED
@@ -139,6 +139,12 @@ class ConfusionMatrix:
|
|
139 |
Returns:
|
140 |
None, updates confusion matrix accordingly
|
141 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
detections = detections[detections[:, 4] > self.conf]
|
143 |
gt_classes = labels[:, 0].int()
|
144 |
detection_classes = detections[:, 5].int()
|
|
|
139 |
Returns:
|
140 |
None, updates confusion matrix accordingly
|
141 |
"""
|
142 |
+
if detections is None:
|
143 |
+
gt_classes = labels.int()
|
144 |
+
for i, gc in enumerate(gt_classes):
|
145 |
+
self.matrix[self.nc, gc] += 1 # background FN
|
146 |
+
return
|
147 |
+
|
148 |
detections = detections[detections[:, 4] > self.conf]
|
149 |
gt_classes = labels[:, 0].int()
|
150 |
detection_classes = detections[:, 5].int()
|
val.py
CHANGED
@@ -228,6 +228,8 @@ def run(
|
|
228 |
if npr == 0:
|
229 |
if nl:
|
230 |
stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))
|
|
|
|
|
231 |
continue
|
232 |
|
233 |
# Predictions
|
@@ -273,6 +275,8 @@ def run(
|
|
273 |
# Print results
|
274 |
pf = '%20s' + '%11i' * 2 + '%11.3g' * 4 # print format
|
275 |
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
|
|
|
|
|
276 |
|
277 |
# Print results per class
|
278 |
if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
|
|
|
228 |
if npr == 0:
|
229 |
if nl:
|
230 |
stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))
|
231 |
+
if plots:
|
232 |
+
confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
|
233 |
continue
|
234 |
|
235 |
# Predictions
|
|
|
275 |
# Print results
|
276 |
pf = '%20s' + '%11i' * 2 + '%11.3g' * 4 # print format
|
277 |
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
|
278 |
+
if nt.sum() == 0:
|
279 |
+
LOGGER.warning(emojis(f'WARNING: no labels found in {task} set, can not compute metrics without labels ⚠️'))
|
280 |
|
281 |
# Print results per class
|
282 |
if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
|