Add Hub results.pandas() method (#2725)
Browse files* Add Hub results.pandas() method
New method converts results from torch tensors to pandas DataFrames with column names.
This PR may partially resolve issue https://github.com/ultralytics/yolov5/issues/2703
```python
results = model(imgs)
print(results.pandas().xyxy[0])
xmin ymin xmax ymax confidence class name
0 57.068970 391.770599 241.383545 905.797852 0.868964 0 person
1 667.661255 399.303589 810.000000 881.396667 0.851888 0 person
2 222.878387 414.774231 343.804474 857.825073 0.838376 0 person
3 4.205386 234.447678 803.739136 750.023376 0.658006 5 bus
4 0.000000 550.596008 76.681190 878.669922 0.450596 0 person
```
* Update comments
torch example input now shown resized to size=640 and also now a multiple of P6 stride 64 (see https://github.com/ultralytics/yolov5/issues/2722#issuecomment-814785930)
* apply decorators
* PEP8
* Update common.py
* pd.options.display.max_columns = 10
* Update common.py
- hubconf.py +1 -1
- models/common.py +29 -17
- utils/general.py +2 -0
@@ -38,7 +38,7 @@ def create(name, pretrained, channels, classes, autoshape):
|
|
38 |
fname = f'{name}.pt' # checkpoint filename
|
39 |
attempt_download(fname) # download if not found locally
|
40 |
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
41 |
-
msd = model.state_dict()
|
42 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
43 |
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
44 |
model.load_state_dict(csd, strict=False) # load
|
|
|
38 |
fname = f'{name}.pt' # checkpoint filename
|
39 |
attempt_download(fname) # download if not found locally
|
40 |
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
41 |
+
msd = model.state_dict() # model state_dict
|
42 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
43 |
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
44 |
model.load_state_dict(csd, strict=False) # load
|
@@ -1,14 +1,15 @@
|
|
1 |
# YOLOv5 common modules
|
2 |
|
3 |
import math
|
|
|
4 |
from pathlib import Path
|
5 |
|
6 |
import numpy as np
|
|
|
7 |
import requests
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
from PIL import Image
|
11 |
-
from torch.cuda import amp
|
12 |
|
13 |
from utils.datasets import letterbox
|
14 |
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
|
@@ -235,14 +236,16 @@ class autoShape(nn.Module):
|
|
235 |
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
|
236 |
return self
|
237 |
|
|
|
|
|
238 |
def forward(self, imgs, size=640, augment=False, profile=False):
|
239 |
-
# Inference from various sources. For height=
|
240 |
# filename: imgs = 'data/samples/zidane.jpg'
|
241 |
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
|
242 |
-
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(
|
243 |
-
# PIL: = Image.open('image.jpg') # HWC x(
|
244 |
-
# numpy: = np.zeros((
|
245 |
-
# torch: = torch.zeros(16,3,
|
246 |
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
247 |
|
248 |
t = [time_synchronized()]
|
@@ -275,15 +278,14 @@ class autoShape(nn.Module):
|
|
275 |
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
276 |
t.append(time_synchronized())
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
t.append(time_synchronized())
|
282 |
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
|
288 |
t.append(time_synchronized())
|
289 |
return Detections(imgs, y, files, t, self.names, x.shape)
|
@@ -347,17 +349,27 @@ class Detections:
|
|
347 |
self.display(render=True) # render results
|
348 |
return self.imgs
|
349 |
|
350 |
-
def
|
351 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
|
353 |
def tolist(self):
|
354 |
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
355 |
-
x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
|
356 |
for d in x:
|
357 |
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
358 |
setattr(d, k, getattr(d, k)[0]) # pop out of list
|
359 |
return x
|
360 |
|
|
|
|
|
|
|
361 |
|
362 |
class Classify(nn.Module):
|
363 |
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
|
|
1 |
# YOLOv5 common modules
|
2 |
|
3 |
import math
|
4 |
+
from copy import copy
|
5 |
from pathlib import Path
|
6 |
|
7 |
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
import requests
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
from PIL import Image
|
|
|
13 |
|
14 |
from utils.datasets import letterbox
|
15 |
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
|
|
|
236 |
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
|
237 |
return self
|
238 |
|
239 |
+
@torch.no_grad()
|
240 |
+
@torch.cuda.amp.autocast()
|
241 |
def forward(self, imgs, size=640, augment=False, profile=False):
|
242 |
+
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
243 |
# filename: imgs = 'data/samples/zidane.jpg'
|
244 |
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
|
245 |
+
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
246 |
+
# PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
|
247 |
+
# numpy: = np.zeros((640,1280,3)) # HWC
|
248 |
+
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
249 |
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
250 |
|
251 |
t = [time_synchronized()]
|
|
|
278 |
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
|
279 |
t.append(time_synchronized())
|
280 |
|
281 |
+
# Inference
|
282 |
+
y = self.model(x, augment, profile)[0] # forward
|
283 |
+
t.append(time_synchronized())
|
|
|
284 |
|
285 |
+
# Post-process
|
286 |
+
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
|
287 |
+
for i in range(n):
|
288 |
+
scale_coords(shape1, y[i][:, :4], shape0[i])
|
289 |
|
290 |
t.append(time_synchronized())
|
291 |
return Detections(imgs, y, files, t, self.names, x.shape)
|
|
|
349 |
self.display(render=True) # render results
|
350 |
return self.imgs
|
351 |
|
352 |
+
def pandas(self):
|
353 |
+
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
354 |
+
new = copy(self) # return copy
|
355 |
+
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
356 |
+
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
357 |
+
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
358 |
+
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
359 |
+
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
|
360 |
+
return new
|
361 |
|
362 |
def tolist(self):
|
363 |
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
364 |
+
x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
|
365 |
for d in x:
|
366 |
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
367 |
setattr(d, k, getattr(d, k)[0]) # pop out of list
|
368 |
return x
|
369 |
|
370 |
+
def __len__(self):
|
371 |
+
return self.n
|
372 |
+
|
373 |
|
374 |
class Classify(nn.Module):
|
375 |
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
@@ -13,6 +13,7 @@ from pathlib import Path
|
|
13 |
|
14 |
import cv2
|
15 |
import numpy as np
|
|
|
16 |
import torch
|
17 |
import torchvision
|
18 |
import yaml
|
@@ -24,6 +25,7 @@ from utils.torch_utils import init_torch_seeds
|
|
24 |
# Settings
|
25 |
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
26 |
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
|
|
27 |
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
28 |
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
|
29 |
|
|
|
13 |
|
14 |
import cv2
|
15 |
import numpy as np
|
16 |
+
import pandas as pd
|
17 |
import torch
|
18 |
import torchvision
|
19 |
import yaml
|
|
|
25 |
# Settings
|
26 |
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
27 |
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
28 |
+
pd.options.display.max_columns = 10
|
29 |
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
30 |
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
|
31 |
|