glenn-jocher
commited on
Commit
β’
5d66e48
1
Parent(s):
3fef117
Train from `--data path/to/dataset.zip` feature (#4185)
Browse files* Train from `--data path/to/dataset.zip` feature
* Update dataset_stats()
* cleanup
* cleanup2
- data/{Argoverse_HD.yaml β Argoverse.yaml} +1 -1
- hubconf.py +1 -1
- models/experimental.py +1 -1
- train.py +4 -7
- utils/datasets.py +50 -16
- utils/{google_utils.py β downloads.py} +5 -1
- utils/general.py +29 -11
- utils/loggers/wandb/wandb_utils.py +30 -32
- val.py +1 -3
data/{Argoverse_HD.yaml β Argoverse.yaml}
RENAMED
@@ -1,6 +1,6 @@
|
|
1 |
# YOLOv5 π by Ultralytics https://ultralytics.com, licensed under GNU GPL v3.0
|
2 |
# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/
|
3 |
-
# Example usage: python train.py --data
|
4 |
# parent
|
5 |
# βββ yolov5
|
6 |
# βββ datasets
|
|
|
1 |
# YOLOv5 π by Ultralytics https://ultralytics.com, licensed under GNU GPL v3.0
|
2 |
# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/
|
3 |
+
# Example usage: python train.py --data Argoverse.yaml
|
4 |
# parent
|
5 |
# βββ yolov5
|
6 |
# βββ datasets
|
hubconf.py
CHANGED
@@ -27,7 +27,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
27 |
|
28 |
from models.yolo import Model, attempt_load
|
29 |
from utils.general import check_requirements, set_logging
|
30 |
-
from utils.
|
31 |
from utils.torch_utils import select_device
|
32 |
|
33 |
file = Path(__file__).absolute()
|
|
|
27 |
|
28 |
from models.yolo import Model, attempt_load
|
29 |
from utils.general import check_requirements, set_logging
|
30 |
+
from utils.downloads import attempt_download
|
31 |
from utils.torch_utils import select_device
|
32 |
|
33 |
file = Path(__file__).absolute()
|
models/experimental.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from models.common import Conv, DWConv
|
8 |
-
from utils.
|
9 |
|
10 |
|
11 |
class CrossConv(nn.Module):
|
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from models.common import Conv, DWConv
|
8 |
+
from utils.downloads import attempt_download
|
9 |
|
10 |
|
11 |
class CrossConv(nn.Module):
|
train.py
CHANGED
@@ -35,7 +35,7 @@ from utils.datasets import create_dataloader
|
|
35 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
36 |
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
37 |
check_requirements, print_mutation, set_logging, one_cycle, colorstr
|
38 |
-
from utils.
|
39 |
from utils.loss import ComputeLoss
|
40 |
from utils.plots import plot_labels, plot_evolution
|
41 |
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
|
@@ -78,9 +78,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
78 |
plots = not evolve # create plots
|
79 |
cuda = device.type != 'cpu'
|
80 |
init_seeds(1 + RANK)
|
81 |
-
with
|
82 |
-
data_dict =
|
83 |
-
|
84 |
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
85 |
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
86 |
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
|
@@ -106,9 +106,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
106 |
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
|
107 |
else:
|
108 |
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
109 |
-
with torch_distributed_zero_first(RANK):
|
110 |
-
check_dataset(data_dict) # check
|
111 |
-
train_path, val_path = data_dict['train'], data_dict['val']
|
112 |
|
113 |
# Freeze
|
114 |
freeze = [] # parameter names to freeze (full or partial)
|
|
|
35 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
36 |
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
37 |
check_requirements, print_mutation, set_logging, one_cycle, colorstr
|
38 |
+
from utils.downloads import attempt_download
|
39 |
from utils.loss import ComputeLoss
|
40 |
from utils.plots import plot_labels, plot_evolution
|
41 |
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
|
|
|
78 |
plots = not evolve # create plots
|
79 |
cuda = device.type != 'cpu'
|
80 |
init_seeds(1 + RANK)
|
81 |
+
with torch_distributed_zero_first(RANK):
|
82 |
+
data_dict = check_dataset(data) # check
|
83 |
+
train_path, val_path = data_dict['train'], data_dict['val']
|
84 |
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
85 |
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
86 |
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
|
|
|
106 |
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
|
107 |
else:
|
108 |
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
|
|
|
|
|
|
109 |
|
110 |
# Freeze
|
111 |
freeze = [] # parameter names to freeze (full or partial)
|
utils/datasets.py
CHANGED
@@ -884,11 +884,11 @@ def verify_image_label(args):
|
|
884 |
return [None, None, None, None, nm, nf, ne, nc, msg]
|
885 |
|
886 |
|
887 |
-
def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
|
888 |
""" Return dataset statistics dictionary with images and instances counts per split per class
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
Arguments
|
893 |
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
894 |
autodownload: Attempt to download dataset if not found locally
|
@@ -897,35 +897,42 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
|
|
897 |
|
898 |
def round_labels(labels):
|
899 |
# Update labels to integer class and 6 decimal place floats
|
900 |
-
return [[int(c), *[round(x,
|
901 |
|
902 |
def unzip(path):
|
903 |
# Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
|
904 |
if str(path).endswith('.zip'): # path is data.zip
|
|
|
905 |
assert os.system(f'unzip -q {path} -d {path.parent}') == 0, f'Error unzipping {path}'
|
906 |
-
|
907 |
-
return True,
|
908 |
else: # path is data.yaml
|
909 |
return False, None, path
|
910 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
911 |
zipped, data_dir, yaml_path = unzip(Path(path))
|
912 |
with open(check_file(yaml_path), encoding='ascii', errors='ignore') as f:
|
913 |
data = yaml.safe_load(f) # data dict
|
914 |
if zipped:
|
915 |
data['path'] = data_dir # TODO: should this be dir.resolve()?
|
916 |
check_dataset(data, autodownload) # download dataset if missing
|
917 |
-
|
918 |
-
stats = {'nc': nc, 'names': data['names']} # statistics dictionary
|
919 |
for split in 'train', 'val', 'test':
|
920 |
if data.get(split) is None:
|
921 |
stats[split] = None # i.e. no test set
|
922 |
continue
|
923 |
x = []
|
924 |
-
dataset = LoadImagesAndLabels(data[split]
|
925 |
-
if split == 'train':
|
926 |
-
cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
|
927 |
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
|
928 |
-
x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
|
929 |
x = np.array(x) # shape(128x80)
|
930 |
stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
|
931 |
'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
|
@@ -933,10 +940,37 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
|
|
933 |
'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
|
934 |
zip(dataset.img_files, dataset.labels)]}
|
935 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
936 |
# Save, print and return
|
937 |
-
|
938 |
-
|
|
|
|
|
939 |
if verbose:
|
940 |
print(json.dumps(stats, indent=2, sort_keys=False))
|
941 |
-
# print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
|
942 |
return stats
|
|
|
884 |
return [None, None, None, None, nm, nf, ne, nc, msg]
|
885 |
|
886 |
|
887 |
+
def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
|
888 |
""" Return dataset statistics dictionary with images and instances counts per split per class
|
889 |
+
To run in parent directory: export PYTHONPATH="$PWD/yolov5"
|
890 |
+
Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
|
891 |
+
Usage2: from utils.datasets import *; dataset_stats('../datasets/coco128_with_yaml.zip')
|
892 |
Arguments
|
893 |
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
894 |
autodownload: Attempt to download dataset if not found locally
|
|
|
897 |
|
898 |
def round_labels(labels):
|
899 |
# Update labels to integer class and 6 decimal place floats
|
900 |
+
return [[int(c), *[round(x, 4) for x in points]] for c, *points in labels]
|
901 |
|
902 |
def unzip(path):
|
903 |
# Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
|
904 |
if str(path).endswith('.zip'): # path is data.zip
|
905 |
+
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
|
906 |
assert os.system(f'unzip -q {path} -d {path.parent}') == 0, f'Error unzipping {path}'
|
907 |
+
dir = path.with_suffix('') # dataset directory
|
908 |
+
return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
|
909 |
else: # path is data.yaml
|
910 |
return False, None, path
|
911 |
|
912 |
+
def hub_ops(f, max_dim=1920):
|
913 |
+
# HUB ops for 1 image 'f'
|
914 |
+
im = Image.open(f)
|
915 |
+
r = max_dim / max(im.height, im.width) # ratio
|
916 |
+
if r < 1.0: # image too large
|
917 |
+
im = im.resize((int(im.width * r), int(im.height * r)))
|
918 |
+
im.save(im_dir / Path(f).name, quality=75) # save
|
919 |
+
|
920 |
zipped, data_dir, yaml_path = unzip(Path(path))
|
921 |
with open(check_file(yaml_path), encoding='ascii', errors='ignore') as f:
|
922 |
data = yaml.safe_load(f) # data dict
|
923 |
if zipped:
|
924 |
data['path'] = data_dir # TODO: should this be dir.resolve()?
|
925 |
check_dataset(data, autodownload) # download dataset if missing
|
926 |
+
hub_dir = Path(data['path'] + ('-hub' if hub else ''))
|
927 |
+
stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
|
928 |
for split in 'train', 'val', 'test':
|
929 |
if data.get(split) is None:
|
930 |
stats[split] = None # i.e. no test set
|
931 |
continue
|
932 |
x = []
|
933 |
+
dataset = LoadImagesAndLabels(data[split]) # load dataset
|
|
|
|
|
934 |
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
|
935 |
+
x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
|
936 |
x = np.array(x) # shape(128x80)
|
937 |
stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
|
938 |
'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
|
|
|
940 |
'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
|
941 |
zip(dataset.img_files, dataset.labels)]}
|
942 |
|
943 |
+
if hub:
|
944 |
+
im_dir = hub_dir / 'images'
|
945 |
+
im_dir.mkdir(parents=True, exist_ok=True)
|
946 |
+
for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.img_files), total=dataset.n, desc='HUB Ops'):
|
947 |
+
pass
|
948 |
+
|
949 |
+
# Profile
|
950 |
+
stats_path = hub_dir / 'stats.json'
|
951 |
+
if profile:
|
952 |
+
for _ in range(1):
|
953 |
+
file = stats_path.with_suffix('.npy')
|
954 |
+
t1 = time.time()
|
955 |
+
np.save(file, stats)
|
956 |
+
t2 = time.time()
|
957 |
+
x = np.load(file, allow_pickle=True)
|
958 |
+
print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
|
959 |
+
|
960 |
+
file = stats_path.with_suffix('.json')
|
961 |
+
t1 = time.time()
|
962 |
+
with open(file, 'w') as f:
|
963 |
+
json.dump(stats, f) # save stats *.json
|
964 |
+
t2 = time.time()
|
965 |
+
with open(file, 'r') as f:
|
966 |
+
x = json.load(f) # load hyps dict
|
967 |
+
print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
|
968 |
+
|
969 |
# Save, print and return
|
970 |
+
if hub:
|
971 |
+
print(f'Saving {stats_path.resolve()}...')
|
972 |
+
with open(stats_path, 'w') as f:
|
973 |
+
json.dump(stats, f) # save stats.json
|
974 |
if verbose:
|
975 |
print(json.dumps(stats, indent=2, sort_keys=False))
|
|
|
976 |
return stats
|
utils/{google_utils.py β downloads.py}
RENAMED
@@ -1,4 +1,4 @@
|
|
1 |
-
#
|
2 |
|
3 |
import os
|
4 |
import platform
|
@@ -115,6 +115,10 @@ def get_token(cookie="./cookie"):
|
|
115 |
return line.split()[-1]
|
116 |
return ""
|
117 |
|
|
|
|
|
|
|
|
|
118 |
# def upload_blob(bucket_name, source_file_name, destination_blob_name):
|
119 |
# # Uploads a file to a bucket
|
120 |
# # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python
|
|
|
1 |
+
# Download utils
|
2 |
|
3 |
import os
|
4 |
import platform
|
|
|
115 |
return line.split()[-1]
|
116 |
return ""
|
117 |
|
118 |
+
|
119 |
+
# Google utils: https://cloud.google.com/storage/docs/reference/libraries ----------------------------------------------
|
120 |
+
#
|
121 |
+
#
|
122 |
# def upload_blob(bucket_name, source_file_name, destination_blob_name):
|
123 |
# # Uploads a file to a bucket
|
124 |
# # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python
|
utils/general.py
CHANGED
@@ -24,7 +24,7 @@ import torch
|
|
24 |
import torchvision
|
25 |
import yaml
|
26 |
|
27 |
-
from utils.
|
28 |
from utils.metrics import box_iou, fitness
|
29 |
from utils.torch_utils import init_torch_seeds
|
30 |
|
@@ -224,16 +224,30 @@ def check_file(file):
|
|
224 |
|
225 |
|
226 |
def check_dataset(data, autodownload=True):
|
227 |
-
# Download dataset if not found locally
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
assert 'nc' in data, "Dataset 'nc' key missing."
|
235 |
if 'names' not in data:
|
236 |
-
data['names'] = [
|
237 |
train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
|
238 |
if val:
|
239 |
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
@@ -256,13 +270,17 @@ def check_dataset(data, autodownload=True):
|
|
256 |
else:
|
257 |
raise Exception('Dataset not found.')
|
258 |
|
|
|
|
|
259 |
|
260 |
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
|
261 |
-
# Multi-threaded file download and unzip function
|
262 |
def download_one(url, dir):
|
263 |
# Download 1 file
|
264 |
f = dir / Path(url).name # filename
|
265 |
-
if
|
|
|
|
|
266 |
print(f'Downloading {url} to {f}...')
|
267 |
if curl:
|
268 |
os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
|
@@ -286,7 +304,7 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
|
|
286 |
pool.close()
|
287 |
pool.join()
|
288 |
else:
|
289 |
-
for u in
|
290 |
download_one(u, dir)
|
291 |
|
292 |
|
|
|
24 |
import torchvision
|
25 |
import yaml
|
26 |
|
27 |
+
from utils.downloads import gsutil_getsize
|
28 |
from utils.metrics import box_iou, fitness
|
29 |
from utils.torch_utils import init_torch_seeds
|
30 |
|
|
|
224 |
|
225 |
|
226 |
def check_dataset(data, autodownload=True):
|
227 |
+
# Download and/or unzip dataset if not found locally
|
228 |
+
# Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
|
229 |
+
|
230 |
+
# Download (optional)
|
231 |
+
extract_dir = ''
|
232 |
+
if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
|
233 |
+
download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
|
234 |
+
data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
|
235 |
+
extract_dir, autodownload = data.parent, False
|
236 |
+
|
237 |
+
# Read yaml (optional)
|
238 |
+
if isinstance(data, (str, Path)):
|
239 |
+
with open(data, encoding='ascii', errors='ignore') as f:
|
240 |
+
data = yaml.safe_load(f) # dictionary
|
241 |
+
|
242 |
+
# Parse yaml
|
243 |
+
path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
|
244 |
+
for k in 'train', 'val', 'test':
|
245 |
+
if data.get(k): # prepend path
|
246 |
+
data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
|
247 |
|
248 |
assert 'nc' in data, "Dataset 'nc' key missing."
|
249 |
if 'names' not in data:
|
250 |
+
data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
|
251 |
train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
|
252 |
if val:
|
253 |
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
|
|
270 |
else:
|
271 |
raise Exception('Dataset not found.')
|
272 |
|
273 |
+
return data # dictionary
|
274 |
+
|
275 |
|
276 |
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
|
277 |
+
# Multi-threaded file download and unzip function, used in data.yaml for autodownload
|
278 |
def download_one(url, dir):
|
279 |
# Download 1 file
|
280 |
f = dir / Path(url).name # filename
|
281 |
+
if Path(url).is_file(): # exists in current path
|
282 |
+
Path(url).rename(f) # move to dir
|
283 |
+
elif not f.exists():
|
284 |
print(f'Downloading {url} to {f}...')
|
285 |
if curl:
|
286 |
os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
|
|
|
304 |
pool.close()
|
305 |
pool.join()
|
306 |
else:
|
307 |
+
for u in [url] if isinstance(url, (str, Path)) else url:
|
308 |
download_one(u, dir)
|
309 |
|
310 |
|
utils/loggers/wandb/wandb_utils.py
CHANGED
@@ -100,7 +100,7 @@ class WandbLogger():
|
|
100 |
"""
|
101 |
|
102 |
def __init__(self, opt, run_id, data_dict, job_type='Training'):
|
103 |
-
|
104 |
- Initialize WandbLogger instance
|
105 |
- Upload dataset if opt.upload_dataset is True
|
106 |
- Setup trainig processes if job_type is 'Training'
|
@@ -111,7 +111,7 @@ class WandbLogger():
|
|
111 |
data_dict (Dict) -- Dictionary conataining info about the dataset to be used
|
112 |
job_type (str) -- To set the job_type for this run
|
113 |
|
114 |
-
|
115 |
# Pre-training routine --
|
116 |
self.job_type = job_type
|
117 |
self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run
|
@@ -157,7 +157,7 @@ class WandbLogger():
|
|
157 |
self.data_dict = self.check_and_upload_dataset(opt)
|
158 |
|
159 |
def check_and_upload_dataset(self, opt):
|
160 |
-
|
161 |
Check if the dataset format is compatible and upload it as W&B artifact
|
162 |
|
163 |
arguments:
|
@@ -165,7 +165,7 @@ class WandbLogger():
|
|
165 |
|
166 |
returns:
|
167 |
Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
|
168 |
-
|
169 |
assert wandb, 'Install wandb to upload dataset'
|
170 |
config_path = self.log_dataset_artifact(check_file(opt.data),
|
171 |
opt.single_cls,
|
@@ -176,7 +176,7 @@ class WandbLogger():
|
|
176 |
return wandb_data_dict
|
177 |
|
178 |
def setup_training(self, opt, data_dict):
|
179 |
-
|
180 |
Setup the necessary processes for training YOLO models:
|
181 |
- Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
|
182 |
- Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded
|
@@ -188,7 +188,7 @@ class WandbLogger():
|
|
188 |
|
189 |
returns:
|
190 |
data_dict (Dict) -- contains the updated info about the dataset to be used for training
|
191 |
-
|
192 |
self.log_dict, self.current_epoch = {}, 0
|
193 |
self.bbox_interval = opt.bbox_interval
|
194 |
if isinstance(opt.resume, str):
|
@@ -224,7 +224,7 @@ class WandbLogger():
|
|
224 |
return data_dict
|
225 |
|
226 |
def download_dataset_artifact(self, path, alias):
|
227 |
-
|
228 |
download the model checkpoint artifact if the path starts with WANDB_ARTIFACT_PREFIX
|
229 |
|
230 |
arguments:
|
@@ -234,7 +234,7 @@ class WandbLogger():
|
|
234 |
returns:
|
235 |
(str, wandb.Artifact) -- path of the downladed dataset and it's corresponding artifact object if dataset
|
236 |
is found otherwise returns (None, None)
|
237 |
-
|
238 |
if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
|
239 |
artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
|
240 |
dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\", "/"))
|
@@ -244,12 +244,12 @@ class WandbLogger():
|
|
244 |
return None, None
|
245 |
|
246 |
def download_model_artifact(self, opt):
|
247 |
-
|
248 |
download the model checkpoint artifact if the resume path starts with WANDB_ARTIFACT_PREFIX
|
249 |
|
250 |
arguments:
|
251 |
opt (namespace) -- Commandline arguments for this run
|
252 |
-
|
253 |
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
254 |
model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
|
255 |
assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
|
@@ -262,7 +262,7 @@ class WandbLogger():
|
|
262 |
return None, None
|
263 |
|
264 |
def log_model(self, path, opt, epoch, fitness_score, best_model=False):
|
265 |
-
|
266 |
Log the model checkpoint as W&B artifact
|
267 |
|
268 |
arguments:
|
@@ -271,7 +271,7 @@ class WandbLogger():
|
|
271 |
epoch (int) -- Current epoch number
|
272 |
fitness_score (float) -- fitness score for current epoch
|
273 |
best_model (boolean) -- Boolean representing if the current checkpoint is the best yet.
|
274 |
-
|
275 |
model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
|
276 |
'original_url': str(path),
|
277 |
'epochs_trained': epoch + 1,
|
@@ -286,7 +286,7 @@ class WandbLogger():
|
|
286 |
print("Saving model artifact on epoch ", epoch + 1)
|
287 |
|
288 |
def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
|
289 |
-
|
290 |
Log the dataset as W&B artifact and return the new data file with W&B links
|
291 |
|
292 |
arguments:
|
@@ -298,10 +298,8 @@ class WandbLogger():
|
|
298 |
|
299 |
returns:
|
300 |
the new .yaml file with artifact links. it can be used to start training directly from artifacts
|
301 |
-
|
302 |
-
|
303 |
-
data = yaml.safe_load(f) # data dict
|
304 |
-
check_dataset(data)
|
305 |
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
|
306 |
names = {k: v for k, v in enumerate(names)} # to index dictionary
|
307 |
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
|
@@ -330,17 +328,17 @@ class WandbLogger():
|
|
330 |
return path
|
331 |
|
332 |
def map_val_table_path(self):
|
333 |
-
|
334 |
Map the validation dataset Table like name of file -> it's id in the W&B Table.
|
335 |
Useful for - referencing artifacts for evaluation.
|
336 |
-
|
337 |
self.val_table_path_map = {}
|
338 |
print("Mapping dataset")
|
339 |
for i, data in enumerate(tqdm(self.val_table.data)):
|
340 |
self.val_table_path_map[data[3]] = data[0]
|
341 |
|
342 |
def create_dataset_table(self, dataset, class_to_id, name='dataset'):
|
343 |
-
|
344 |
Create and return W&B artifact containing W&B Table of the dataset.
|
345 |
|
346 |
arguments:
|
@@ -350,7 +348,7 @@ class WandbLogger():
|
|
350 |
|
351 |
returns:
|
352 |
dataset artifact to be logged or used
|
353 |
-
|
354 |
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
|
355 |
artifact = wandb.Artifact(name=name, type="dataset")
|
356 |
img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
|
@@ -382,14 +380,14 @@ class WandbLogger():
|
|
382 |
return artifact
|
383 |
|
384 |
def log_training_progress(self, predn, path, names):
|
385 |
-
|
386 |
Build evaluation Table. Uses reference from validation dataset table.
|
387 |
|
388 |
arguments:
|
389 |
predn (list): list of predictions in the native space in the format - [xmin, ymin, xmax, ymax, confidence, class]
|
390 |
path (str): local path of the current evaluation image
|
391 |
names (dict(int, str)): hash map that maps class ids to labels
|
392 |
-
|
393 |
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
|
394 |
box_data = []
|
395 |
total_conf = 0
|
@@ -412,17 +410,17 @@ class WandbLogger():
|
|
412 |
)
|
413 |
|
414 |
def val_one_image(self, pred, predn, path, names, im):
|
415 |
-
|
416 |
Log validation data for one image. updates the result Table if validation dataset is uploaded and log bbox media panel
|
417 |
|
418 |
arguments:
|
419 |
pred (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class]
|
420 |
predn (list): list of predictions in the native space - [xmin, ymin, xmax, ymax, confidence, class]
|
421 |
path (str): local path of the current evaluation image
|
422 |
-
|
423 |
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
|
424 |
self.log_training_progress(predn, path, names)
|
425 |
-
|
426 |
if len(self.bbox_media_panel_images) < self.max_imgs_to_log and self.current_epoch > 0:
|
427 |
if self.current_epoch % self.bbox_interval == 0:
|
428 |
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
@@ -434,23 +432,23 @@ class WandbLogger():
|
|
434 |
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))
|
435 |
|
436 |
def log(self, log_dict):
|
437 |
-
|
438 |
save the metrics to the logging dictionary
|
439 |
|
440 |
arguments:
|
441 |
log_dict (Dict) -- metrics/media to be logged in current step
|
442 |
-
|
443 |
if self.wandb_run:
|
444 |
for key, value in log_dict.items():
|
445 |
self.log_dict[key] = value
|
446 |
|
447 |
def end_epoch(self, best_result=False):
|
448 |
-
|
449 |
commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.
|
450 |
|
451 |
arguments:
|
452 |
best_result (boolean): Boolean representing if the result of this evaluation is best or not
|
453 |
-
|
454 |
if self.wandb_run:
|
455 |
with all_logging_disabled():
|
456 |
if self.bbox_media_panel_images:
|
@@ -468,9 +466,9 @@ class WandbLogger():
|
|
468 |
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
|
469 |
|
470 |
def finish_run(self):
|
471 |
-
|
472 |
Log metrics if any and finish the current W&B run
|
473 |
-
|
474 |
if self.wandb_run:
|
475 |
if self.log_dict:
|
476 |
with all_logging_disabled():
|
|
|
100 |
"""
|
101 |
|
102 |
def __init__(self, opt, run_id, data_dict, job_type='Training'):
|
103 |
+
"""
|
104 |
- Initialize WandbLogger instance
|
105 |
- Upload dataset if opt.upload_dataset is True
|
106 |
- Setup trainig processes if job_type is 'Training'
|
|
|
111 |
data_dict (Dict) -- Dictionary conataining info about the dataset to be used
|
112 |
job_type (str) -- To set the job_type for this run
|
113 |
|
114 |
+
"""
|
115 |
# Pre-training routine --
|
116 |
self.job_type = job_type
|
117 |
self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run
|
|
|
157 |
self.data_dict = self.check_and_upload_dataset(opt)
|
158 |
|
159 |
def check_and_upload_dataset(self, opt):
|
160 |
+
"""
|
161 |
Check if the dataset format is compatible and upload it as W&B artifact
|
162 |
|
163 |
arguments:
|
|
|
165 |
|
166 |
returns:
|
167 |
Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
|
168 |
+
"""
|
169 |
assert wandb, 'Install wandb to upload dataset'
|
170 |
config_path = self.log_dataset_artifact(check_file(opt.data),
|
171 |
opt.single_cls,
|
|
|
176 |
return wandb_data_dict
|
177 |
|
178 |
def setup_training(self, opt, data_dict):
|
179 |
+
"""
|
180 |
Setup the necessary processes for training YOLO models:
|
181 |
- Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
|
182 |
- Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded
|
|
|
188 |
|
189 |
returns:
|
190 |
data_dict (Dict) -- contains the updated info about the dataset to be used for training
|
191 |
+
"""
|
192 |
self.log_dict, self.current_epoch = {}, 0
|
193 |
self.bbox_interval = opt.bbox_interval
|
194 |
if isinstance(opt.resume, str):
|
|
|
224 |
return data_dict
|
225 |
|
226 |
def download_dataset_artifact(self, path, alias):
|
227 |
+
"""
|
228 |
download the model checkpoint artifact if the path starts with WANDB_ARTIFACT_PREFIX
|
229 |
|
230 |
arguments:
|
|
|
234 |
returns:
|
235 |
(str, wandb.Artifact) -- path of the downladed dataset and it's corresponding artifact object if dataset
|
236 |
is found otherwise returns (None, None)
|
237 |
+
"""
|
238 |
if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
|
239 |
artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
|
240 |
dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\", "/"))
|
|
|
244 |
return None, None
|
245 |
|
246 |
def download_model_artifact(self, opt):
|
247 |
+
"""
|
248 |
download the model checkpoint artifact if the resume path starts with WANDB_ARTIFACT_PREFIX
|
249 |
|
250 |
arguments:
|
251 |
opt (namespace) -- Commandline arguments for this run
|
252 |
+
"""
|
253 |
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
254 |
model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
|
255 |
assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
|
|
|
262 |
return None, None
|
263 |
|
264 |
def log_model(self, path, opt, epoch, fitness_score, best_model=False):
|
265 |
+
"""
|
266 |
Log the model checkpoint as W&B artifact
|
267 |
|
268 |
arguments:
|
|
|
271 |
epoch (int) -- Current epoch number
|
272 |
fitness_score (float) -- fitness score for current epoch
|
273 |
best_model (boolean) -- Boolean representing if the current checkpoint is the best yet.
|
274 |
+
"""
|
275 |
model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
|
276 |
'original_url': str(path),
|
277 |
'epochs_trained': epoch + 1,
|
|
|
286 |
print("Saving model artifact on epoch ", epoch + 1)
|
287 |
|
288 |
def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
|
289 |
+
"""
|
290 |
Log the dataset as W&B artifact and return the new data file with W&B links
|
291 |
|
292 |
arguments:
|
|
|
298 |
|
299 |
returns:
|
300 |
the new .yaml file with artifact links. it can be used to start training directly from artifacts
|
301 |
+
"""
|
302 |
+
data = check_dataset(data_file) # parse and check
|
|
|
|
|
303 |
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
|
304 |
names = {k: v for k, v in enumerate(names)} # to index dictionary
|
305 |
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
|
|
|
328 |
return path
|
329 |
|
330 |
def map_val_table_path(self):
|
331 |
+
"""
|
332 |
Map the validation dataset Table like name of file -> it's id in the W&B Table.
|
333 |
Useful for - referencing artifacts for evaluation.
|
334 |
+
"""
|
335 |
self.val_table_path_map = {}
|
336 |
print("Mapping dataset")
|
337 |
for i, data in enumerate(tqdm(self.val_table.data)):
|
338 |
self.val_table_path_map[data[3]] = data[0]
|
339 |
|
340 |
def create_dataset_table(self, dataset, class_to_id, name='dataset'):
|
341 |
+
"""
|
342 |
Create and return W&B artifact containing W&B Table of the dataset.
|
343 |
|
344 |
arguments:
|
|
|
348 |
|
349 |
returns:
|
350 |
dataset artifact to be logged or used
|
351 |
+
"""
|
352 |
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
|
353 |
artifact = wandb.Artifact(name=name, type="dataset")
|
354 |
img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
|
|
|
380 |
return artifact
|
381 |
|
382 |
def log_training_progress(self, predn, path, names):
|
383 |
+
"""
|
384 |
Build evaluation Table. Uses reference from validation dataset table.
|
385 |
|
386 |
arguments:
|
387 |
predn (list): list of predictions in the native space in the format - [xmin, ymin, xmax, ymax, confidence, class]
|
388 |
path (str): local path of the current evaluation image
|
389 |
names (dict(int, str)): hash map that maps class ids to labels
|
390 |
+
"""
|
391 |
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
|
392 |
box_data = []
|
393 |
total_conf = 0
|
|
|
410 |
)
|
411 |
|
412 |
def val_one_image(self, pred, predn, path, names, im):
|
413 |
+
"""
|
414 |
Log validation data for one image. updates the result Table if validation dataset is uploaded and log bbox media panel
|
415 |
|
416 |
arguments:
|
417 |
pred (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class]
|
418 |
predn (list): list of predictions in the native space - [xmin, ymin, xmax, ymax, confidence, class]
|
419 |
path (str): local path of the current evaluation image
|
420 |
+
"""
|
421 |
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
|
422 |
self.log_training_progress(predn, path, names)
|
423 |
+
|
424 |
if len(self.bbox_media_panel_images) < self.max_imgs_to_log and self.current_epoch > 0:
|
425 |
if self.current_epoch % self.bbox_interval == 0:
|
426 |
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
|
|
432 |
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))
|
433 |
|
434 |
def log(self, log_dict):
|
435 |
+
"""
|
436 |
save the metrics to the logging dictionary
|
437 |
|
438 |
arguments:
|
439 |
log_dict (Dict) -- metrics/media to be logged in current step
|
440 |
+
"""
|
441 |
if self.wandb_run:
|
442 |
for key, value in log_dict.items():
|
443 |
self.log_dict[key] = value
|
444 |
|
445 |
def end_epoch(self, best_result=False):
|
446 |
+
"""
|
447 |
commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.
|
448 |
|
449 |
arguments:
|
450 |
best_result (boolean): Boolean representing if the result of this evaluation is best or not
|
451 |
+
"""
|
452 |
if self.wandb_run:
|
453 |
with all_logging_disabled():
|
454 |
if self.bbox_media_panel_images:
|
|
|
466 |
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
|
467 |
|
468 |
def finish_run(self):
|
469 |
+
"""
|
470 |
Log metrics if any and finish the current W&B run
|
471 |
+
"""
|
472 |
if self.wandb_run:
|
473 |
if self.log_dict:
|
474 |
with all_logging_disabled():
|
val.py
CHANGED
@@ -123,9 +123,7 @@ def run(data,
|
|
123 |
# model = nn.DataParallel(model)
|
124 |
|
125 |
# Data
|
126 |
-
|
127 |
-
data = yaml.safe_load(f)
|
128 |
-
check_dataset(data) # check
|
129 |
|
130 |
# Half
|
131 |
half &= device.type != 'cpu' # half precision only supported on CUDA
|
|
|
123 |
# model = nn.DataParallel(model)
|
124 |
|
125 |
# Data
|
126 |
+
data = check_dataset(data) # check
|
|
|
|
|
127 |
|
128 |
# Half
|
129 |
half &= device.type != 'cpu' # half precision only supported on CUDA
|