yujun glenn-jocher commited on
Commit
05a955a
1 Parent(s): af8aee7

FLOPS computation device bug fix (#1447)

Browse files

* Update torch_utils.py

fix issue#113 , inputs device should be same with model parameters' device

* Update torch_utils.py

* Update torch_utils.py

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. utils/torch_utils.py +2 -1
utils/torch_utils.py CHANGED
@@ -153,7 +153,8 @@ def model_info(model, verbose=False, img_size=640):
153
  try: # FLOPS
154
  from thop import profile
155
  stride = int(model.stride.max())
156
- flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2
 
157
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
158
  fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
159
  except (ImportError, Exception):
 
153
  try: # FLOPS
154
  from thop import profile
155
  stride = int(model.stride.max())
156
+ img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
157
+ flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
158
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
159
  fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
160
  except (ImportError, Exception):