Commit
•
cf4f3c3
1
Parent(s):
d51f9b2
yolo.py profiling updates (#7178)
Browse files* yolo.py profiling updates
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- models/yolo.py +12 -14
models/yolo.py
CHANGED
@@ -25,7 +25,8 @@ from models.experimental import *
|
|
25 |
from utils.autoanchor import check_anchor_order
|
26 |
from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
|
27 |
from utils.plots import feature_visualization
|
28 |
-
from utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device,
|
|
|
29 |
|
30 |
try:
|
31 |
import thop # for FLOPs computation
|
@@ -300,8 +301,10 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
|
300 |
if __name__ == '__main__':
|
301 |
parser = argparse.ArgumentParser()
|
302 |
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
|
|
|
303 |
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
304 |
parser.add_argument('--profile', action='store_true', help='profile model speed')
|
|
|
305 |
parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
|
306 |
opt = parser.parse_args()
|
307 |
opt.cfg = check_yaml(opt.cfg) # check YAML
|
@@ -309,24 +312,19 @@ if __name__ == '__main__':
|
|
309 |
device = select_device(opt.device)
|
310 |
|
311 |
# Create model
|
|
|
312 |
model = Model(opt.cfg).to(device)
|
313 |
|
314 |
-
#
|
315 |
-
if opt.profile
|
316 |
-
model
|
317 |
-
img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
|
318 |
-
y = model(img, profile=True)
|
319 |
|
320 |
-
#
|
321 |
-
|
|
|
|
|
322 |
for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
|
323 |
try:
|
324 |
_ = Model(cfg)
|
325 |
except Exception as e:
|
326 |
print(f'Error in {cfg}: {e}')
|
327 |
-
|
328 |
-
# Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
|
329 |
-
# from torch.utils.tensorboard import SummaryWriter
|
330 |
-
# tb_writer = SummaryWriter('.')
|
331 |
-
# LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
|
332 |
-
# tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
|
|
|
25 |
from utils.autoanchor import check_anchor_order
|
26 |
from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
|
27 |
from utils.plots import feature_visualization
|
28 |
+
from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
|
29 |
+
time_sync)
|
30 |
|
31 |
try:
|
32 |
import thop # for FLOPs computation
|
|
|
301 |
if __name__ == '__main__':
|
302 |
parser = argparse.ArgumentParser()
|
303 |
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
|
304 |
+
parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
|
305 |
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
306 |
parser.add_argument('--profile', action='store_true', help='profile model speed')
|
307 |
+
parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
|
308 |
parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
|
309 |
opt = parser.parse_args()
|
310 |
opt.cfg = check_yaml(opt.cfg) # check YAML
|
|
|
312 |
device = select_device(opt.device)
|
313 |
|
314 |
# Create model
|
315 |
+
im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
|
316 |
model = Model(opt.cfg).to(device)
|
317 |
|
318 |
+
# Options
|
319 |
+
if opt.line_profile: # profile layer by layer
|
320 |
+
_ = model(im, profile=True)
|
|
|
|
|
321 |
|
322 |
+
elif opt.profile: # profile forward-backward
|
323 |
+
results = profile(input=im, ops=[model], n=3)
|
324 |
+
|
325 |
+
elif opt.test: # test all models
|
326 |
for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
|
327 |
try:
|
328 |
_ = Model(cfg)
|
329 |
except Exception as e:
|
330 |
print(f'Error in {cfg}: {e}')
|
|
|
|
|
|
|
|
|
|
|
|