glenn-jocher
commited on
Commit
•
d8f1883
1
Parent(s):
bceb57b
Update `profile()` for CUDA Memory allocation (#4239)
Browse files* Update profile()
* Update profile()
* Update profile()
* Update profile()
* Update profile()
* Update profile()
* Update profile()
* Update profile()
* Update profile()
* Update profile()
* Update profile()
* Update profile()
* Cleanup
- tutorial.ipynb +2 -2
- utils/torch_utils.py +45 -31
tutorial.ipynb
CHANGED
@@ -1172,11 +1172,11 @@
|
|
1172 |
},
|
1173 |
"source": [
|
1174 |
"# Profile\n",
|
1175 |
-
"from utils.torch_utils import profile
|
1176 |
"\n",
|
1177 |
"m1 = lambda x: x * torch.sigmoid(x)\n",
|
1178 |
"m2 = torch.nn.SiLU()\n",
|
1179 |
-
"profile(
|
1180 |
],
|
1181 |
"execution_count": null,
|
1182 |
"outputs": []
|
|
|
1172 |
},
|
1173 |
"source": [
|
1174 |
"# Profile\n",
|
1175 |
+
"from utils.torch_utils import profile\n",
|
1176 |
"\n",
|
1177 |
"m1 = lambda x: x * torch.sigmoid(x)\n",
|
1178 |
"m2 = torch.nn.SiLU()\n",
|
1179 |
+
"results = profile(input=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)"
|
1180 |
],
|
1181 |
"execution_count": null,
|
1182 |
"outputs": []
|
utils/torch_utils.py
CHANGED
@@ -98,42 +98,56 @@ def time_sync():
|
|
98 |
return time.time()
|
99 |
|
100 |
|
101 |
-
def profile(
|
102 |
-
#
|
103 |
-
#
|
|
|
|
|
104 |
# m1 = lambda x: x * torch.sigmoid(x)
|
105 |
# m2 = nn.SiLU()
|
106 |
-
# profile(
|
107 |
|
|
|
108 |
device = device or select_device()
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
for
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
flops = 0
|
120 |
-
|
121 |
-
for _ in range(n):
|
122 |
-
t[0] = time_sync()
|
123 |
-
y = m(x)
|
124 |
-
t[1] = time_sync()
|
125 |
try:
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
def is_parallel(model):
|
|
|
98 |
return time.time()
|
99 |
|
100 |
|
101 |
+
def profile(input, ops, n=10, device=None):
|
102 |
+
# YOLOv5 speed/memory/FLOPs profiler
|
103 |
+
#
|
104 |
+
# Usage:
|
105 |
+
# input = torch.randn(16, 3, 640, 640)
|
106 |
# m1 = lambda x: x * torch.sigmoid(x)
|
107 |
# m2 = nn.SiLU()
|
108 |
+
# profile(input, [m1, m2], n=100) # profile over 100 iterations
|
109 |
|
110 |
+
results = []
|
111 |
device = device or select_device()
|
112 |
+
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
113 |
+
f"{'input':>24s}{'output':>24s}")
|
114 |
+
|
115 |
+
for x in input if isinstance(input, list) else [input]:
|
116 |
+
x = x.to(device)
|
117 |
+
x.requires_grad = True
|
118 |
+
for m in ops if isinstance(ops, list) else [ops]:
|
119 |
+
m = m.to(device) if hasattr(m, 'to') else m # device
|
120 |
+
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
121 |
+
tf, tb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
try:
|
123 |
+
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
|
124 |
+
except:
|
125 |
+
flops = 0
|
126 |
+
|
127 |
+
try:
|
128 |
+
for _ in range(n):
|
129 |
+
t[0] = time_sync()
|
130 |
+
y = m(x)
|
131 |
+
t[1] = time_sync()
|
132 |
+
try:
|
133 |
+
_ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward()
|
134 |
+
t[2] = time_sync()
|
135 |
+
except Exception as e: # no backward method
|
136 |
+
print(e)
|
137 |
+
t[2] = float('nan')
|
138 |
+
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
139 |
+
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
140 |
+
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
141 |
+
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
|
142 |
+
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
|
143 |
+
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
|
144 |
+
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
145 |
+
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
146 |
+
except Exception as e:
|
147 |
+
print(e)
|
148 |
+
results.append(None)
|
149 |
+
torch.cuda.empty_cache()
|
150 |
+
return results
|
151 |
|
152 |
|
153 |
def is_parallel(model):
|