glenn-jocher
commited on
Commit
•
f542926
1
Parent(s):
92c9b72
PyTorch Hub and autoShape update (#1415)
Browse files* PyTorch Hub and autoShape update
* comment x for imgs
* reduce comment
- detect.py +1 -1
- hubconf.py +8 -8
- models/common.py +71 -22
- test.py +1 -1
- utils/general.py +3 -3
detect.py
CHANGED
@@ -89,7 +89,7 @@ def detect(save_img=False):
|
|
89 |
txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
|
90 |
s += '%gx%g ' % img.shape[2:] # print string
|
91 |
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
92 |
-
if
|
93 |
# Rescale boxes from img_size to im0 size
|
94 |
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
|
95 |
|
|
|
89 |
txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
|
90 |
s += '%gx%g ' % img.shape[2:] # print string
|
91 |
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
92 |
+
if len(det):
|
93 |
# Rescale boxes from img_size to im0 size
|
94 |
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
|
95 |
|
hubconf.py
CHANGED
@@ -5,15 +5,16 @@ Usage:
|
|
5 |
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
|
6 |
"""
|
7 |
|
8 |
-
dependencies = ['torch', 'yaml']
|
9 |
from pathlib import Path
|
10 |
|
11 |
import torch
|
|
|
12 |
|
13 |
from models.yolo import Model
|
14 |
from utils.general import set_logging
|
15 |
from utils.google_utils import attempt_download
|
16 |
|
|
|
17 |
set_logging()
|
18 |
|
19 |
|
@@ -41,7 +42,7 @@ def create(name, pretrained, channels, classes):
|
|
41 |
model.load_state_dict(state_dict, strict=False) # load
|
42 |
if len(ckpt['model'].names) == classes:
|
43 |
model.names = ckpt['model'].names # set class names attribute
|
44 |
-
# model = model.autoshape() # for
|
45 |
return model
|
46 |
|
47 |
except Exception as e:
|
@@ -108,11 +109,10 @@ def yolov5x(pretrained=False, channels=3, classes=80):
|
|
108 |
|
109 |
if __name__ == '__main__':
|
110 |
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
111 |
-
model = model.fuse().
|
112 |
|
113 |
# Verify inference
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
print(y[0].shape)
|
|
|
5 |
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
|
6 |
"""
|
7 |
|
|
|
8 |
from pathlib import Path
|
9 |
|
10 |
import torch
|
11 |
+
from PIL import Image
|
12 |
|
13 |
from models.yolo import Model
|
14 |
from utils.general import set_logging
|
15 |
from utils.google_utils import attempt_download
|
16 |
|
17 |
+
dependencies = ['torch', 'yaml', 'pillow']
|
18 |
set_logging()
|
19 |
|
20 |
|
|
|
42 |
model.load_state_dict(state_dict, strict=False) # load
|
43 |
if len(ckpt['model'].names) == classes:
|
44 |
model.names = ckpt['model'].names # set class names attribute
|
45 |
+
# model = model.autoshape() # for PIL/cv2/np inputs and NMS
|
46 |
return model
|
47 |
|
48 |
except Exception as e:
|
|
|
109 |
|
110 |
if __name__ == '__main__':
|
111 |
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
112 |
+
model = model.fuse().autoshape() # for PIL/cv2/np inputs and NMS
|
113 |
|
114 |
# Verify inference
|
115 |
+
imgs = [Image.open(x) for x in Path('data/images').glob('*.jpg')]
|
116 |
+
results = model(imgs)
|
117 |
+
results.show()
|
118 |
+
results.print()
|
|
models/common.py
CHANGED
@@ -5,9 +5,11 @@ import math
|
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
import torch.nn as nn
|
|
|
8 |
|
9 |
from utils.datasets import letterbox
|
10 |
-
from utils.general import non_max_suppression, make_divisible, scale_coords
|
|
|
11 |
|
12 |
|
13 |
def autopad(k, p=None): # kernel, padding
|
@@ -125,47 +127,94 @@ class autoShape(nn.Module):
|
|
125 |
|
126 |
def __init__(self, model):
|
127 |
super(autoShape, self).__init__()
|
128 |
-
self.model = model
|
129 |
|
130 |
-
def forward(self,
|
131 |
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
|
132 |
-
# opencv:
|
133 |
-
# PIL:
|
134 |
-
# numpy:
|
135 |
-
# torch:
|
136 |
-
# multiple:
|
137 |
|
138 |
p = next(self.model.parameters()) # for device and type
|
139 |
-
if isinstance(
|
140 |
-
return self.model(
|
141 |
|
142 |
# Pre-process
|
143 |
-
if not isinstance(
|
144 |
-
|
145 |
shape0, shape1 = [], [] # image and inference shapes
|
146 |
-
batch = range(len(
|
147 |
for i in batch:
|
148 |
-
|
149 |
-
|
150 |
-
s =
|
151 |
shape0.append(s) # image shape
|
152 |
g = (size / max(s)) # gain
|
153 |
shape1.append([y * g for y in s])
|
154 |
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
|
155 |
-
x = [letterbox(
|
156 |
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
|
157 |
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
|
158 |
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
159 |
|
160 |
# Inference
|
161 |
-
|
162 |
-
|
|
|
163 |
|
164 |
# Post-process
|
165 |
for i in batch:
|
166 |
-
if
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
|
171 |
class Flatten(nn.Module):
|
|
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
+
from PIL import Image, ImageDraw
|
9 |
|
10 |
from utils.datasets import letterbox
|
11 |
+
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
|
12 |
+
from utils.plots import color_list
|
13 |
|
14 |
|
15 |
def autopad(k, p=None): # kernel, padding
|
|
|
127 |
|
128 |
def __init__(self, model):
|
129 |
super(autoShape, self).__init__()
|
130 |
+
self.model = model.eval()
|
131 |
|
132 |
+
def forward(self, imgs, size=640, augment=False, profile=False):
|
133 |
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
|
134 |
+
# opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
|
135 |
+
# PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3)
|
136 |
+
# numpy: imgs = np.zeros((720,1280,3)) # HWC
|
137 |
+
# torch: imgs = torch.zeros(16,3,720,1280) # BCHW
|
138 |
+
# multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
139 |
|
140 |
p = next(self.model.parameters()) # for device and type
|
141 |
+
if isinstance(imgs, torch.Tensor): # torch
|
142 |
+
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
143 |
|
144 |
# Pre-process
|
145 |
+
if not isinstance(imgs, list):
|
146 |
+
imgs = [imgs]
|
147 |
shape0, shape1 = [], [] # image and inference shapes
|
148 |
+
batch = range(len(imgs)) # batch size
|
149 |
for i in batch:
|
150 |
+
imgs[i] = np.array(imgs[i]) # to numpy
|
151 |
+
imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input
|
152 |
+
s = imgs[i].shape[:2] # HWC
|
153 |
shape0.append(s) # image shape
|
154 |
g = (size / max(s)) # gain
|
155 |
shape1.append([y * g for y in s])
|
156 |
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
|
157 |
+
x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
|
158 |
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
|
159 |
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
|
160 |
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
161 |
|
162 |
# Inference
|
163 |
+
with torch.no_grad():
|
164 |
+
y = self.model(x, augment, profile)[0] # forward
|
165 |
+
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
|
166 |
|
167 |
# Post-process
|
168 |
for i in batch:
|
169 |
+
if y[i] is not None:
|
170 |
+
y[i][:, :4] = scale_coords(shape1, y[i][:, :4], shape0[i])
|
171 |
+
|
172 |
+
return Detections(imgs, y, self.names)
|
173 |
+
|
174 |
+
|
175 |
+
class Detections:
|
176 |
+
# detections class for YOLOv5 inference results
|
177 |
+
def __init__(self, imgs, pred, names=None):
|
178 |
+
super(Detections, self).__init__()
|
179 |
+
self.imgs = imgs # list of images as numpy arrays
|
180 |
+
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
181 |
+
self.names = names # class names
|
182 |
+
self.xyxy = pred # xyxy pixels
|
183 |
+
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
|
184 |
+
gn = [torch.Tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.]) for im in imgs] # normalization gains
|
185 |
+
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
186 |
+
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
187 |
+
|
188 |
+
def display(self, pprint=False, show=False, save=False):
|
189 |
+
colors = color_list()
|
190 |
+
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
|
191 |
+
str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
|
192 |
+
if pred is not None:
|
193 |
+
for c in pred[:, -1].unique():
|
194 |
+
n = (pred[:, -1] == c).sum() # detections per class
|
195 |
+
str += f'{n} {self.names[int(c)]}s, ' # add to string
|
196 |
+
if show or save:
|
197 |
+
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
|
198 |
+
for *box, conf, cls in pred: # xyxy, confidence, class
|
199 |
+
# str += '%s %.2f, ' % (names[int(cls)], conf) # label
|
200 |
+
ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
|
201 |
+
if save:
|
202 |
+
f = f'results{i}.jpg'
|
203 |
+
str += f"saved to '{f}'"
|
204 |
+
img.save(f) # save
|
205 |
+
if show:
|
206 |
+
img.show(f'Image {i}') # show
|
207 |
+
if pprint:
|
208 |
+
print(str)
|
209 |
+
|
210 |
+
def print(self):
|
211 |
+
self.display(pprint=True) # print results
|
212 |
+
|
213 |
+
def show(self):
|
214 |
+
self.display(show=True) # show results
|
215 |
+
|
216 |
+
def save(self):
|
217 |
+
self.display(save=True) # save results
|
218 |
|
219 |
|
220 |
class Flatten(nn.Module):
|
test.py
CHANGED
@@ -126,7 +126,7 @@ def test(data,
|
|
126 |
tcls = labels[:, 0].tolist() if nl else [] # target class
|
127 |
seen += 1
|
128 |
|
129 |
-
if pred
|
130 |
if nl:
|
131 |
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
|
132 |
continue
|
|
|
126 |
tcls = labels[:, 0].tolist() if nl else [] # target class
|
127 |
seen += 1
|
128 |
|
129 |
+
if len(pred) == 0:
|
130 |
if nl:
|
131 |
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
|
132 |
continue
|
utils/general.py
CHANGED
@@ -142,7 +142,7 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
|
142 |
|
143 |
def xyxy2xywh(x):
|
144 |
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
145 |
-
y =
|
146 |
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
147 |
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
148 |
y[:, 2] = x[:, 2] - x[:, 0] # width
|
@@ -152,7 +152,7 @@ def xyxy2xywh(x):
|
|
152 |
|
153 |
def xywh2xyxy(x):
|
154 |
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
155 |
-
y =
|
156 |
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
157 |
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
158 |
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
@@ -280,7 +280,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
|
|
280 |
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
281 |
|
282 |
t = time.time()
|
283 |
-
output = [
|
284 |
for xi, x in enumerate(prediction): # image index, image inference
|
285 |
# Apply constraints
|
286 |
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
|
|
142 |
|
143 |
def xyxy2xywh(x):
|
144 |
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
145 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
146 |
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
147 |
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
148 |
y[:, 2] = x[:, 2] - x[:, 0] # width
|
|
|
152 |
|
153 |
def xywh2xyxy(x):
|
154 |
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
155 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
156 |
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
157 |
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
158 |
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
|
|
280 |
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
281 |
|
282 |
t = time.time()
|
283 |
+
output = [torch.zeros(0, 6)] * prediction.shape[0]
|
284 |
for xi, x in enumerate(prediction): # image index, image inference
|
285 |
# Apply constraints
|
286 |
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|