glenn-jocher commited on
Commit
6e46617
1 Parent(s): cf298fb

AutoBatch checks against failed solutions (#8159)

Browse files

* AutoBatch checks against failed solutions

@kalenmike this is a simple improvement to AutoBatch to verify that returned solutions have not already failed, i.e. return batch-size 8 when 8 already produced CUDA out of memory.

This is a halfway fix until I can implement a 'final solution' that will actively verify the solved-for batch size rather than passively assume it works.

* Update autobatch.py

* Update autobatch.py

Files changed (1) hide show
  1. utils/autobatch.py +19 -10
utils/autobatch.py CHANGED
@@ -8,7 +8,7 @@ from copy import deepcopy
8
  import numpy as np
9
  import torch
10
 
11
- from utils.general import LOGGER, colorstr
12
  from utils.torch_utils import profile
13
 
14
 
@@ -26,6 +26,7 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
26
  # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
27
  # print(autobatch(model))
28
 
 
29
  prefix = colorstr('AutoBatch: ')
30
  LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
31
  device = next(model.parameters()).device # get model device
@@ -33,25 +34,33 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
33
  LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
34
  return batch_size
35
 
 
36
  gb = 1 << 30 # bytes to GiB (1024 ** 3)
37
  d = str(device).upper() # 'CUDA:0'
38
  properties = torch.cuda.get_device_properties(device) # device properties
39
- t = properties.total_memory / gb # (GiB)
40
- r = torch.cuda.memory_reserved(device) / gb # (GiB)
41
- a = torch.cuda.memory_allocated(device) / gb # (GiB)
42
- f = t - (r + a) # free inside reserved
43
  LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
44
 
 
45
  batch_sizes = [1, 2, 4, 8, 16]
46
  try:
47
  img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
48
- y = profile(img, model, n=3, device=device)
49
  except Exception as e:
50
  LOGGER.warning(f'{prefix}{e}')
51
 
52
- y = [x[2] for x in y if x] # memory [2]
53
- batch_sizes = batch_sizes[:len(y)]
54
- p = np.polyfit(batch_sizes, y, deg=1) # first degree polynomial fit
55
  b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
56
- LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%)')
 
 
 
 
 
 
57
  return b
 
8
  import numpy as np
9
  import torch
10
 
11
+ from utils.general import LOGGER, colorstr, emojis
12
  from utils.torch_utils import profile
13
 
14
 
 
26
  # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
27
  # print(autobatch(model))
28
 
29
+ # Check device
30
  prefix = colorstr('AutoBatch: ')
31
  LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
32
  device = next(model.parameters()).device # get model device
 
34
  LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
35
  return batch_size
36
 
37
+ # Inspect CUDA memory
38
  gb = 1 << 30 # bytes to GiB (1024 ** 3)
39
  d = str(device).upper() # 'CUDA:0'
40
  properties = torch.cuda.get_device_properties(device) # device properties
41
+ t = properties.total_memory / gb # GiB total
42
+ r = torch.cuda.memory_reserved(device) / gb # GiB reserved
43
+ a = torch.cuda.memory_allocated(device) / gb # GiB allocated
44
+ f = t - (r + a) # GiB free
45
  LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
46
 
47
+ # Profile batch sizes
48
  batch_sizes = [1, 2, 4, 8, 16]
49
  try:
50
  img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
51
+ results = profile(img, model, n=3, device=device)
52
  except Exception as e:
53
  LOGGER.warning(f'{prefix}{e}')
54
 
55
+ # Fit a solution
56
+ y = [x[2] for x in results if x] # memory [2]
57
+ p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
58
  b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
59
+ if None in results: # some sizes failed
60
+ i = results.index(None) # first fail index
61
+ if b >= batch_sizes[i]: # y intercept above failure point
62
+ b = batch_sizes[max(i - 1, 0)] # select prior safe point
63
+
64
+ fraction = np.polyval(p, b) / t # actual fraction predicted
65
+ LOGGER.info(emojis(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅'))
66
  return b