Commit
•
05a955a
1
Parent(s):
af8aee7
FLOPS computation device bug fix (#1447)
Browse files* Update torch_utils.py
fix issue#113 , inputs device should be same with model parameters' device
* Update torch_utils.py
* Update torch_utils.py
Co-authored-by: Glenn Jocher <[email protected]>
- utils/torch_utils.py +2 -1
utils/torch_utils.py
CHANGED
@@ -153,7 +153,8 @@ def model_info(model, verbose=False, img_size=640):
|
|
153 |
try: # FLOPS
|
154 |
from thop import profile
|
155 |
stride = int(model.stride.max())
|
156 |
-
|
|
|
157 |
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
158 |
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
|
159 |
except (ImportError, Exception):
|
|
|
153 |
try: # FLOPS
|
154 |
from thop import profile
|
155 |
stride = int(model.stride.max())
|
156 |
+
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
|
157 |
+
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
|
158 |
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
159 |
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
|
160 |
except (ImportError, Exception):
|