whoam glenn-jocher commited on
Commit
b292837
1 Parent(s): 41cc7ca

Fix ONNX export using --grid --simplify --dynamic simultaneously (#2982)

Browse files

* Update yolo.py

* Update export.py

* fix export grid

* Update export.py, remove detect export attribute

* rearrange if order

* remove --grid, default inplace=False

* rename exp_dynamic to onnx_dynamic, comment

* replace bs with 1 in anchor_grid[i] index 0

* Update export.py

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (2) hide show
  1. models/export.py +6 -4
  2. models/yolo.py +3 -4
models/export.py CHANGED
@@ -26,9 +26,9 @@ if __name__ == '__main__':
26
  parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
27
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
28
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
29
- parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
30
  parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
31
  parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
 
32
  parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
33
  parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
34
  opt = parser.parse_args()
@@ -60,9 +60,11 @@ if __name__ == '__main__':
60
  m.act = Hardswish()
61
  elif isinstance(m.act, nn.SiLU):
62
  m.act = SiLU()
63
- # elif isinstance(m, models.yolo.Detect):
64
- # m.forward = m.forward_export # assign forward (optional)
65
- model.model[-1].export = not opt.grid # set Detect() layer grid export
 
 
66
  for _ in range(2):
67
  y = model(img) # dry runs
68
  print(f"\n{colorstr('PyTorch:')} starting from {opt.weights} ({file_size(opt.weights):.1f} MB)")
 
26
  parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
27
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
28
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
 
29
  parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
30
  parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
31
+ parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
32
  parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
33
  parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
34
  opt = parser.parse_args()
 
60
  m.act = Hardswish()
61
  elif isinstance(m.act, nn.SiLU):
62
  m.act = SiLU()
63
+ elif isinstance(m, models.yolo.Detect):
64
+ m.inplace = opt.inplace
65
+ m.onnx_dynamic = opt.dynamic
66
+ # m.forward = m.forward_export # assign forward (optional)
67
+
68
  for _ in range(2):
69
  y = model(img) # dry runs
70
  print(f"\n{colorstr('PyTorch:')} starting from {opt.weights} ({file_size(opt.weights):.1f} MB)")
models/yolo.py CHANGED
@@ -24,7 +24,7 @@ except ImportError:
24
 
25
  class Detect(nn.Module):
26
  stride = None # strides computed during build
27
- export = False # onnx export
28
 
29
  def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
30
  super(Detect, self).__init__()
@@ -42,14 +42,13 @@ class Detect(nn.Module):
42
  def forward(self, x):
43
  # x = x.copy() # for profiling
44
  z = [] # inference output
45
- self.training |= self.export
46
  for i in range(self.nl):
47
  x[i] = self.m[i](x[i]) # conv
48
  bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
49
  x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
50
 
51
  if not self.training: # inference
52
- if self.grid[i].shape[2:4] != x[i].shape[2:4]:
53
  self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
54
 
55
  y = x[i].sigmoid()
@@ -58,7 +57,7 @@ class Detect(nn.Module):
58
  y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
59
  else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
60
  xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
61
- wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
62
  y = torch.cat((xy, wh, y[..., 4:]), -1)
63
  z.append(y.view(bs, -1, self.no))
64
 
 
24
 
25
  class Detect(nn.Module):
26
  stride = None # strides computed during build
27
+ onnx_dynamic = False # ONNX export parameter
28
 
29
  def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
30
  super(Detect, self).__init__()
 
42
  def forward(self, x):
43
  # x = x.copy() # for profiling
44
  z = [] # inference output
 
45
  for i in range(self.nl):
46
  x[i] = self.m[i](x[i]) # conv
47
  bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
48
  x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
49
 
50
  if not self.training: # inference
51
+ if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
52
  self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
53
 
54
  y = x[i].sigmoid()
 
57
  y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
58
  else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
59
  xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
60
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
61
  y = torch.cat((xy, wh, y[..., 4:]), -1)
62
  z.append(y.view(bs, -1, self.no))
63