Commit
•
b292837
1
Parent(s):
41cc7ca
Fix ONNX export using --grid --simplify --dynamic simultaneously (#2982)
Browse files* Update yolo.py
* Update export.py
* fix export grid
* Update export.py, remove detect export attribute
* rearrange if order
* remove --grid, default inplace=False
* rename exp_dynamic to onnx_dynamic, comment
* replace bs with 1 in anchor_grid[i] index 0
* Update export.py
Co-authored-by: Glenn Jocher <[email protected]>
- models/export.py +6 -4
- models/yolo.py +3 -4
models/export.py
CHANGED
@@ -26,9 +26,9 @@ if __name__ == '__main__':
|
|
26 |
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
|
27 |
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
|
28 |
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
29 |
-
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
|
30 |
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
31 |
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
|
|
32 |
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
|
33 |
parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
|
34 |
opt = parser.parse_args()
|
@@ -60,9 +60,11 @@ if __name__ == '__main__':
|
|
60 |
m.act = Hardswish()
|
61 |
elif isinstance(m.act, nn.SiLU):
|
62 |
m.act = SiLU()
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
66 |
for _ in range(2):
|
67 |
y = model(img) # dry runs
|
68 |
print(f"\n{colorstr('PyTorch:')} starting from {opt.weights} ({file_size(opt.weights):.1f} MB)")
|
|
|
26 |
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
|
27 |
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
|
28 |
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
|
|
29 |
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
30 |
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
|
31 |
+
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
|
32 |
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
|
33 |
parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
|
34 |
opt = parser.parse_args()
|
|
|
60 |
m.act = Hardswish()
|
61 |
elif isinstance(m.act, nn.SiLU):
|
62 |
m.act = SiLU()
|
63 |
+
elif isinstance(m, models.yolo.Detect):
|
64 |
+
m.inplace = opt.inplace
|
65 |
+
m.onnx_dynamic = opt.dynamic
|
66 |
+
# m.forward = m.forward_export # assign forward (optional)
|
67 |
+
|
68 |
for _ in range(2):
|
69 |
y = model(img) # dry runs
|
70 |
print(f"\n{colorstr('PyTorch:')} starting from {opt.weights} ({file_size(opt.weights):.1f} MB)")
|
models/yolo.py
CHANGED
@@ -24,7 +24,7 @@ except ImportError:
|
|
24 |
|
25 |
class Detect(nn.Module):
|
26 |
stride = None # strides computed during build
|
27 |
-
|
28 |
|
29 |
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
30 |
super(Detect, self).__init__()
|
@@ -42,14 +42,13 @@ class Detect(nn.Module):
|
|
42 |
def forward(self, x):
|
43 |
# x = x.copy() # for profiling
|
44 |
z = [] # inference output
|
45 |
-
self.training |= self.export
|
46 |
for i in range(self.nl):
|
47 |
x[i] = self.m[i](x[i]) # conv
|
48 |
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
49 |
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
50 |
|
51 |
if not self.training: # inference
|
52 |
-
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
53 |
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
54 |
|
55 |
y = x[i].sigmoid()
|
@@ -58,7 +57,7 @@ class Detect(nn.Module):
|
|
58 |
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
59 |
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
|
60 |
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
61 |
-
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
62 |
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
63 |
z.append(y.view(bs, -1, self.no))
|
64 |
|
|
|
24 |
|
25 |
class Detect(nn.Module):
|
26 |
stride = None # strides computed during build
|
27 |
+
onnx_dynamic = False # ONNX export parameter
|
28 |
|
29 |
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
30 |
super(Detect, self).__init__()
|
|
|
42 |
def forward(self, x):
|
43 |
# x = x.copy() # for profiling
|
44 |
z = [] # inference output
|
|
|
45 |
for i in range(self.nl):
|
46 |
x[i] = self.m[i](x[i]) # conv
|
47 |
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
48 |
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
49 |
|
50 |
if not self.training: # inference
|
51 |
+
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
|
52 |
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
|
53 |
|
54 |
y = x[i].sigmoid()
|
|
|
57 |
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
58 |
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
|
59 |
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
60 |
+
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
|
61 |
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
62 |
z.append(y.view(bs, -1, self.no))
|
63 |
|