Add TensorFlow and TFLite export (#1127)
Browse files* Add models/tf.py for TensorFlow and TFLite export
* Set auto=False for int8 calibration
* Update requirements.txt for TensorFlow and TFLite export
* Read anchors directly from PyTorch weights
* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export
* Remove check_anchor_order, check_file, set_logging from import
* Reformat code and optimize imports
* Autodownload model and check cfg
* update --source path, img-size to 320, single output
* Adjust representative_dataset
* Put representative dataset in tfl_int8 block
* detect.py TF inference
* weights to string
* weights to string
* cleanup tf.py
* Add --dynamic-batch-size
* Add xywh normalization to reduce calibration error
* Update requirements.txt
TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error
* Fix imports
Move C3 from models.experimental to models.common
* Add models/tf.py for TensorFlow and TFLite export
* Set auto=False for int8 calibration
* Update requirements.txt for TensorFlow and TFLite export
* Read anchors directly from PyTorch weights
* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export
* Remove check_anchor_order, check_file, set_logging from import
* Reformat code and optimize imports
* Autodownload model and check cfg
* update --source path, img-size to 320, single output
* Adjust representative_dataset
* detect.py TF inference
* Put representative dataset in tfl_int8 block
* weights to string
* weights to string
* cleanup tf.py
* Add --dynamic-batch-size
* Add xywh normalization to reduce calibration error
* Update requirements.txt
TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error
* Fix imports
Move C3 from models.experimental to models.common
* implement C3() and SiLU()
* Fix reshape dim to support dynamic batching
* Add epsilon argument in tf_BN, which is different between TF and PT
* Set stride to None if not using PyTorch, and do not warmup without PyTorch
* Add list support in check_img_size()
* Add list input support in detect.py
* sys.path.append('./') to run from yolov5/
* Add int8 quantization support for TensorFlow 2.5
* Add get_coco128.sh
* Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU)
* Update requirements.txt
* Replace torch.load() with attempt_load()
* Update requirements.txt
* Add --tf-raw-resize to set half_pixel_centers=False
* Add --agnostic-nms for TF class-agnostic NMS
* Cleanup after merge
* Cleanup2 after merge
* Cleanup3 after merge
* Add tf.py docstring with credit and usage
* pb saved_model and tflite use only one model in detect.py
* Add use cases in docstring of tf.py
* Remove redundant `stride` definition
* Remove keras direct import
* Fix `check_requirements(('tensorflow>=2.4.1',))`
Co-authored-by: Glenn Jocher <[email protected]>
- detect.py +54 -10
- models/experimental.py +6 -2
- models/tf.py +558 -0
- requirements.txt +1 -0
- utils/datasets.py +7 -5
@@ -12,6 +12,7 @@ import time
|
|
12 |
from pathlib import Path
|
13 |
|
14 |
import cv2
|
|
|
15 |
import torch
|
16 |
import torch.backends.cudnn as cudnn
|
17 |
|
@@ -51,6 +52,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
|
|
51 |
hide_labels=False, # hide labels
|
52 |
hide_conf=False, # hide confidences
|
53 |
half=False, # use FP16 half-precision inference
|
|
|
54 |
):
|
55 |
save_img = not nosave and not source.endswith('.txt') # save inference images
|
56 |
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
|
@@ -68,7 +70,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
|
|
68 |
# Load model
|
69 |
w = weights[0] if isinstance(weights, list) else weights
|
70 |
classify, suffix = False, Path(w).suffix.lower()
|
71 |
-
pt, onnx, tflite, pb,
|
72 |
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
|
73 |
if pt:
|
74 |
model = attempt_load(weights, map_location=device) # load FP32 model
|
@@ -83,30 +85,49 @@ def run(weights='yolov5s.pt', # model.pt path(s)
|
|
83 |
check_requirements(('onnx', 'onnxruntime'))
|
84 |
import onnxruntime
|
85 |
session = onnxruntime.InferenceSession(w, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
imgsz = check_img_size(imgsz, s=stride) # check image size
|
87 |
|
88 |
# Dataloader
|
89 |
if webcam:
|
90 |
view_img = check_imshow()
|
91 |
cudnn.benchmark = True # set True to speed up constant image size inference
|
92 |
-
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
|
93 |
bs = len(dataset) # batch_size
|
94 |
else:
|
95 |
-
dataset = LoadImages(source, img_size=imgsz, stride=stride)
|
96 |
bs = 1 # batch_size
|
97 |
vid_path, vid_writer = [None] * bs, [None] * bs
|
98 |
|
99 |
# Run inference
|
100 |
if pt and device.type != 'cpu':
|
101 |
-
model(torch.zeros(1, 3, imgsz
|
102 |
t0 = time.time()
|
103 |
for path, img, im0s, vid_cap in dataset:
|
104 |
-
if
|
|
|
|
|
105 |
img = torch.from_numpy(img).to(device)
|
106 |
img = img.half() if half else img.float() # uint8 to fp16/32
|
107 |
-
|
108 |
-
img = img.astype('float32')
|
109 |
-
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
110 |
if len(img.shape) == 3:
|
111 |
img = img[None] # expand for batch dim
|
112 |
|
@@ -117,6 +138,27 @@ def run(weights='yolov5s.pt', # model.pt path(s)
|
|
117 |
pred = model(img, augment=augment, visualize=visualize)[0]
|
118 |
elif onnx:
|
119 |
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
# NMS
|
122 |
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
@@ -202,9 +244,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)
|
|
202 |
|
203 |
def parse_opt():
|
204 |
parser = argparse.ArgumentParser()
|
205 |
-
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.
|
206 |
parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
|
207 |
-
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size
|
208 |
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
|
209 |
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
|
210 |
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
|
@@ -226,7 +268,9 @@ def parse_opt():
|
|
226 |
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
|
227 |
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
|
228 |
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
|
|
229 |
opt = parser.parse_args()
|
|
|
230 |
return opt
|
231 |
|
232 |
|
|
|
12 |
from pathlib import Path
|
13 |
|
14 |
import cv2
|
15 |
+
import numpy as np
|
16 |
import torch
|
17 |
import torch.backends.cudnn as cudnn
|
18 |
|
|
|
52 |
hide_labels=False, # hide labels
|
53 |
hide_conf=False, # hide confidences
|
54 |
half=False, # use FP16 half-precision inference
|
55 |
+
tfl_int8=False, # INT8 quantized TFLite model
|
56 |
):
|
57 |
save_img = not nosave and not source.endswith('.txt') # save inference images
|
58 |
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
|
|
|
70 |
# Load model
|
71 |
w = weights[0] if isinstance(weights, list) else weights
|
72 |
classify, suffix = False, Path(w).suffix.lower()
|
73 |
+
pt, onnx, tflite, pb, saved_model = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
|
74 |
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
|
75 |
if pt:
|
76 |
model = attempt_load(weights, map_location=device) # load FP32 model
|
|
|
85 |
check_requirements(('onnx', 'onnxruntime'))
|
86 |
import onnxruntime
|
87 |
session = onnxruntime.InferenceSession(w, None)
|
88 |
+
else: # TensorFlow models
|
89 |
+
check_requirements(('tensorflow>=2.4.1',))
|
90 |
+
import tensorflow as tf
|
91 |
+
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
92 |
+
def wrap_frozen_graph(gd, inputs, outputs):
|
93 |
+
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
|
94 |
+
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
|
95 |
+
tf.nest.map_structure(x.graph.as_graph_element, outputs))
|
96 |
+
|
97 |
+
graph_def = tf.Graph().as_graph_def()
|
98 |
+
graph_def.ParseFromString(open(w, 'rb').read())
|
99 |
+
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
|
100 |
+
elif saved_model:
|
101 |
+
model = tf.keras.models.load_model(w)
|
102 |
+
elif tflite:
|
103 |
+
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
|
104 |
+
interpreter.allocate_tensors() # allocate
|
105 |
+
input_details = interpreter.get_input_details() # inputs
|
106 |
+
output_details = interpreter.get_output_details() # outputs
|
107 |
imgsz = check_img_size(imgsz, s=stride) # check image size
|
108 |
|
109 |
# Dataloader
|
110 |
if webcam:
|
111 |
view_img = check_imshow()
|
112 |
cudnn.benchmark = True # set True to speed up constant image size inference
|
113 |
+
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
|
114 |
bs = len(dataset) # batch_size
|
115 |
else:
|
116 |
+
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
|
117 |
bs = 1 # batch_size
|
118 |
vid_path, vid_writer = [None] * bs, [None] * bs
|
119 |
|
120 |
# Run inference
|
121 |
if pt and device.type != 'cpu':
|
122 |
+
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
|
123 |
t0 = time.time()
|
124 |
for path, img, im0s, vid_cap in dataset:
|
125 |
+
if onnx:
|
126 |
+
img = img.astype('float32')
|
127 |
+
else:
|
128 |
img = torch.from_numpy(img).to(device)
|
129 |
img = img.half() if half else img.float() # uint8 to fp16/32
|
130 |
+
img = img / 255.0 # 0 - 255 to 0.0 - 1.0
|
|
|
|
|
131 |
if len(img.shape) == 3:
|
132 |
img = img[None] # expand for batch dim
|
133 |
|
|
|
138 |
pred = model(img, augment=augment, visualize=visualize)[0]
|
139 |
elif onnx:
|
140 |
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
|
141 |
+
else: # tensorflow model (tflite, pb, saved_model)
|
142 |
+
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
|
143 |
+
if pb:
|
144 |
+
pred = frozen_func(x=tf.constant(imn)).numpy()
|
145 |
+
elif saved_model:
|
146 |
+
pred = model(imn, training=False).numpy()
|
147 |
+
elif tflite:
|
148 |
+
if tfl_int8:
|
149 |
+
scale, zero_point = input_details[0]['quantization']
|
150 |
+
imn = (imn / scale + zero_point).astype(np.uint8)
|
151 |
+
interpreter.set_tensor(input_details[0]['index'], imn)
|
152 |
+
interpreter.invoke()
|
153 |
+
pred = interpreter.get_tensor(output_details[0]['index'])
|
154 |
+
if tfl_int8:
|
155 |
+
scale, zero_point = output_details[0]['quantization']
|
156 |
+
pred = (pred.astype(np.float32) - zero_point) * scale
|
157 |
+
pred[..., 0] *= imgsz[1] # x
|
158 |
+
pred[..., 1] *= imgsz[0] # y
|
159 |
+
pred[..., 2] *= imgsz[1] # w
|
160 |
+
pred[..., 3] *= imgsz[0] # h
|
161 |
+
pred = torch.tensor(pred)
|
162 |
|
163 |
# NMS
|
164 |
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
|
|
244 |
|
245 |
def parse_opt():
|
246 |
parser = argparse.ArgumentParser()
|
247 |
+
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pb', help='model.pt path(s)')
|
248 |
parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
|
249 |
+
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
|
250 |
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
|
251 |
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
|
252 |
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
|
|
|
268 |
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
|
269 |
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
|
270 |
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
271 |
+
parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
|
272 |
opt = parser.parse_args()
|
273 |
+
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
|
274 |
return opt
|
275 |
|
276 |
|
@@ -85,14 +85,18 @@ class Ensemble(nn.ModuleList):
|
|
85 |
return y, None # inference, train output
|
86 |
|
87 |
|
88 |
-
def attempt_load(weights, map_location=None, inplace=True):
|
89 |
from models.yolo import Detect, Model
|
90 |
|
91 |
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
92 |
model = Ensemble()
|
93 |
for w in weights if isinstance(weights, list) else [weights]:
|
94 |
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
|
95 |
-
|
|
|
|
|
|
|
|
|
96 |
|
97 |
# Compatibility updates
|
98 |
for m in model.modules():
|
|
|
85 |
return y, None # inference, train output
|
86 |
|
87 |
|
88 |
+
def attempt_load(weights, map_location=None, inplace=True, fuse=True):
|
89 |
from models.yolo import Detect, Model
|
90 |
|
91 |
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
92 |
model = Ensemble()
|
93 |
for w in weights if isinstance(weights, list) else [weights]:
|
94 |
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
|
95 |
+
if fuse:
|
96 |
+
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
|
97 |
+
else:
|
98 |
+
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
|
99 |
+
|
100 |
|
101 |
# Compatibility updates
|
102 |
for m in model.modules():
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
2 |
+
"""
|
3 |
+
TensorFlow/Keras and TFLite versions of YOLOv5
|
4 |
+
Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
|
5 |
+
|
6 |
+
Usage:
|
7 |
+
$ python models/tf.py --weights yolov5s.pt --cfg yolov5s.yaml
|
8 |
+
|
9 |
+
Export int8 TFLite models:
|
10 |
+
$ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --tfl-int8 \
|
11 |
+
--source path/to/images/ --ncalib 100
|
12 |
+
|
13 |
+
Detection:
|
14 |
+
$ python detect.py --weights yolov5s.pb --img 320
|
15 |
+
$ python detect.py --weights yolov5s_saved_model --img 320
|
16 |
+
$ python detect.py --weights yolov5s-fp16.tflite --img 320
|
17 |
+
$ python detect.py --weights yolov5s-int8.tflite --img 320 --tfl-int8
|
18 |
+
|
19 |
+
For TensorFlow.js:
|
20 |
+
$ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --img 320 --tf-nms --agnostic-nms
|
21 |
+
$ pip install tensorflowjs
|
22 |
+
$ tensorflowjs_converter \
|
23 |
+
--input_format=tf_frozen_model \
|
24 |
+
--output_node_names='Identity,Identity_1,Identity_2,Identity_3' \
|
25 |
+
yolov5s.pb \
|
26 |
+
web_model
|
27 |
+
$ # Edit web_model/model.json to sort Identity* in ascending order
|
28 |
+
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
29 |
+
$ npm install
|
30 |
+
$ ln -s ../../yolov5/web_model public/web_model
|
31 |
+
$ npm start
|
32 |
+
"""
|
33 |
+
|
34 |
+
import argparse
|
35 |
+
import logging
|
36 |
+
import os
|
37 |
+
import sys
|
38 |
+
import traceback
|
39 |
+
from copy import deepcopy
|
40 |
+
from pathlib import Path
|
41 |
+
|
42 |
+
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
43 |
+
|
44 |
+
import numpy as np
|
45 |
+
import tensorflow as tf
|
46 |
+
import torch
|
47 |
+
import torch.nn as nn
|
48 |
+
import yaml
|
49 |
+
from tensorflow import keras
|
50 |
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
51 |
+
|
52 |
+
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
|
53 |
+
from models.experimental import MixConv2d, CrossConv, attempt_load
|
54 |
+
from models.yolo import Detect
|
55 |
+
from utils.datasets import LoadImages
|
56 |
+
from utils.general import make_divisible, check_file, check_dataset
|
57 |
+
|
58 |
+
logger = logging.getLogger(__name__)
|
59 |
+
|
60 |
+
|
61 |
+
class tf_BN(keras.layers.Layer):
|
62 |
+
# TensorFlow BatchNormalization wrapper
|
63 |
+
def __init__(self, w=None):
|
64 |
+
super(tf_BN, self).__init__()
|
65 |
+
self.bn = keras.layers.BatchNormalization(
|
66 |
+
beta_initializer=keras.initializers.Constant(w.bias.numpy()),
|
67 |
+
gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
|
68 |
+
moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
|
69 |
+
moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
|
70 |
+
epsilon=w.eps)
|
71 |
+
|
72 |
+
def call(self, inputs):
|
73 |
+
return self.bn(inputs)
|
74 |
+
|
75 |
+
|
76 |
+
class tf_Pad(keras.layers.Layer):
|
77 |
+
def __init__(self, pad):
|
78 |
+
super(tf_Pad, self).__init__()
|
79 |
+
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
|
80 |
+
|
81 |
+
def call(self, inputs):
|
82 |
+
return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
|
83 |
+
|
84 |
+
|
85 |
+
class tf_Conv(keras.layers.Layer):
|
86 |
+
# Standard convolution
|
87 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
|
88 |
+
# ch_in, ch_out, weights, kernel, stride, padding, groups
|
89 |
+
super(tf_Conv, self).__init__()
|
90 |
+
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
|
91 |
+
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
|
92 |
+
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
|
93 |
+
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
|
94 |
+
|
95 |
+
conv = keras.layers.Conv2D(
|
96 |
+
c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False,
|
97 |
+
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()))
|
98 |
+
self.conv = conv if s == 1 else keras.Sequential([tf_Pad(autopad(k, p)), conv])
|
99 |
+
self.bn = tf_BN(w.bn) if hasattr(w, 'bn') else tf.identity
|
100 |
+
|
101 |
+
# YOLOv5 activations
|
102 |
+
if isinstance(w.act, nn.LeakyReLU):
|
103 |
+
self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
|
104 |
+
elif isinstance(w.act, nn.Hardswish):
|
105 |
+
self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
|
106 |
+
elif isinstance(w.act, nn.SiLU):
|
107 |
+
self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
|
108 |
+
|
109 |
+
def call(self, inputs):
|
110 |
+
return self.act(self.bn(self.conv(inputs)))
|
111 |
+
|
112 |
+
|
113 |
+
class tf_Focus(keras.layers.Layer):
|
114 |
+
# Focus wh information into c-space
|
115 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
|
116 |
+
# ch_in, ch_out, kernel, stride, padding, groups
|
117 |
+
super(tf_Focus, self).__init__()
|
118 |
+
self.conv = tf_Conv(c1 * 4, c2, k, s, p, g, act, w.conv)
|
119 |
+
|
120 |
+
def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
|
121 |
+
# inputs = inputs / 255. # normalize 0-255 to 0-1
|
122 |
+
return self.conv(tf.concat([inputs[:, ::2, ::2, :],
|
123 |
+
inputs[:, 1::2, ::2, :],
|
124 |
+
inputs[:, ::2, 1::2, :],
|
125 |
+
inputs[:, 1::2, 1::2, :]], 3))
|
126 |
+
|
127 |
+
|
128 |
+
class tf_Bottleneck(keras.layers.Layer):
|
129 |
+
# Standard bottleneck
|
130 |
+
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
|
131 |
+
super(tf_Bottleneck, self).__init__()
|
132 |
+
c_ = int(c2 * e) # hidden channels
|
133 |
+
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
|
134 |
+
self.cv2 = tf_Conv(c_, c2, 3, 1, g=g, w=w.cv2)
|
135 |
+
self.add = shortcut and c1 == c2
|
136 |
+
|
137 |
+
def call(self, inputs):
|
138 |
+
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
|
139 |
+
|
140 |
+
|
141 |
+
class tf_Conv2d(keras.layers.Layer):
|
142 |
+
# Substitution for PyTorch nn.Conv2D
|
143 |
+
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
|
144 |
+
super(tf_Conv2d, self).__init__()
|
145 |
+
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
|
146 |
+
self.conv = keras.layers.Conv2D(
|
147 |
+
c2, k, s, 'VALID', use_bias=bias,
|
148 |
+
kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
|
149 |
+
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )
|
150 |
+
|
151 |
+
def call(self, inputs):
|
152 |
+
return self.conv(inputs)
|
153 |
+
|
154 |
+
|
155 |
+
class tf_BottleneckCSP(keras.layers.Layer):
|
156 |
+
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
|
157 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
|
158 |
+
# ch_in, ch_out, number, shortcut, groups, expansion
|
159 |
+
super(tf_BottleneckCSP, self).__init__()
|
160 |
+
c_ = int(c2 * e) # hidden channels
|
161 |
+
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
|
162 |
+
self.cv2 = tf_Conv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
|
163 |
+
self.cv3 = tf_Conv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
|
164 |
+
self.cv4 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv4)
|
165 |
+
self.bn = tf_BN(w.bn)
|
166 |
+
self.act = lambda x: keras.activations.relu(x, alpha=0.1)
|
167 |
+
self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
|
168 |
+
|
169 |
+
def call(self, inputs):
|
170 |
+
y1 = self.cv3(self.m(self.cv1(inputs)))
|
171 |
+
y2 = self.cv2(inputs)
|
172 |
+
return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
|
173 |
+
|
174 |
+
|
175 |
+
class tf_C3(keras.layers.Layer):
|
176 |
+
# CSP Bottleneck with 3 convolutions
|
177 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
|
178 |
+
# ch_in, ch_out, number, shortcut, groups, expansion
|
179 |
+
super(tf_C3, self).__init__()
|
180 |
+
c_ = int(c2 * e) # hidden channels
|
181 |
+
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
|
182 |
+
self.cv2 = tf_Conv(c1, c_, 1, 1, w=w.cv2)
|
183 |
+
self.cv3 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv3)
|
184 |
+
self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
|
185 |
+
|
186 |
+
def call(self, inputs):
|
187 |
+
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
|
188 |
+
|
189 |
+
|
190 |
+
class tf_SPP(keras.layers.Layer):
|
191 |
+
# Spatial pyramid pooling layer used in YOLOv3-SPP
|
192 |
+
def __init__(self, c1, c2, k=(5, 9, 13), w=None):
|
193 |
+
super(tf_SPP, self).__init__()
|
194 |
+
c_ = c1 // 2 # hidden channels
|
195 |
+
self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
|
196 |
+
self.cv2 = tf_Conv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
|
197 |
+
self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
|
198 |
+
|
199 |
+
def call(self, inputs):
|
200 |
+
x = self.cv1(inputs)
|
201 |
+
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
|
202 |
+
|
203 |
+
|
204 |
+
class tf_Detect(keras.layers.Layer):
|
205 |
+
def __init__(self, nc=80, anchors=(), ch=(), w=None): # detection layer
|
206 |
+
super(tf_Detect, self).__init__()
|
207 |
+
self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
|
208 |
+
self.nc = nc # number of classes
|
209 |
+
self.no = nc + 5 # number of outputs per anchor
|
210 |
+
self.nl = len(anchors) # number of detection layers
|
211 |
+
self.na = len(anchors[0]) // 2 # number of anchors
|
212 |
+
self.grid = [tf.zeros(1)] * self.nl # init grid
|
213 |
+
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
|
214 |
+
self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32),
|
215 |
+
[self.nl, 1, -1, 1, 2])
|
216 |
+
self.m = [tf_Conv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
|
217 |
+
self.export = False # onnx export
|
218 |
+
self.training = True # set to False after building model
|
219 |
+
for i in range(self.nl):
|
220 |
+
ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
|
221 |
+
self.grid[i] = self._make_grid(nx, ny)
|
222 |
+
|
223 |
+
def call(self, inputs):
|
224 |
+
# x = x.copy() # for profiling
|
225 |
+
z = [] # inference output
|
226 |
+
self.training |= self.export
|
227 |
+
x = []
|
228 |
+
for i in range(self.nl):
|
229 |
+
x.append(self.m[i](inputs[i]))
|
230 |
+
# x(bs,20,20,255) to x(bs,3,20,20,85)
|
231 |
+
ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
|
232 |
+
x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
|
233 |
+
|
234 |
+
if not self.training: # inference
|
235 |
+
y = tf.sigmoid(x[i])
|
236 |
+
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
|
237 |
+
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
|
238 |
+
# Normalize xywh to 0-1 to reduce calibration error
|
239 |
+
xy /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
|
240 |
+
wh /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
|
241 |
+
y = tf.concat([xy, wh, y[..., 4:]], -1)
|
242 |
+
z.append(tf.reshape(y, [-1, 3 * ny * nx, self.no]))
|
243 |
+
|
244 |
+
return x if self.training else (tf.concat(z, 1), x)
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def _make_grid(nx=20, ny=20):
|
248 |
+
# yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
249 |
+
# return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
250 |
+
xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
|
251 |
+
return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
|
252 |
+
|
253 |
+
|
254 |
+
class tf_Upsample(keras.layers.Layer):
|
255 |
+
def __init__(self, size, scale_factor, mode, w=None):
|
256 |
+
super(tf_Upsample, self).__init__()
|
257 |
+
assert scale_factor == 2, "scale_factor must be 2"
|
258 |
+
# self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
|
259 |
+
if opt.tf_raw_resize:
|
260 |
+
# with default arguments: align_corners=False, half_pixel_centers=False
|
261 |
+
self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
|
262 |
+
size=(x.shape[1] * 2, x.shape[2] * 2))
|
263 |
+
else:
|
264 |
+
self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
|
265 |
+
|
266 |
+
def call(self, inputs):
|
267 |
+
return self.upsample(inputs)
|
268 |
+
|
269 |
+
|
270 |
+
class tf_Concat(keras.layers.Layer):
|
271 |
+
def __init__(self, dimension=1, w=None):
|
272 |
+
super(tf_Concat, self).__init__()
|
273 |
+
assert dimension == 1, "convert only NCHW to NHWC concat"
|
274 |
+
self.d = 3
|
275 |
+
|
276 |
+
def call(self, inputs):
|
277 |
+
return tf.concat(inputs, self.d)
|
278 |
+
|
279 |
+
|
280 |
+
def parse_model(d, ch, model): # model_dict, input_channels(3)
|
281 |
+
logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
|
282 |
+
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
|
283 |
+
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
284 |
+
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
285 |
+
|
286 |
+
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
287 |
+
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
288 |
+
m_str = m
|
289 |
+
m = eval(m) if isinstance(m, str) else m # eval strings
|
290 |
+
for j, a in enumerate(args):
|
291 |
+
try:
|
292 |
+
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
293 |
+
except:
|
294 |
+
pass
|
295 |
+
|
296 |
+
n = max(round(n * gd), 1) if n > 1 else n # depth gain
|
297 |
+
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
|
298 |
+
c1, c2 = ch[f], args[0]
|
299 |
+
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
|
300 |
+
|
301 |
+
args = [c1, c2, *args[1:]]
|
302 |
+
if m in [BottleneckCSP, C3]:
|
303 |
+
args.insert(2, n)
|
304 |
+
n = 1
|
305 |
+
elif m is nn.BatchNorm2d:
|
306 |
+
args = [ch[f]]
|
307 |
+
elif m is Concat:
|
308 |
+
c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
|
309 |
+
elif m is Detect:
|
310 |
+
args.append([ch[x + 1] for x in f])
|
311 |
+
if isinstance(args[1], int): # number of anchors
|
312 |
+
args[1] = [list(range(args[1] * 2))] * len(f)
|
313 |
+
else:
|
314 |
+
c2 = ch[f]
|
315 |
+
|
316 |
+
tf_m = eval('tf_' + m_str.replace('nn.', ''))
|
317 |
+
m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
|
318 |
+
else tf_m(*args, w=model.model[i]) # module
|
319 |
+
|
320 |
+
torch_m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
|
321 |
+
t = str(m)[8:-2].replace('__main__.', '') # module type
|
322 |
+
np = sum([x.numel() for x in torch_m_.parameters()]) # number params
|
323 |
+
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
|
324 |
+
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
|
325 |
+
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
326 |
+
layers.append(m_)
|
327 |
+
ch.append(c2)
|
328 |
+
return keras.Sequential(layers), sorted(save)
|
329 |
+
|
330 |
+
|
331 |
+
class tf_Model():
|
332 |
+
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None): # model, input channels, number of classes
|
333 |
+
super(tf_Model, self).__init__()
|
334 |
+
if isinstance(cfg, dict):
|
335 |
+
self.yaml = cfg # model dict
|
336 |
+
else: # is *.yaml
|
337 |
+
import yaml # for torch hub
|
338 |
+
self.yaml_file = Path(cfg).name
|
339 |
+
with open(cfg) as f:
|
340 |
+
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
341 |
+
|
342 |
+
# Define model
|
343 |
+
if nc and nc != self.yaml['nc']:
|
344 |
+
print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
|
345 |
+
self.yaml['nc'] = nc # override yaml value
|
346 |
+
self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model) # model, savelist, ch_out
|
347 |
+
|
348 |
+
def predict(self, inputs, profile=False):
|
349 |
+
y = [] # outputs
|
350 |
+
x = inputs
|
351 |
+
for i, m in enumerate(self.model.layers):
|
352 |
+
if m.f != -1: # if not from previous layer
|
353 |
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
354 |
+
|
355 |
+
x = m(x) # run
|
356 |
+
y.append(x if m.i in self.savelist else None) # save output
|
357 |
+
|
358 |
+
# Add TensorFlow NMS
|
359 |
+
if opt.tf_nms:
|
360 |
+
boxes = xywh2xyxy(x[0][..., :4])
|
361 |
+
probs = x[0][:, :, 4:5]
|
362 |
+
classes = x[0][:, :, 5:]
|
363 |
+
scores = probs * classes
|
364 |
+
if opt.agnostic_nms:
|
365 |
+
nms = agnostic_nms_layer()((boxes, classes, scores))
|
366 |
+
return nms, x[1]
|
367 |
+
else:
|
368 |
+
boxes = tf.expand_dims(boxes, 2)
|
369 |
+
nms = tf.image.combined_non_max_suppression(
|
370 |
+
boxes, scores, opt.topk_per_class, opt.topk_all, opt.iou_thres, opt.score_thres, clip_boxes=False)
|
371 |
+
return nms, x[1]
|
372 |
+
|
373 |
+
return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
|
374 |
+
# x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
|
375 |
+
# xywh = x[..., :4] # x(6300,4) boxes
|
376 |
+
# conf = x[..., 4:5] # x(6300,1) confidences
|
377 |
+
# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
|
378 |
+
# return tf.concat([conf, cls, xywh], 1)
|
379 |
+
|
380 |
+
|
381 |
+
class agnostic_nms_layer(keras.layers.Layer):
|
382 |
+
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
|
383 |
+
def call(self, input):
|
384 |
+
return tf.map_fn(agnostic_nms, input,
|
385 |
+
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
|
386 |
+
name='agnostic_nms')
|
387 |
+
|
388 |
+
|
389 |
+
def agnostic_nms(x):
|
390 |
+
boxes, classes, scores = x
|
391 |
+
class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
|
392 |
+
scores_inp = tf.reduce_max(scores, -1)
|
393 |
+
selected_inds = tf.image.non_max_suppression(
|
394 |
+
boxes, scores_inp, max_output_size=opt.topk_all, iou_threshold=opt.iou_thres, score_threshold=opt.score_thres)
|
395 |
+
selected_boxes = tf.gather(boxes, selected_inds)
|
396 |
+
padded_boxes = tf.pad(selected_boxes,
|
397 |
+
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
|
398 |
+
mode="CONSTANT", constant_values=0.0)
|
399 |
+
selected_scores = tf.gather(scores_inp, selected_inds)
|
400 |
+
padded_scores = tf.pad(selected_scores,
|
401 |
+
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
|
402 |
+
mode="CONSTANT", constant_values=-1.0)
|
403 |
+
selected_classes = tf.gather(class_inds, selected_inds)
|
404 |
+
padded_classes = tf.pad(selected_classes,
|
405 |
+
paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
|
406 |
+
mode="CONSTANT", constant_values=-1.0)
|
407 |
+
valid_detections = tf.shape(selected_inds)[0]
|
408 |
+
return padded_boxes, padded_scores, padded_classes, valid_detections
|
409 |
+
|
410 |
+
|
411 |
+
def xywh2xyxy(xywh):
|
412 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
413 |
+
x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
|
414 |
+
return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
|
415 |
+
|
416 |
+
|
417 |
+
def representative_dataset_gen():
|
418 |
+
# Representative dataset for use with converter.representative_dataset
|
419 |
+
n = 0
|
420 |
+
for path, img, im0s, vid_cap in dataset:
|
421 |
+
# Get sample input data as a numpy array in a method of your choosing.
|
422 |
+
n += 1
|
423 |
+
input = np.transpose(img, [1, 2, 0])
|
424 |
+
input = np.expand_dims(input, axis=0).astype(np.float32)
|
425 |
+
input /= 255.0
|
426 |
+
yield [input]
|
427 |
+
if n >= opt.ncalib:
|
428 |
+
break
|
429 |
+
|
430 |
+
|
431 |
+
if __name__ == "__main__":
|
432 |
+
parser = argparse.ArgumentParser()
|
433 |
+
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='cfg path')
|
434 |
+
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='weights path')
|
435 |
+
parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size') # height, width
|
436 |
+
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
437 |
+
parser.add_argument('--dynamic-batch-size', action='store_true', help='dynamic batch size')
|
438 |
+
parser.add_argument('--source', type=str, default='../data/coco128.yaml', help='dir of images or data.yaml file')
|
439 |
+
parser.add_argument('--ncalib', type=int, default=100, help='number of calibration images')
|
440 |
+
parser.add_argument('--tfl-int8', action='store_true', dest='tfl_int8', help='export TFLite int8 model')
|
441 |
+
parser.add_argument('--tf-nms', action='store_true', dest='tf_nms', help='TF NMS (without TFLite export)')
|
442 |
+
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
443 |
+
parser.add_argument('--tf-raw-resize', action='store_true', dest='tf_raw_resize',
|
444 |
+
help='use tf.raw_ops.ResizeNearestNeighbor for resize')
|
445 |
+
parser.add_argument('--topk-per-class', type=int, default=100, help='topk per class to keep in NMS')
|
446 |
+
parser.add_argument('--topk-all', type=int, default=100, help='topk for all classes to keep in NMS')
|
447 |
+
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
|
448 |
+
parser.add_argument('--score-thres', type=float, default=0.4, help='score threshold for NMS')
|
449 |
+
opt = parser.parse_args()
|
450 |
+
opt.cfg = check_file(opt.cfg) # check file
|
451 |
+
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
|
452 |
+
print(opt)
|
453 |
+
|
454 |
+
# Input
|
455 |
+
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection
|
456 |
+
|
457 |
+
# Load PyTorch model
|
458 |
+
model = attempt_load(opt.weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
|
459 |
+
model.model[-1].export = False # set Detect() layer export=True
|
460 |
+
y = model(img) # dry run
|
461 |
+
nc = y[0].shape[-1] - 5
|
462 |
+
|
463 |
+
# TensorFlow saved_model export
|
464 |
+
try:
|
465 |
+
print('\nStarting TensorFlow saved_model export with TensorFlow %s...' % tf.__version__)
|
466 |
+
tf_model = tf_Model(opt.cfg, model=model, nc=nc)
|
467 |
+
img = tf.zeros((opt.batch_size, *opt.img_size, 3)) # NHWC Input for TensorFlow
|
468 |
+
|
469 |
+
m = tf_model.model.layers[-1]
|
470 |
+
assert isinstance(m, tf_Detect), "the last layer must be Detect"
|
471 |
+
m.training = False
|
472 |
+
y = tf_model.predict(img)
|
473 |
+
|
474 |
+
inputs = keras.Input(shape=(*opt.img_size, 3), batch_size=None if opt.dynamic_batch_size else opt.batch_size)
|
475 |
+
keras_model = keras.Model(inputs=inputs, outputs=tf_model.predict(inputs))
|
476 |
+
keras_model.summary()
|
477 |
+
path = opt.weights.replace('.pt', '_saved_model') # filename
|
478 |
+
keras_model.save(path, save_format='tf')
|
479 |
+
print('TensorFlow saved_model export success, saved as %s' % path)
|
480 |
+
except Exception as e:
|
481 |
+
print('TensorFlow saved_model export failure: %s' % e)
|
482 |
+
traceback.print_exc(file=sys.stdout)
|
483 |
+
|
484 |
+
# TensorFlow GraphDef export
|
485 |
+
try:
|
486 |
+
print('\nStarting TensorFlow GraphDef export with TensorFlow %s...' % tf.__version__)
|
487 |
+
|
488 |
+
# https://github.com/leimao/Frozen_Graph_TensorFlow
|
489 |
+
full_model = tf.function(lambda x: keras_model(x))
|
490 |
+
full_model = full_model.get_concrete_function(
|
491 |
+
tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
492 |
+
|
493 |
+
frozen_func = convert_variables_to_constants_v2(full_model)
|
494 |
+
frozen_func.graph.as_graph_def()
|
495 |
+
f = opt.weights.replace('.pt', '.pb') # filename
|
496 |
+
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
|
497 |
+
logdir=os.path.dirname(f),
|
498 |
+
name=os.path.basename(f),
|
499 |
+
as_text=False)
|
500 |
+
|
501 |
+
print('TensorFlow GraphDef export success, saved as %s' % f)
|
502 |
+
except Exception as e:
|
503 |
+
print('TensorFlow GraphDef export failure: %s' % e)
|
504 |
+
traceback.print_exc(file=sys.stdout)
|
505 |
+
|
506 |
+
# TFLite model export
|
507 |
+
if not opt.tf_nms:
|
508 |
+
try:
|
509 |
+
print('\nStarting TFLite export with TensorFlow %s...' % tf.__version__)
|
510 |
+
|
511 |
+
# fp32 TFLite model export ---------------------------------------------------------------------------------
|
512 |
+
# converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
513 |
+
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
|
514 |
+
# converter.allow_custom_ops = False
|
515 |
+
# converter.experimental_new_converter = True
|
516 |
+
# tflite_model = converter.convert()
|
517 |
+
# f = opt.weights.replace('.pt', '.tflite') # filename
|
518 |
+
# open(f, "wb").write(tflite_model)
|
519 |
+
|
520 |
+
# fp16 TFLite model export ---------------------------------------------------------------------------------
|
521 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
522 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
523 |
+
# converter.representative_dataset = representative_dataset_gen
|
524 |
+
# converter.target_spec.supported_types = [tf.float16]
|
525 |
+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
|
526 |
+
converter.allow_custom_ops = False
|
527 |
+
converter.experimental_new_converter = True
|
528 |
+
tflite_model = converter.convert()
|
529 |
+
f = opt.weights.replace('.pt', '-fp16.tflite') # filename
|
530 |
+
open(f, "wb").write(tflite_model)
|
531 |
+
print('\nTFLite export success, saved as %s' % f)
|
532 |
+
|
533 |
+
# int8 TFLite model export ---------------------------------------------------------------------------------
|
534 |
+
if opt.tfl_int8:
|
535 |
+
# Representative Dataset
|
536 |
+
if opt.source.endswith('.yaml'):
|
537 |
+
with open(check_file(opt.source)) as f:
|
538 |
+
data = yaml.load(f, Loader=yaml.FullLoader) # data dict
|
539 |
+
check_dataset(data) # check
|
540 |
+
opt.source = data['train']
|
541 |
+
dataset = LoadImages(opt.source, img_size=opt.img_size, auto=False)
|
542 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
543 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
544 |
+
converter.representative_dataset = representative_dataset_gen
|
545 |
+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
546 |
+
converter.inference_input_type = tf.uint8 # or tf.int8
|
547 |
+
converter.inference_output_type = tf.uint8 # or tf.int8
|
548 |
+
converter.allow_custom_ops = False
|
549 |
+
converter.experimental_new_converter = True
|
550 |
+
converter.experimental_new_quantizer = False
|
551 |
+
tflite_model = converter.convert()
|
552 |
+
f = opt.weights.replace('.pt', '-int8.tflite') # filename
|
553 |
+
open(f, "wb").write(tflite_model)
|
554 |
+
print('\nTFLite (int8) export success, saved as %s' % f)
|
555 |
+
|
556 |
+
except Exception as e:
|
557 |
+
print('\nTFLite export failure: %s' % e)
|
558 |
+
traceback.print_exc(file=sys.stdout)
|
@@ -23,6 +23,7 @@ pandas
|
|
23 |
# coremltools>=4.1
|
24 |
# onnx>=1.9.0
|
25 |
# scikit-learn==0.19.2 # for coreml quantization
|
|
|
26 |
|
27 |
# extras --------------------------------------
|
28 |
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
|
|
|
23 |
# coremltools>=4.1
|
24 |
# onnx>=1.9.0
|
25 |
# scikit-learn==0.19.2 # for coreml quantization
|
26 |
+
# tensorflow==2.4.1 # for TFLite export
|
27 |
|
28 |
# extras --------------------------------------
|
29 |
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
|
@@ -155,7 +155,7 @@ class _RepeatSampler(object):
|
|
155 |
|
156 |
|
157 |
class LoadImages: # for inference
|
158 |
-
def __init__(self, path, img_size=640, stride=32):
|
159 |
p = str(Path(path).absolute()) # os-agnostic absolute path
|
160 |
if '*' in p:
|
161 |
files = sorted(glob.glob(p, recursive=True)) # glob
|
@@ -176,6 +176,7 @@ class LoadImages: # for inference
|
|
176 |
self.nf = ni + nv # number of files
|
177 |
self.video_flag = [False] * ni + [True] * nv
|
178 |
self.mode = 'image'
|
|
|
179 |
if any(videos):
|
180 |
self.new_video(videos[0]) # new video
|
181 |
else:
|
@@ -217,7 +218,7 @@ class LoadImages: # for inference
|
|
217 |
print(f'image {self.count}/{self.nf} {path}: ', end='')
|
218 |
|
219 |
# Padded resize
|
220 |
-
img = letterbox(img0, self.img_size, stride=self.stride)[0]
|
221 |
|
222 |
# Convert
|
223 |
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
@@ -276,7 +277,7 @@ class LoadWebcam: # for inference
|
|
276 |
|
277 |
|
278 |
class LoadStreams: # multiple IP or RTSP cameras
|
279 |
-
def __init__(self, sources='streams.txt', img_size=640, stride=32):
|
280 |
self.mode = 'stream'
|
281 |
self.img_size = img_size
|
282 |
self.stride = stride
|
@@ -290,6 +291,7 @@ class LoadStreams: # multiple IP or RTSP cameras
|
|
290 |
n = len(sources)
|
291 |
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
|
292 |
self.sources = [clean_str(x) for x in sources] # clean source names for later
|
|
|
293 |
for i, s in enumerate(sources): # index, source
|
294 |
# Start thread to read frames from video stream
|
295 |
print(f'{i + 1}/{n}: {s}... ', end='')
|
@@ -312,7 +314,7 @@ class LoadStreams: # multiple IP or RTSP cameras
|
|
312 |
print('') # newline
|
313 |
|
314 |
# check for common shapes
|
315 |
-
s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
|
316 |
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
317 |
if not self.rect:
|
318 |
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
|
@@ -341,7 +343,7 @@ class LoadStreams: # multiple IP or RTSP cameras
|
|
341 |
|
342 |
# Letterbox
|
343 |
img0 = self.imgs.copy()
|
344 |
-
img = [letterbox(x, self.img_size,
|
345 |
|
346 |
# Stack
|
347 |
img = np.stack(img, 0)
|
|
|
155 |
|
156 |
|
157 |
class LoadImages: # for inference
|
158 |
+
def __init__(self, path, img_size=640, stride=32, auto=True):
|
159 |
p = str(Path(path).absolute()) # os-agnostic absolute path
|
160 |
if '*' in p:
|
161 |
files = sorted(glob.glob(p, recursive=True)) # glob
|
|
|
176 |
self.nf = ni + nv # number of files
|
177 |
self.video_flag = [False] * ni + [True] * nv
|
178 |
self.mode = 'image'
|
179 |
+
self.auto = auto
|
180 |
if any(videos):
|
181 |
self.new_video(videos[0]) # new video
|
182 |
else:
|
|
|
218 |
print(f'image {self.count}/{self.nf} {path}: ', end='')
|
219 |
|
220 |
# Padded resize
|
221 |
+
img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
|
222 |
|
223 |
# Convert
|
224 |
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
|
|
277 |
|
278 |
|
279 |
class LoadStreams: # multiple IP or RTSP cameras
|
280 |
+
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
|
281 |
self.mode = 'stream'
|
282 |
self.img_size = img_size
|
283 |
self.stride = stride
|
|
|
291 |
n = len(sources)
|
292 |
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
|
293 |
self.sources = [clean_str(x) for x in sources] # clean source names for later
|
294 |
+
self.auto = auto
|
295 |
for i, s in enumerate(sources): # index, source
|
296 |
# Start thread to read frames from video stream
|
297 |
print(f'{i + 1}/{n}: {s}... ', end='')
|
|
|
314 |
print('') # newline
|
315 |
|
316 |
# check for common shapes
|
317 |
+
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs], 0) # shapes
|
318 |
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
319 |
if not self.rect:
|
320 |
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
|
|
|
343 |
|
344 |
# Letterbox
|
345 |
img0 = self.imgs.copy()
|
346 |
+
img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
|
347 |
|
348 |
# Stack
|
349 |
img = np.stack(img, 0)
|