glenn-jocher
commited on
Merge remote-tracking branch 'origin/master'
Browse files- README.md +5 -4
- data/coco.yaml +3 -3
- data/coco128.yaml +6 -6
- data/get_coco2017.sh +3 -2
- data/get_voc.sh +2 -1
- data/voc.yaml +5 -4
- test.py +2 -2
- utils/datasets.py +27 -16
- utils/torch_utils.py +2 -0
README.md
CHANGED
@@ -21,7 +21,7 @@ This repository represents Ultralytics open-source research into future object d
|
|
21 |
| [YOLOv5m](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 43.4 | 43.4 | 62.4 | 3.0ms | 333 || 21.8M | 39.4B
|
22 |
| [YOLOv5l](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 46.6 | 46.7 | 65.4 | 3.9ms | 256 || 47.8M | 88.1B
|
23 |
| [YOLOv5x](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | **48.4** | **48.4** | **66.9** | 6.1ms | 164 || 89.0M | 166.4B
|
24 |
-
| [YOLOv3-SPP](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)
|
25 |
|
26 |
|
27 |
** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy.
|
@@ -54,10 +54,11 @@ $ pip install -U -r requirements.txt
|
|
54 |
|
55 |
Inference can be run on most common media formats. Model [checkpoints](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) are downloaded automatically if available. Results are saved to `./inference/output`.
|
56 |
```bash
|
57 |
-
$ python detect.py --source
|
|
|
58 |
file.mp4 # video
|
59 |
-
|
60 |
-
|
61 |
rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa # rtsp stream
|
62 |
http://112.50.243.8/PLTV/88888888/224/3221225900/1.m3u8 # http stream
|
63 |
```
|
|
|
21 |
| [YOLOv5m](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 43.4 | 43.4 | 62.4 | 3.0ms | 333 || 21.8M | 39.4B
|
22 |
| [YOLOv5l](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 46.6 | 46.7 | 65.4 | 3.9ms | 256 || 47.8M | 88.1B
|
23 |
| [YOLOv5x](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | **48.4** | **48.4** | **66.9** | 6.1ms | 164 || 89.0M | 166.4B
|
24 |
+
| [YOLOv3-SPP](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 45.6 | 45.5 | 65.2 | 4.5ms | 222 || 63.0M | 118.0B
|
25 |
|
26 |
|
27 |
** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy.
|
|
|
54 |
|
55 |
Inference can be run on most common media formats. Model [checkpoints](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) are downloaded automatically if available. Results are saved to `./inference/output`.
|
56 |
```bash
|
57 |
+
$ python detect.py --source 0 # webcam
|
58 |
+
file.jpg # image
|
59 |
file.mp4 # video
|
60 |
+
path/ # directory
|
61 |
+
path/*.jpg # glob
|
62 |
rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa # rtsp stream
|
63 |
http://112.50.243.8/PLTV/88888888/224/3221225900/1.m3u8 # http stream
|
64 |
```
|
data/coco.yaml
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
# COCO 2017 dataset http://cocodataset.org
|
2 |
# Download command: bash yolov5/data/get_coco2017.sh
|
3 |
-
# Train command: python train.py --data
|
4 |
-
#
|
5 |
# /parent_folder
|
6 |
# /coco
|
7 |
# /yolov5
|
8 |
|
9 |
|
10 |
-
# train and val
|
11 |
train: ../coco/train2017.txt # 118k images
|
12 |
val: ../coco/val2017.txt # 5k images
|
13 |
test: ../coco/test-dev2017.txt # 20k images for submission to https://competitions.codalab.org/competitions/20794
|
|
|
1 |
# COCO 2017 dataset http://cocodataset.org
|
2 |
# Download command: bash yolov5/data/get_coco2017.sh
|
3 |
+
# Train command: python train.py --data coco.yaml
|
4 |
+
# Default dataset location is next to /yolov5:
|
5 |
# /parent_folder
|
6 |
# /coco
|
7 |
# /yolov5
|
8 |
|
9 |
|
10 |
+
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
|
11 |
train: ../coco/train2017.txt # 118k images
|
12 |
val: ../coco/val2017.txt # 5k images
|
13 |
test: ../coco/test-dev2017.txt # 20k images for submission to https://competitions.codalab.org/competitions/20794
|
data/coco128.yaml
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
# COCO 2017 dataset http://cocodataset.org - first 128 training images
|
2 |
-
# Download command: python -c "from yolov5.utils.google_utils import
|
3 |
-
# Train command: python train.py --data
|
4 |
-
#
|
5 |
# /parent_folder
|
6 |
# /coco128
|
7 |
# /yolov5
|
8 |
|
9 |
|
10 |
-
# train and val
|
11 |
-
train: ../coco128/images/train2017/
|
12 |
-
val: ../coco128/images/train2017/
|
13 |
|
14 |
# number of classes
|
15 |
nc: 80
|
|
|
1 |
# COCO 2017 dataset http://cocodataset.org - first 128 training images
|
2 |
+
# Download command: python -c "from yolov5.utils.google_utils import *; gdrive_download('1n_oKgR81BJtqk75b00eAjdv03qVCQn2f', 'coco128.zip')"
|
3 |
+
# Train command: python train.py --data coco128.yaml
|
4 |
+
# Default dataset location is next to /yolov5:
|
5 |
# /parent_folder
|
6 |
# /coco128
|
7 |
# /yolov5
|
8 |
|
9 |
|
10 |
+
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
|
11 |
+
train: ../coco128/images/train2017/ # 128 images
|
12 |
+
val: ../coco128/images/train2017/ # 128 images
|
13 |
|
14 |
# number of classes
|
15 |
nc: 80
|
data/get_coco2017.sh
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
#!/bin/bash
|
2 |
# COCO 2017 dataset http://cocodataset.org
|
3 |
# Download command: bash yolov5/data/get_coco2017.sh
|
4 |
-
# Train command: python train.py --data
|
5 |
-
#
|
6 |
# /parent_folder
|
7 |
# /coco
|
8 |
# /yolov5
|
9 |
|
|
|
10 |
# Download labels from Google Drive, accepting presented query
|
11 |
filename="coco2017labels.zip"
|
12 |
fileid="1cXZR_ckHki6nddOmcysCuuJFM--T-Q6L"
|
|
|
1 |
#!/bin/bash
|
2 |
# COCO 2017 dataset http://cocodataset.org
|
3 |
# Download command: bash yolov5/data/get_coco2017.sh
|
4 |
+
# Train command: python train.py --data coco.yaml
|
5 |
+
# Default dataset location is next to /yolov5:
|
6 |
# /parent_folder
|
7 |
# /coco
|
8 |
# /yolov5
|
9 |
|
10 |
+
|
11 |
# Download labels from Google Drive, accepting presented query
|
12 |
filename="coco2017labels.zip"
|
13 |
fileid="1cXZR_ckHki6nddOmcysCuuJFM--T-Q6L"
|
data/get_voc.sh
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
|
2 |
# Download command: bash ./data/get_voc.sh
|
3 |
# Train command: python train.py --data voc.yaml
|
4 |
-
#
|
5 |
# /parent_folder
|
6 |
# /VOC
|
7 |
# /yolov5
|
8 |
|
|
|
9 |
start=`date +%s`
|
10 |
|
11 |
# handle optional download dir
|
|
|
1 |
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
|
2 |
# Download command: bash ./data/get_voc.sh
|
3 |
# Train command: python train.py --data voc.yaml
|
4 |
+
# Default dataset location is next to /yolov5:
|
5 |
# /parent_folder
|
6 |
# /VOC
|
7 |
# /yolov5
|
8 |
|
9 |
+
|
10 |
start=`date +%s`
|
11 |
|
12 |
# handle optional download dir
|
data/voc.yaml
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
|
2 |
# Download command: bash ./data/get_voc.sh
|
3 |
# Train command: python train.py --data voc.yaml
|
4 |
-
#
|
5 |
# /parent_folder
|
6 |
# /VOC
|
7 |
# /yolov5
|
8 |
|
9 |
-
|
10 |
-
train:
|
11 |
-
|
|
|
12 |
|
13 |
# number of classes
|
14 |
nc: 20
|
|
|
1 |
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/
|
2 |
# Download command: bash ./data/get_voc.sh
|
3 |
# Train command: python train.py --data voc.yaml
|
4 |
+
# Default dataset location is next to /yolov5:
|
5 |
# /parent_folder
|
6 |
# /VOC
|
7 |
# /yolov5
|
8 |
|
9 |
+
|
10 |
+
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
|
11 |
+
train: ../VOC/images/train/ # 16551 images
|
12 |
+
val: ../VOC/images/val/ # 4952 images
|
13 |
|
14 |
# number of classes
|
15 |
nc: 20
|
test.py
CHANGED
@@ -41,9 +41,9 @@ def test(data,
|
|
41 |
# model = nn.DataParallel(model)
|
42 |
|
43 |
# Half
|
44 |
-
half = device.type != 'cpu'
|
45 |
if half:
|
46 |
-
model.half()
|
47 |
|
48 |
# Configure
|
49 |
model.eval()
|
|
|
41 |
# model = nn.DataParallel(model)
|
42 |
|
43 |
# Half
|
44 |
+
half = device.type != 'cpu' # half precision only supported on CUDA
|
45 |
if half:
|
46 |
+
model.half()
|
47 |
|
48 |
# Configure
|
49 |
model.eval()
|
utils/datasets.py
CHANGED
@@ -68,35 +68,39 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|
68 |
|
69 |
class LoadImages: # for inference
|
70 |
def __init__(self, path, img_size=640):
|
71 |
-
|
72 |
-
|
73 |
-
if
|
74 |
-
files = sorted(glob.glob(
|
75 |
-
elif os.path.
|
76 |
-
files =
|
|
|
|
|
|
|
|
|
77 |
|
78 |
images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
|
79 |
videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
|
80 |
-
|
81 |
|
82 |
self.img_size = img_size
|
83 |
self.files = images + videos
|
84 |
-
self.
|
85 |
-
self.video_flag = [False] *
|
86 |
self.mode = 'images'
|
87 |
if any(videos):
|
88 |
self.new_video(videos[0]) # new video
|
89 |
else:
|
90 |
self.cap = None
|
91 |
-
assert self.
|
92 |
-
(
|
93 |
|
94 |
def __iter__(self):
|
95 |
self.count = 0
|
96 |
return self
|
97 |
|
98 |
def __next__(self):
|
99 |
-
if self.count == self.
|
100 |
raise StopIteration
|
101 |
path = self.files[self.count]
|
102 |
|
@@ -107,7 +111,7 @@ class LoadImages: # for inference
|
|
107 |
if not ret_val:
|
108 |
self.count += 1
|
109 |
self.cap.release()
|
110 |
-
if self.count == self.
|
111 |
raise StopIteration
|
112 |
else:
|
113 |
path = self.files[self.count]
|
@@ -115,14 +119,14 @@ class LoadImages: # for inference
|
|
115 |
ret_val, img0 = self.cap.read()
|
116 |
|
117 |
self.frame += 1
|
118 |
-
print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.
|
119 |
|
120 |
else:
|
121 |
# Read image
|
122 |
self.count += 1
|
123 |
img0 = cv2.imread(path) # BGR
|
124 |
assert img0 is not None, 'Image Not Found ' + path
|
125 |
-
print('image %g/%g %s: ' % (self.count, self.
|
126 |
|
127 |
# Padded resize
|
128 |
img = letterbox(img0, new_shape=self.img_size)[0]
|
@@ -140,7 +144,7 @@ class LoadImages: # for inference
|
|
140 |
self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
141 |
|
142 |
def __len__(self):
|
143 |
-
return self.
|
144 |
|
145 |
|
146 |
class LoadWebcam: # for inference
|
@@ -470,6 +474,13 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
470 |
img, labels = load_mosaic(self, index)
|
471 |
shapes = None
|
472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
else:
|
474 |
# Load image
|
475 |
img, (h0, w0), (h, w) = load_image(self, index)
|
|
|
68 |
|
69 |
class LoadImages: # for inference
|
70 |
def __init__(self, path, img_size=640):
|
71 |
+
p = str(Path(path)) # os-agnostic
|
72 |
+
p = os.path.abspath(p) # absolute path
|
73 |
+
if '*' in p:
|
74 |
+
files = sorted(glob.glob(p)) # glob
|
75 |
+
elif os.path.isdir(p):
|
76 |
+
files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
|
77 |
+
elif os.path.isfile(p):
|
78 |
+
files = [p] # files
|
79 |
+
else:
|
80 |
+
raise Exception('ERROR: %s does not exist' % p)
|
81 |
|
82 |
images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
|
83 |
videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
|
84 |
+
ni, nv = len(images), len(videos)
|
85 |
|
86 |
self.img_size = img_size
|
87 |
self.files = images + videos
|
88 |
+
self.nf = ni + nv # number of files
|
89 |
+
self.video_flag = [False] * ni + [True] * nv
|
90 |
self.mode = 'images'
|
91 |
if any(videos):
|
92 |
self.new_video(videos[0]) # new video
|
93 |
else:
|
94 |
self.cap = None
|
95 |
+
assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
|
96 |
+
(p, img_formats, vid_formats)
|
97 |
|
98 |
def __iter__(self):
|
99 |
self.count = 0
|
100 |
return self
|
101 |
|
102 |
def __next__(self):
|
103 |
+
if self.count == self.nf:
|
104 |
raise StopIteration
|
105 |
path = self.files[self.count]
|
106 |
|
|
|
111 |
if not ret_val:
|
112 |
self.count += 1
|
113 |
self.cap.release()
|
114 |
+
if self.count == self.nf: # last video
|
115 |
raise StopIteration
|
116 |
else:
|
117 |
path = self.files[self.count]
|
|
|
119 |
ret_val, img0 = self.cap.read()
|
120 |
|
121 |
self.frame += 1
|
122 |
+
print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')
|
123 |
|
124 |
else:
|
125 |
# Read image
|
126 |
self.count += 1
|
127 |
img0 = cv2.imread(path) # BGR
|
128 |
assert img0 is not None, 'Image Not Found ' + path
|
129 |
+
print('image %g/%g %s: ' % (self.count, self.nf, path), end='')
|
130 |
|
131 |
# Padded resize
|
132 |
img = letterbox(img0, new_shape=self.img_size)[0]
|
|
|
144 |
self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
145 |
|
146 |
def __len__(self):
|
147 |
+
return self.nf # number of files
|
148 |
|
149 |
|
150 |
class LoadWebcam: # for inference
|
|
|
474 |
img, labels = load_mosaic(self, index)
|
475 |
shapes = None
|
476 |
|
477 |
+
# MixUp https://arxiv.org/pdf/1710.09412.pdf
|
478 |
+
# if random.random() < 0.5:
|
479 |
+
# img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
|
480 |
+
# r = np.random.beta(0.3, 0.3) # mixup ratio, alpha=beta=0.3
|
481 |
+
# img = (img * r + img2 * (1 - r)).astype(np.uint8)
|
482 |
+
# labels = np.concatenate((labels, labels2), 0)
|
483 |
+
|
484 |
else:
|
485 |
# Load image
|
486 |
img, (h0, w0), (h, w) = load_image(self, index)
|
utils/torch_utils.py
CHANGED
@@ -195,6 +195,8 @@ class ModelEMA:
|
|
195 |
def __init__(self, model, decay=0.9999, updates=0):
|
196 |
# Create EMA
|
197 |
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
|
|
|
|
198 |
self.updates = updates # number of EMA updates
|
199 |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
200 |
for p in self.ema.parameters():
|
|
|
195 |
def __init__(self, model, decay=0.9999, updates=0):
|
196 |
# Create EMA
|
197 |
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
198 |
+
if next(model.parameters()).device.type != 'cpu':
|
199 |
+
self.ema.half() # FP16 EMA
|
200 |
self.updates = updates # number of EMA updates
|
201 |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
202 |
for p in self.ema.parameters():
|