glenn-jocher
commited on
Commit
•
932dc78
1
Parent(s):
99de551
YOLOv5 Export Benchmarks for GPU (#6963)
Browse files* Add benchmarks.py GPU support
* Updates
* Updates
* Updates
* Updates
* Add --half
* Add TRT requirements
* Cleanup
* Add TF to warmup types
* Update export.py
* Update export.py
* Update benchmarks.py
- export.py +12 -12
- models/common.py +4 -3
- utils/benchmarks.py +15 -3
export.py
CHANGED
@@ -75,18 +75,18 @@ from utils.torch_utils import select_device
|
|
75 |
|
76 |
def export_formats():
|
77 |
# YOLOv5 export formats
|
78 |
-
x = [['PyTorch', '-', '.pt'],
|
79 |
-
['TorchScript', 'torchscript', '.torchscript'],
|
80 |
-
['ONNX', 'onnx', '.onnx'],
|
81 |
-
['OpenVINO', 'openvino', '_openvino_model'],
|
82 |
-
['TensorRT', 'engine', '.engine'],
|
83 |
-
['CoreML', 'coreml', '.mlmodel'],
|
84 |
-
['TensorFlow SavedModel', 'saved_model', '_saved_model'],
|
85 |
-
['TensorFlow GraphDef', 'pb', '.pb'],
|
86 |
-
['TensorFlow Lite', 'tflite', '.tflite'],
|
87 |
-
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite'],
|
88 |
-
['TensorFlow.js', 'tfjs', '_web_model']]
|
89 |
-
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix'])
|
90 |
|
91 |
|
92 |
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
|
|
|
75 |
|
76 |
def export_formats():
|
77 |
# YOLOv5 export formats
|
78 |
+
x = [['PyTorch', '-', '.pt', True],
|
79 |
+
['TorchScript', 'torchscript', '.torchscript', True],
|
80 |
+
['ONNX', 'onnx', '.onnx', True],
|
81 |
+
['OpenVINO', 'openvino', '_openvino_model', False],
|
82 |
+
['TensorRT', 'engine', '.engine', True],
|
83 |
+
['CoreML', 'coreml', '.mlmodel', False],
|
84 |
+
['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
|
85 |
+
['TensorFlow GraphDef', 'pb', '.pb', True],
|
86 |
+
['TensorFlow Lite', 'tflite', '.tflite', False],
|
87 |
+
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
|
88 |
+
['TensorFlow.js', 'tfjs', '_web_model', False]]
|
89 |
+
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])
|
90 |
|
91 |
|
92 |
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
|
models/common.py
CHANGED
@@ -464,10 +464,11 @@ class DetectMultiBackend(nn.Module):
|
|
464 |
|
465 |
def warmup(self, imgsz=(1, 3, 640, 640)):
|
466 |
# Warmup model by running inference once
|
467 |
-
if self.pt
|
468 |
-
if
|
469 |
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
470 |
-
self.
|
|
|
471 |
|
472 |
@staticmethod
|
473 |
def model_type(p='path/to/model.pt'):
|
|
|
464 |
|
465 |
def warmup(self, imgsz=(1, 3, 640, 640)):
|
466 |
# Warmup model by running inference once
|
467 |
+
if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types
|
468 |
+
if self.device.type != 'cpu': # only warmup GPU models
|
469 |
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
470 |
+
for _ in range(2 if self.jit else 1): #
|
471 |
+
self.forward(im) # warmup
|
472 |
|
473 |
@staticmethod
|
474 |
def model_type(p='path/to/model.pt'):
|
utils/benchmarks.py
CHANGED
@@ -19,6 +19,7 @@ TensorFlow.js | `tfjs` | yolov5s_web_model/
|
|
19 |
Requirements:
|
20 |
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
|
21 |
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
|
|
|
22 |
|
23 |
Usage:
|
24 |
$ python utils/benchmarks.py --weights yolov5s.pt --img 640
|
@@ -41,20 +42,29 @@ import export
|
|
41 |
import val
|
42 |
from utils import notebook_init
|
43 |
from utils.general import LOGGER, print_args
|
|
|
44 |
|
45 |
|
46 |
def run(weights=ROOT / 'yolov5s.pt', # weights path
|
47 |
imgsz=640, # inference size (pixels)
|
48 |
batch_size=1, # batch size
|
49 |
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
|
|
|
|
50 |
):
|
51 |
y, t = [], time.time()
|
52 |
formats = export.export_formats()
|
53 |
-
|
|
|
54 |
try:
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
assert suffix in str(w), 'export failed'
|
57 |
-
result = val.run(data, w, batch_size, imgsz
|
58 |
metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls))
|
59 |
speeds = result[2] # times (preprocess, inference, postprocess)
|
60 |
y.append([name, metrics[3], speeds[1]]) # mAP, t_inference
|
@@ -78,6 +88,8 @@ def parse_opt():
|
|
78 |
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
|
79 |
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
80 |
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
|
|
|
|
|
81 |
opt = parser.parse_args()
|
82 |
print_args(FILE.stem, opt)
|
83 |
return opt
|
|
|
19 |
Requirements:
|
20 |
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
|
21 |
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
|
22 |
+
$ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT
|
23 |
|
24 |
Usage:
|
25 |
$ python utils/benchmarks.py --weights yolov5s.pt --img 640
|
|
|
42 |
import val
|
43 |
from utils import notebook_init
|
44 |
from utils.general import LOGGER, print_args
|
45 |
+
from utils.torch_utils import select_device
|
46 |
|
47 |
|
48 |
def run(weights=ROOT / 'yolov5s.pt', # weights path
|
49 |
imgsz=640, # inference size (pixels)
|
50 |
batch_size=1, # batch size
|
51 |
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
|
52 |
+
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
53 |
+
half=False, # use FP16 half-precision inference
|
54 |
):
|
55 |
y, t = [], time.time()
|
56 |
formats = export.export_formats()
|
57 |
+
device = select_device(device)
|
58 |
+
for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable)
|
59 |
try:
|
60 |
+
if device.type != 'cpu':
|
61 |
+
assert gpu, f'{name} inference not supported on GPU'
|
62 |
+
if f == '-':
|
63 |
+
w = weights # PyTorch format
|
64 |
+
else:
|
65 |
+
w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others
|
66 |
assert suffix in str(w), 'export failed'
|
67 |
+
result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half)
|
68 |
metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls))
|
69 |
speeds = result[2] # times (preprocess, inference, postprocess)
|
70 |
y.append([name, metrics[3], speeds[1]]) # mAP, t_inference
|
|
|
88 |
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
|
89 |
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
90 |
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
|
91 |
+
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
92 |
+
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
93 |
opt = parser.parse_args()
|
94 |
print_args(FILE.stem, opt)
|
95 |
return opt
|