glenn-jocher
commited on
Commit
•
1f69d12
1
Parent(s):
75c0ff4
Update 4 main ops for paths and .run() (#3715)
Browse files* Add yolov5/ to path
* rename functions to run()
* cleanup
* rename fix
* CI fix
* cleanup find models/export.py
- .github/workflows/ci-testing.yml +1 -1
- .github/workflows/greetings.yml +1 -1
- detect.py +35 -25
- models/export.py → export.py +16 -16
- test.py +41 -31
- train.py +34 -24
- tutorial.ipynb +2 -2
.github/workflows/ci-testing.yml
CHANGED
@@ -74,5 +74,5 @@ jobs:
|
|
74 |
|
75 |
python hubconf.py # hub
|
76 |
python models/yolo.py --cfg ${{ matrix.model }}.yaml # inspect
|
77 |
-
python
|
78 |
shell: bash
|
|
|
74 |
|
75 |
python hubconf.py # hub
|
76 |
python models/yolo.py --cfg ${{ matrix.model }}.yaml # inspect
|
77 |
+
python export.py --img 128 --batch 1 --weights ${{ matrix.model }}.pt # export
|
78 |
shell: bash
|
.github/workflows/greetings.yml
CHANGED
@@ -52,5 +52,5 @@ jobs:
|
|
52 |
|
53 |
![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)
|
54 |
|
55 |
-
If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/
|
56 |
|
|
|
52 |
|
53 |
![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)
|
54 |
|
55 |
+
If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.
|
56 |
|
detect.py
CHANGED
@@ -1,4 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
|
|
2 |
import time
|
3 |
from pathlib import Path
|
4 |
|
@@ -6,6 +13,9 @@ import cv2
|
|
6 |
import torch
|
7 |
import torch.backends.cudnn as cudnn
|
8 |
|
|
|
|
|
|
|
9 |
from models.experimental import attempt_load
|
10 |
from utils.datasets import LoadStreams, LoadImages
|
11 |
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
|
@@ -15,30 +25,30 @@ from utils.torch_utils import select_device, load_classifier, time_synchronized
|
|
15 |
|
16 |
|
17 |
@torch.no_grad()
|
18 |
-
def
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
save_img = not nosave and not source.endswith('.txt') # save inference images
|
43 |
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
|
44 |
('rtsp://', 'rtmp://', 'http://', 'https://'))
|
@@ -204,7 +214,7 @@ def parse_opt():
|
|
204 |
def main(opt):
|
205 |
print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
|
206 |
check_requirements(exclude=('tensorboard', 'thop'))
|
207 |
-
|
208 |
|
209 |
|
210 |
if __name__ == "__main__":
|
|
|
1 |
+
"""Run inference with a YOLOv5 model on images, videos, directories, streams
|
2 |
+
|
3 |
+
Usage:
|
4 |
+
$ python path/to/detect.py --source path/to/img.jpg --weights yolov5s.pt --img 640
|
5 |
+
"""
|
6 |
+
|
7 |
import argparse
|
8 |
+
import sys
|
9 |
import time
|
10 |
from pathlib import Path
|
11 |
|
|
|
13 |
import torch
|
14 |
import torch.backends.cudnn as cudnn
|
15 |
|
16 |
+
FILE = Path(__file__).absolute()
|
17 |
+
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
|
18 |
+
|
19 |
from models.experimental import attempt_load
|
20 |
from utils.datasets import LoadStreams, LoadImages
|
21 |
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
|
|
|
25 |
|
26 |
|
27 |
@torch.no_grad()
|
28 |
+
def run(weights='yolov5s.pt', # model.pt path(s)
|
29 |
+
source='data/images', # file/dir/URL/glob, 0 for webcam
|
30 |
+
imgsz=640, # inference size (pixels)
|
31 |
+
conf_thres=0.25, # confidence threshold
|
32 |
+
iou_thres=0.45, # NMS IOU threshold
|
33 |
+
max_det=1000, # maximum detections per image
|
34 |
+
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
35 |
+
view_img=False, # show results
|
36 |
+
save_txt=False, # save results to *.txt
|
37 |
+
save_conf=False, # save confidences in --save-txt labels
|
38 |
+
save_crop=False, # save cropped prediction boxes
|
39 |
+
nosave=False, # do not save images/videos
|
40 |
+
classes=None, # filter by class: --class 0, or --class 0 2 3
|
41 |
+
agnostic_nms=False, # class-agnostic NMS
|
42 |
+
augment=False, # augmented inference
|
43 |
+
update=False, # update all models
|
44 |
+
project='runs/detect', # save results to project/name
|
45 |
+
name='exp', # save results to project/name
|
46 |
+
exist_ok=False, # existing project/name ok, do not increment
|
47 |
+
line_thickness=3, # bounding box thickness (pixels)
|
48 |
+
hide_labels=False, # hide labels
|
49 |
+
hide_conf=False, # hide confidences
|
50 |
+
half=False, # use FP16 half-precision inference
|
51 |
+
):
|
52 |
save_img = not nosave and not source.endswith('.txt') # save inference images
|
53 |
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
|
54 |
('rtsp://', 'rtmp://', 'http://', 'https://'))
|
|
|
214 |
def main(opt):
|
215 |
print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
|
216 |
check_requirements(exclude=('tensorboard', 'thop'))
|
217 |
+
run(**vars(opt))
|
218 |
|
219 |
|
220 |
if __name__ == "__main__":
|
models/export.py → export.py
RENAMED
@@ -1,7 +1,7 @@
|
|
1 |
"""Export a YOLOv5 *.pt model to TorchScript, ONNX, CoreML formats
|
2 |
|
3 |
Usage:
|
4 |
-
$ python path/to/
|
5 |
"""
|
6 |
|
7 |
import argparse
|
@@ -14,7 +14,7 @@ import torch.nn as nn
|
|
14 |
from torch.utils.mobile_optimizer import optimize_for_mobile
|
15 |
|
16 |
FILE = Path(__file__).absolute()
|
17 |
-
sys.path.append(FILE.parents[
|
18 |
|
19 |
from models.common import Conv
|
20 |
from models.yolo import Detect
|
@@ -24,19 +24,19 @@ from utils.general import colorstr, check_img_size, check_requirements, file_siz
|
|
24 |
from utils.torch_utils import select_device
|
25 |
|
26 |
|
27 |
-
def
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
t = time.time()
|
41 |
include = [x.lower() for x in include]
|
42 |
img_size *= 2 if len(img_size) == 1 else 1 # expand
|
@@ -165,7 +165,7 @@ def parse_opt():
|
|
165 |
def main(opt):
|
166 |
set_logging()
|
167 |
print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
|
168 |
-
|
169 |
|
170 |
|
171 |
if __name__ == "__main__":
|
|
|
1 |
"""Export a YOLOv5 *.pt model to TorchScript, ONNX, CoreML formats
|
2 |
|
3 |
Usage:
|
4 |
+
$ python path/to/export.py --weights yolov5s.pt --img 640 --batch 1
|
5 |
"""
|
6 |
|
7 |
import argparse
|
|
|
14 |
from torch.utils.mobile_optimizer import optimize_for_mobile
|
15 |
|
16 |
FILE = Path(__file__).absolute()
|
17 |
+
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
|
18 |
|
19 |
from models.common import Conv
|
20 |
from models.yolo import Detect
|
|
|
24 |
from utils.torch_utils import select_device
|
25 |
|
26 |
|
27 |
+
def run(weights='./yolov5s.pt', # weights path
|
28 |
+
img_size=(640, 640), # image (height, width)
|
29 |
+
batch_size=1, # batch size
|
30 |
+
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
31 |
+
include=('torchscript', 'onnx', 'coreml'), # include formats
|
32 |
+
half=False, # FP16 half-precision export
|
33 |
+
inplace=False, # set YOLOv5 Detect() inplace=True
|
34 |
+
train=False, # model.train() mode
|
35 |
+
optimize=False, # TorchScript: optimize for mobile
|
36 |
+
dynamic=False, # ONNX: dynamic axes
|
37 |
+
simplify=False, # ONNX: simplify model
|
38 |
+
opset_version=12, # ONNX: opset version
|
39 |
+
):
|
40 |
t = time.time()
|
41 |
include = [x.lower() for x in include]
|
42 |
img_size *= 2 if len(img_size) == 1 else 1 # expand
|
|
|
165 |
def main(opt):
|
166 |
set_logging()
|
167 |
print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
|
168 |
+
run(**vars(opt))
|
169 |
|
170 |
|
171 |
if __name__ == "__main__":
|
test.py
CHANGED
@@ -1,6 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
|
|
4 |
from pathlib import Path
|
5 |
from threading import Thread
|
6 |
|
@@ -9,6 +16,9 @@ import torch
|
|
9 |
import yaml
|
10 |
from tqdm import tqdm
|
11 |
|
|
|
|
|
|
|
12 |
from models.experimental import attempt_load
|
13 |
from utils.datasets import create_dataloader
|
14 |
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
|
@@ -19,32 +29,32 @@ from utils.torch_utils import select_device, time_synchronized
|
|
19 |
|
20 |
|
21 |
@torch.no_grad()
|
22 |
-
def
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
# Initialize/load model and set device
|
49 |
training = model is not None
|
50 |
if training: # called by train.py
|
@@ -327,12 +337,12 @@ def main(opt):
|
|
327 |
check_requirements(exclude=('tensorboard', 'thop'))
|
328 |
|
329 |
if opt.task in ('train', 'val', 'test'): # run normally
|
330 |
-
|
331 |
|
332 |
elif opt.task == 'speed': # speed benchmarks
|
333 |
for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
|
334 |
-
|
335 |
-
|
336 |
|
337 |
elif opt.task == 'study': # run over a range of settings and save/plot
|
338 |
# python test.py --task study --data coco.yaml --iou 0.7 --weights yolov5s.pt yolov5m.pt yolov5l.pt yolov5x.pt
|
@@ -342,8 +352,8 @@ def main(opt):
|
|
342 |
y = [] # y axis
|
343 |
for i in x: # img-size
|
344 |
print(f'\nRunning {f} point {i}...')
|
345 |
-
r, _, t =
|
346 |
-
|
347 |
y.append(r + t) # results and times
|
348 |
np.savetxt(f, y, fmt='%10.4g') # save
|
349 |
os.system('zip -r study.zip study_*.txt')
|
|
|
1 |
+
"""Test a trained YOLOv5 model accuracy on a custom dataset
|
2 |
+
|
3 |
+
Usage:
|
4 |
+
$ python path/to/test.py --data coco128.yaml --weights yolov5s.pt --img 640
|
5 |
+
"""
|
6 |
+
|
7 |
import argparse
|
8 |
import json
|
9 |
import os
|
10 |
+
import sys
|
11 |
from pathlib import Path
|
12 |
from threading import Thread
|
13 |
|
|
|
16 |
import yaml
|
17 |
from tqdm import tqdm
|
18 |
|
19 |
+
FILE = Path(__file__).absolute()
|
20 |
+
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
|
21 |
+
|
22 |
from models.experimental import attempt_load
|
23 |
from utils.datasets import create_dataloader
|
24 |
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
|
|
|
29 |
|
30 |
|
31 |
@torch.no_grad()
|
32 |
+
def run(data,
|
33 |
+
weights=None, # model.pt path(s)
|
34 |
+
batch_size=32, # batch size
|
35 |
+
imgsz=640, # inference size (pixels)
|
36 |
+
conf_thres=0.001, # confidence threshold
|
37 |
+
iou_thres=0.6, # NMS IoU threshold
|
38 |
+
task='val', # train, val, test, speed or study
|
39 |
+
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
40 |
+
single_cls=False, # treat as single-class dataset
|
41 |
+
augment=False, # augmented inference
|
42 |
+
verbose=False, # verbose output
|
43 |
+
save_txt=False, # save results to *.txt
|
44 |
+
save_hybrid=False, # save label+prediction hybrid results to *.txt
|
45 |
+
save_conf=False, # save confidences in --save-txt labels
|
46 |
+
save_json=False, # save a cocoapi-compatible JSON results file
|
47 |
+
project='runs/test', # save to project/name
|
48 |
+
name='exp', # save to project/name
|
49 |
+
exist_ok=False, # existing project/name ok, do not increment
|
50 |
+
half=True, # use FP16 half-precision inference
|
51 |
+
model=None,
|
52 |
+
dataloader=None,
|
53 |
+
save_dir=Path(''),
|
54 |
+
plots=True,
|
55 |
+
wandb_logger=None,
|
56 |
+
compute_loss=None,
|
57 |
+
):
|
58 |
# Initialize/load model and set device
|
59 |
training = model is not None
|
60 |
if training: # called by train.py
|
|
|
337 |
check_requirements(exclude=('tensorboard', 'thop'))
|
338 |
|
339 |
if opt.task in ('train', 'val', 'test'): # run normally
|
340 |
+
run(**vars(opt))
|
341 |
|
342 |
elif opt.task == 'speed': # speed benchmarks
|
343 |
for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
|
344 |
+
run(opt.data, weights=w, batch_size=opt.batch_size, imgsz=opt.imgsz, conf_thres=.25, iou_thres=.45,
|
345 |
+
save_json=False, plots=False)
|
346 |
|
347 |
elif opt.task == 'study': # run over a range of settings and save/plot
|
348 |
# python test.py --task study --data coco.yaml --iou 0.7 --weights yolov5s.pt yolov5m.pt yolov5l.pt yolov5x.pt
|
|
|
352 |
y = [] # y axis
|
353 |
for i in x: # img-size
|
354 |
print(f'\nRunning {f} point {i}...')
|
355 |
+
r, _, t = run(opt.data, weights=w, batch_size=opt.batch_size, imgsz=i, conf_thres=opt.conf_thres,
|
356 |
+
iou_thres=opt.iou_thres, save_json=opt.save_json, plots=False)
|
357 |
y.append(r + t) # results and times
|
358 |
np.savetxt(f, y, fmt='%10.4g') # save
|
359 |
os.system('zip -r study.zip study_*.txt')
|
train.py
CHANGED
@@ -1,8 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
import math
|
4 |
import os
|
5 |
import random
|
|
|
6 |
import time
|
7 |
import warnings
|
8 |
from copy import deepcopy
|
@@ -22,6 +29,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|
22 |
from torch.utils.tensorboard import SummaryWriter
|
23 |
from tqdm import tqdm
|
24 |
|
|
|
|
|
|
|
25 |
import test # for end-of-epoch mAP
|
26 |
from models.experimental import attempt_load
|
27 |
from models.yolo import Model
|
@@ -89,7 +99,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
89 |
# W&B
|
90 |
opt.hyp = hyp # add hyperparameters
|
91 |
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
92 |
-
run_id = run_id if opt.resume else None
|
93 |
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
|
94 |
loggers['wandb'] = wandb_logger.wandb
|
95 |
if loggers['wandb']:
|
@@ -375,18 +385,18 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
375 |
final_epoch = epoch + 1 == epochs
|
376 |
if not notest or final_epoch: # Calculate mAP
|
377 |
wandb_logger.current_epoch = epoch + 1
|
378 |
-
results, maps, _ = test.
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
|
391 |
# Write
|
392 |
with open(results_file, 'a') as f:
|
@@ -443,17 +453,17 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
443 |
if not evolve:
|
444 |
if is_coco: # COCO dataset
|
445 |
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
|
446 |
-
results, _, _ = test.
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
|
458 |
# Strip optimizers
|
459 |
for f in last, best:
|
|
|
1 |
+
"""Train a YOLOv5 model on a custom dataset
|
2 |
+
|
3 |
+
Usage:
|
4 |
+
$ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
|
5 |
+
"""
|
6 |
+
|
7 |
import argparse
|
8 |
import logging
|
9 |
import math
|
10 |
import os
|
11 |
import random
|
12 |
+
import sys
|
13 |
import time
|
14 |
import warnings
|
15 |
from copy import deepcopy
|
|
|
29 |
from torch.utils.tensorboard import SummaryWriter
|
30 |
from tqdm import tqdm
|
31 |
|
32 |
+
FILE = Path(__file__).absolute()
|
33 |
+
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
|
34 |
+
|
35 |
import test # for end-of-epoch mAP
|
36 |
from models.experimental import attempt_load
|
37 |
from models.yolo import Model
|
|
|
99 |
# W&B
|
100 |
opt.hyp = hyp # add hyperparameters
|
101 |
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
102 |
+
run_id = run_id if opt.resume else None # start fresh run if transfer learning
|
103 |
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
|
104 |
loggers['wandb'] = wandb_logger.wandb
|
105 |
if loggers['wandb']:
|
|
|
385 |
final_epoch = epoch + 1 == epochs
|
386 |
if not notest or final_epoch: # Calculate mAP
|
387 |
wandb_logger.current_epoch = epoch + 1
|
388 |
+
results, maps, _ = test.run(data_dict,
|
389 |
+
batch_size=batch_size // WORLD_SIZE * 2,
|
390 |
+
imgsz=imgsz_test,
|
391 |
+
model=ema.ema,
|
392 |
+
single_cls=single_cls,
|
393 |
+
dataloader=testloader,
|
394 |
+
save_dir=save_dir,
|
395 |
+
save_json=is_coco and final_epoch,
|
396 |
+
verbose=nc < 50 and final_epoch,
|
397 |
+
plots=plots and final_epoch,
|
398 |
+
wandb_logger=wandb_logger,
|
399 |
+
compute_loss=compute_loss)
|
400 |
|
401 |
# Write
|
402 |
with open(results_file, 'a') as f:
|
|
|
453 |
if not evolve:
|
454 |
if is_coco: # COCO dataset
|
455 |
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
|
456 |
+
results, _, _ = test.run(data,
|
457 |
+
batch_size=batch_size // WORLD_SIZE * 2,
|
458 |
+
imgsz=imgsz_test,
|
459 |
+
conf_thres=0.001,
|
460 |
+
iou_thres=0.7,
|
461 |
+
model=attempt_load(m, device).half(),
|
462 |
+
single_cls=single_cls,
|
463 |
+
dataloader=testloader,
|
464 |
+
save_dir=save_dir,
|
465 |
+
save_json=True,
|
466 |
+
plots=False)
|
467 |
|
468 |
# Strip optimizers
|
469 |
for f in last, best:
|
tutorial.ipynb
CHANGED
@@ -1125,7 +1125,7 @@
|
|
1125 |
"\n",
|
1126 |
"![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)\n",
|
1127 |
"\n",
|
1128 |
-
"If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/
|
1129 |
]
|
1130 |
},
|
1131 |
{
|
@@ -1212,7 +1212,7 @@
|
|
1212 |
" done\n",
|
1213 |
" python hubconf.py # hub\n",
|
1214 |
" python models/yolo.py --cfg $m.yaml # inspect\n",
|
1215 |
-
" python
|
1216 |
"done"
|
1217 |
],
|
1218 |
"execution_count": null,
|
|
|
1125 |
"\n",
|
1126 |
"![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)\n",
|
1127 |
"\n",
|
1128 |
+
"If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.\n"
|
1129 |
]
|
1130 |
},
|
1131 |
{
|
|
|
1212 |
" done\n",
|
1213 |
" python hubconf.py # hub\n",
|
1214 |
" python models/yolo.py --cfg $m.yaml # inspect\n",
|
1215 |
+
" python export.py --weights $m.pt --img 640 --batch 1 # export\n",
|
1216 |
"done"
|
1217 |
],
|
1218 |
"execution_count": null,
|