glenn-jocher
commited on
Commit
•
12b0c04
1
Parent(s):
b810b21
model fusion and onnx export
Browse files- models/common.py +3 -0
- models/onnx_export.py +8 -7
- models/yolo.py +9 -0
models/common.py
CHANGED
@@ -20,6 +20,9 @@ class Conv(nn.Module): # standard convolution
|
|
20 |
def forward(self, x):
|
21 |
return self.act(self.bn(self.conv(x)))
|
22 |
|
|
|
|
|
|
|
23 |
|
24 |
class Bottleneck(nn.Module):
|
25 |
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
|
|
|
20 |
def forward(self, x):
|
21 |
return self.act(self.bn(self.conv(x)))
|
22 |
|
23 |
+
def fuseforward(self, x):
|
24 |
+
return self.act(self.conv(x))
|
25 |
+
|
26 |
|
27 |
class Bottleneck(nn.Module):
|
28 |
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
|
models/onnx_export.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
# Exports a pytorch *.pt model to *.onnx format
|
2 |
-
#
|
3 |
-
# $ python models/onnx_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
|
4 |
|
5 |
import argparse
|
6 |
|
@@ -10,10 +10,11 @@ from models.common import *
|
|
10 |
|
11 |
if __name__ == '__main__':
|
12 |
parser = argparse.ArgumentParser()
|
13 |
-
parser.add_argument('--weights', default='./weights/yolov5s.pt', help='weights path')
|
14 |
-
parser.add_argument('--img-size', default=640, help='inference size (pixels)')
|
15 |
-
parser.add_argument('--batch-size', default=1, help='batch size')
|
16 |
opt = parser.parse_args()
|
|
|
17 |
|
18 |
# Parameters
|
19 |
f = opt.weights.replace('.pt', '.onnx') # onnx filename
|
@@ -23,7 +24,7 @@ if __name__ == '__main__':
|
|
23 |
google_utils.attempt_download(opt.weights)
|
24 |
model = torch.load(opt.weights)['model']
|
25 |
model.eval()
|
26 |
-
# model.fuse()
|
27 |
|
28 |
# Export to onnx
|
29 |
model.model[-1].export = True # set Detect() layer export=True
|
|
|
1 |
+
# Exports a pytorch *.pt model to *.onnx format
|
2 |
+
# Example usage (run from ./yolov5 directory):
|
3 |
+
# $ export PYTHONPATH="$PWD" && python models/onnx_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
|
4 |
|
5 |
import argparse
|
6 |
|
|
|
10 |
|
11 |
if __name__ == '__main__':
|
12 |
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--weights', type=str, default='./weights/yolov5s.pt', help='weights path')
|
14 |
+
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
|
15 |
+
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
16 |
opt = parser.parse_args()
|
17 |
+
print(opt)
|
18 |
|
19 |
# Parameters
|
20 |
f = opt.weights.replace('.pt', '.onnx') # onnx filename
|
|
|
24 |
google_utils.attempt_download(opt.weights)
|
25 |
model = torch.load(opt.weights)['model']
|
26 |
model.eval()
|
27 |
+
# model.fuse()
|
28 |
|
29 |
# Export to onnx
|
30 |
model.model[-1].export = True # set Detect() layer export=True
|
models/yolo.py
CHANGED
@@ -123,6 +123,15 @@ class Model(nn.Module):
|
|
123 |
b = self.model[f].bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
|
124 |
print(('%g Conv2d.bias:' + '%10.3g' * 6) % (f, *b[:5].mean(1).tolist(), b[5:].mean()))
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
def parse_model(md, ch): # model_dict, input_channels(3)
|
128 |
print('\n%3s%15s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
|
|
|
123 |
b = self.model[f].bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
|
124 |
print(('%g Conv2d.bias:' + '%10.3g' * 6) % (f, *b[:5].mean(1).tolist(), b[5:].mean()))
|
125 |
|
126 |
+
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
127 |
+
print('Fusing layers...')
|
128 |
+
for m in self.model.modules():
|
129 |
+
if type(m) is Conv:
|
130 |
+
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
|
131 |
+
m.bn = None # remove batchnorm
|
132 |
+
m.forward = m.fuseforward # update forward
|
133 |
+
torch_utils.model_info(self)
|
134 |
+
|
135 |
|
136 |
def parse_model(md, ch): # model_dict, input_channels(3)
|
137 |
print('\n%3s%15s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
|