glenn-jocher commited on
Commit
d8f1883
1 Parent(s): bceb57b

Update `profile()` for CUDA Memory allocation (#4239)

Browse files

* Update profile()

* Update profile()

* Update profile()

* Update profile()

* Update profile()

* Update profile()

* Update profile()

* Update profile()

* Update profile()

* Update profile()

* Update profile()

* Update profile()

* Cleanup

Files changed (2) hide show
  1. tutorial.ipynb +2 -2
  2. utils/torch_utils.py +45 -31
tutorial.ipynb CHANGED
@@ -1172,11 +1172,11 @@
1172
  },
1173
  "source": [
1174
  "# Profile\n",
1175
- "from utils.torch_utils import profile \n",
1176
  "\n",
1177
  "m1 = lambda x: x * torch.sigmoid(x)\n",
1178
  "m2 = torch.nn.SiLU()\n",
1179
- "profile(x=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)"
1180
  ],
1181
  "execution_count": null,
1182
  "outputs": []
 
1172
  },
1173
  "source": [
1174
  "# Profile\n",
1175
+ "from utils.torch_utils import profile\n",
1176
  "\n",
1177
  "m1 = lambda x: x * torch.sigmoid(x)\n",
1178
  "m2 = torch.nn.SiLU()\n",
1179
+ "results = profile(input=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)"
1180
  ],
1181
  "execution_count": null,
1182
  "outputs": []
utils/torch_utils.py CHANGED
@@ -98,42 +98,56 @@ def time_sync():
98
  return time.time()
99
 
100
 
101
- def profile(x, ops, n=100, device=None):
102
- # profile a pytorch module or list of modules. Example usage:
103
- # x = torch.randn(16, 3, 640, 640) # input
 
 
104
  # m1 = lambda x: x * torch.sigmoid(x)
105
  # m2 = nn.SiLU()
106
- # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
107
 
 
108
  device = device or select_device()
109
- x = x.to(device)
110
- x.requires_grad = True
111
- print(f"{'Params':>12s}{'GFLOPs':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
112
- for m in ops if isinstance(ops, list) else [ops]:
113
- m = m.to(device) if hasattr(m, 'to') else m # device
114
- m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
115
- dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
116
- try:
117
- flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
118
- except:
119
- flops = 0
120
-
121
- for _ in range(n):
122
- t[0] = time_sync()
123
- y = m(x)
124
- t[1] = time_sync()
125
  try:
126
- _ = y.sum().backward()
127
- t[2] = time_sync()
128
- except: # no backward method
129
- t[2] = float('nan')
130
- dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
131
- dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
132
-
133
- s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
134
- s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
135
- p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
136
- print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  def is_parallel(model):
 
98
  return time.time()
99
 
100
 
101
+ def profile(input, ops, n=10, device=None):
102
+ # YOLOv5 speed/memory/FLOPs profiler
103
+ #
104
+ # Usage:
105
+ # input = torch.randn(16, 3, 640, 640)
106
  # m1 = lambda x: x * torch.sigmoid(x)
107
  # m2 = nn.SiLU()
108
+ # profile(input, [m1, m2], n=100) # profile over 100 iterations
109
 
110
+ results = []
111
  device = device or select_device()
112
+ print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
113
+ f"{'input':>24s}{'output':>24s}")
114
+
115
+ for x in input if isinstance(input, list) else [input]:
116
+ x = x.to(device)
117
+ x.requires_grad = True
118
+ for m in ops if isinstance(ops, list) else [ops]:
119
+ m = m.to(device) if hasattr(m, 'to') else m # device
120
+ m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
121
+ tf, tb, t = 0., 0., [0., 0., 0.] # dt forward, backward
 
 
 
 
 
 
122
  try:
123
+ flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
124
+ except:
125
+ flops = 0
126
+
127
+ try:
128
+ for _ in range(n):
129
+ t[0] = time_sync()
130
+ y = m(x)
131
+ t[1] = time_sync()
132
+ try:
133
+ _ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward()
134
+ t[2] = time_sync()
135
+ except Exception as e: # no backward method
136
+ print(e)
137
+ t[2] = float('nan')
138
+ tf += (t[1] - t[0]) * 1000 / n # ms per op forward
139
+ tb += (t[2] - t[1]) * 1000 / n # ms per op backward
140
+ mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
141
+ s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
142
+ s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
143
+ p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
144
+ print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
145
+ results.append([p, flops, mem, tf, tb, s_in, s_out])
146
+ except Exception as e:
147
+ print(e)
148
+ results.append(None)
149
+ torch.cuda.empty_cache()
150
+ return results
151
 
152
 
153
  def is_parallel(model):