glenn-jocher
commited on
Commit
•
3004fb5
1
Parent(s):
0bd9c48
Automatic m.half() profile on x.half()
Browse files- utils/torch_utils.py +2 -1
utils/torch_utils.py
CHANGED
@@ -88,7 +88,8 @@ def profile(x, ops, n=100, device=None):
|
|
88 |
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
89 |
print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
|
90 |
for m in ops if isinstance(ops, list) else [ops]:
|
91 |
-
m = m.to(device) if hasattr(m, 'to') else m
|
|
|
92 |
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
93 |
try:
|
94 |
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
|
|
|
88 |
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
89 |
print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
|
90 |
for m in ops if isinstance(ops, list) else [ops]:
|
91 |
+
m = m.to(device) if hasattr(m, 'to') else m # device
|
92 |
+
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
|
93 |
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
94 |
try:
|
95 |
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
|