glenn-jocher pre-commit-ci[bot] commited on
Commit
cf4f3c3
1 Parent(s): d51f9b2

yolo.py profiling updates (#7178)

Browse files

* yolo.py profiling updates

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (1) hide show
  1. models/yolo.py +12 -14
models/yolo.py CHANGED
@@ -25,7 +25,8 @@ from models.experimental import *
25
  from utils.autoanchor import check_anchor_order
26
  from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
27
  from utils.plots import feature_visualization
28
- from utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device, time_sync
 
29
 
30
  try:
31
  import thop # for FLOPs computation
@@ -300,8 +301,10 @@ def parse_model(d, ch): # model_dict, input_channels(3)
300
  if __name__ == '__main__':
301
  parser = argparse.ArgumentParser()
302
  parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
 
303
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
304
  parser.add_argument('--profile', action='store_true', help='profile model speed')
 
305
  parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
306
  opt = parser.parse_args()
307
  opt.cfg = check_yaml(opt.cfg) # check YAML
@@ -309,24 +312,19 @@ if __name__ == '__main__':
309
  device = select_device(opt.device)
310
 
311
  # Create model
 
312
  model = Model(opt.cfg).to(device)
313
 
314
- # Profile
315
- if opt.profile:
316
- model.eval().fuse()
317
- img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
318
- y = model(img, profile=True)
319
 
320
- # Test all models
321
- if opt.test:
 
 
322
  for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
323
  try:
324
  _ = Model(cfg)
325
  except Exception as e:
326
  print(f'Error in {cfg}: {e}')
327
-
328
- # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
329
- # from torch.utils.tensorboard import SummaryWriter
330
- # tb_writer = SummaryWriter('.')
331
- # LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
332
- # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
 
25
  from utils.autoanchor import check_anchor_order
26
  from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
27
  from utils.plots import feature_visualization
28
+ from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
29
+ time_sync)
30
 
31
  try:
32
  import thop # for FLOPs computation
 
301
  if __name__ == '__main__':
302
  parser = argparse.ArgumentParser()
303
  parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
304
+ parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
305
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
306
  parser.add_argument('--profile', action='store_true', help='profile model speed')
307
+ parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
308
  parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
309
  opt = parser.parse_args()
310
  opt.cfg = check_yaml(opt.cfg) # check YAML
 
312
  device = select_device(opt.device)
313
 
314
  # Create model
315
+ im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
316
  model = Model(opt.cfg).to(device)
317
 
318
+ # Options
319
+ if opt.line_profile: # profile layer by layer
320
+ _ = model(im, profile=True)
 
 
321
 
322
+ elif opt.profile: # profile forward-backward
323
+ results = profile(input=im, ops=[model], n=3)
324
+
325
+ elif opt.test: # test all models
326
  for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
327
  try:
328
  _ = Model(cfg)
329
  except Exception as e:
330
  print(f'Error in {cfg}: {e}')