Improved W&B integration (#2125)
Browse files* Init Commit
* new wandb integration
* Update
* Use data_dict in test
* Updates
* Update: scope of log_img
* Update: scope of log_img
* Update
* Update: Fix logging conditions
* Add tqdm bar, support for .txt dataset format
* Improve Result table Logger
* Init Commit
* new wandb integration
* Update
* Use data_dict in test
* Updates
* Update: scope of log_img
* Update: scope of log_img
* Update
* Update: Fix logging conditions
* Add tqdm bar, support for .txt dataset format
* Improve Result table Logger
* Add dataset creation in training script
* Change scope: self.wandb_run
* Add wandb-artifact:// natively
you can now use --resume with wandb run links
* Add suuport for logging dataset while training
* Cleanup
* Fix: Merge conflict
* Fix: CI tests
* Automatically use wandb config
* Fix: Resume
* Fix: CI
* Enhance: Using val_table
* More resume enhancement
* FIX : CI
* Add alias
* Get useful opt config data
* train.py cleanup
* Cleanup train.py
* more cleanup
* Cleanup| CI fix
* Reformat using PEP8
* FIX:CI
* rebase
* remove uneccesary changes
* remove uneccesary changes
* remove uneccesary changes
* remove unecessary chage from test.py
* FIX: resume from local checkpoint
* FIX:resume
* FIX:resume
* Reformat
* Performance improvement
* Fix local resume
* Fix local resume
* FIX:CI
* Fix: CI
* Imporve image logging
* (:(:Redo CI tests:):)
* Remember epochs when resuming
* Remember epochs when resuming
* Update DDP location
Potential fix for #2405
* PEP8 reformat
* 0.25 confidence threshold
* reset train.py plots syntax to previous
* reset epochs completed syntax to previous
* reset space to previous
* remove brackets
* reset comment to previous
* Update: is_coco check, remove unused code
* Remove redundant print statement
* Remove wandb imports
* remove dsviz logger from test.py
* Remove redundant change from test.py
* remove redundant changes from train.py
* reformat and improvements
* Fix typo
* Add tqdm tqdm progress when scanning files, naming improvements
Co-authored-by: Glenn Jocher <[email protected]>
- models/common.py +1 -1
- test.py +26 -23
- train.py +61 -55
- utils/wandb_logging/log_dataset.py +1 -15
- utils/wandb_logging/wandb_utils.py +193 -74
@@ -278,7 +278,7 @@ class Detections:
|
|
278 |
def print(self):
|
279 |
self.display(pprint=True) # print results
|
280 |
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
|
281 |
-
tuple(self.t))
|
282 |
|
283 |
def show(self):
|
284 |
self.display(show=True) # show results
|
|
|
278 |
def print(self):
|
279 |
self.display(pprint=True) # print results
|
280 |
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
|
281 |
+
tuple(self.t))
|
282 |
|
283 |
def show(self):
|
284 |
self.display(show=True) # show results
|
@@ -35,8 +35,9 @@ def test(data,
|
|
35 |
save_hybrid=False, # for hybrid auto-labelling
|
36 |
save_conf=False, # save auto-label confidences
|
37 |
plots=True,
|
38 |
-
|
39 |
-
compute_loss=None
|
|
|
40 |
# Initialize/load model and set device
|
41 |
training = model is not None
|
42 |
if training: # called by train.py
|
@@ -66,21 +67,19 @@ def test(data,
|
|
66 |
|
67 |
# Configure
|
68 |
model.eval()
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
72 |
check_dataset(data) # check
|
73 |
nc = 1 if single_cls else int(data['nc']) # number of classes
|
74 |
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95
|
75 |
niou = iouv.numel()
|
76 |
|
77 |
# Logging
|
78 |
-
log_imgs
|
79 |
-
|
80 |
-
|
81 |
-
except ImportError:
|
82 |
-
log_imgs = 0
|
83 |
-
|
84 |
# Dataloader
|
85 |
if not training:
|
86 |
if device.type != 'cpu':
|
@@ -147,15 +146,17 @@ def test(data,
|
|
147 |
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
|
148 |
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
149 |
|
150 |
-
# W&B logging
|
151 |
-
if
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
159 |
|
160 |
# Append to pycocotools JSON dictionary
|
161 |
if save_json:
|
@@ -239,9 +240,11 @@ def test(data,
|
|
239 |
# Plots
|
240 |
if plots:
|
241 |
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
242 |
-
if
|
243 |
-
val_batches = [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
|
244 |
-
|
|
|
|
|
245 |
|
246 |
# Save JSON
|
247 |
if save_json and len(jdict):
|
|
|
35 |
save_hybrid=False, # for hybrid auto-labelling
|
36 |
save_conf=False, # save auto-label confidences
|
37 |
plots=True,
|
38 |
+
wandb_logger=None,
|
39 |
+
compute_loss=None,
|
40 |
+
is_coco=False):
|
41 |
# Initialize/load model and set device
|
42 |
training = model is not None
|
43 |
if training: # called by train.py
|
|
|
67 |
|
68 |
# Configure
|
69 |
model.eval()
|
70 |
+
if isinstance(data, str):
|
71 |
+
is_coco = data.endswith('coco.yaml')
|
72 |
+
with open(data) as f:
|
73 |
+
data = yaml.load(f, Loader=yaml.SafeLoader)
|
74 |
check_dataset(data) # check
|
75 |
nc = 1 if single_cls else int(data['nc']) # number of classes
|
76 |
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95
|
77 |
niou = iouv.numel()
|
78 |
|
79 |
# Logging
|
80 |
+
log_imgs = 0
|
81 |
+
if wandb_logger and wandb_logger.wandb:
|
82 |
+
log_imgs = min(wandb_logger.log_imgs, 100)
|
|
|
|
|
|
|
83 |
# Dataloader
|
84 |
if not training:
|
85 |
if device.type != 'cpu':
|
|
|
146 |
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
|
147 |
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
148 |
|
149 |
+
# W&B logging - Media Panel Plots
|
150 |
+
if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # Check for test operation
|
151 |
+
if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0:
|
152 |
+
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
153 |
+
"class_id": int(cls),
|
154 |
+
"box_caption": "%s %.3f" % (names[cls], conf),
|
155 |
+
"scores": {"class_score": conf},
|
156 |
+
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
|
157 |
+
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
|
158 |
+
wandb_images.append(wandb_logger.wandb.Image(img[si], boxes=boxes, caption=path.name))
|
159 |
+
wandb_logger.log_training_progress(predn, path, names) # logs dsviz tables
|
160 |
|
161 |
# Append to pycocotools JSON dictionary
|
162 |
if save_json:
|
|
|
240 |
# Plots
|
241 |
if plots:
|
242 |
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
243 |
+
if wandb_logger and wandb_logger.wandb:
|
244 |
+
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
|
245 |
+
wandb_logger.log({"Validation": val_batches})
|
246 |
+
if wandb_images:
|
247 |
+
wandb_logger.log({"Bounding Box Debugger/Images": wandb_images})
|
248 |
|
249 |
# Save JSON
|
250 |
if save_json and len(jdict):
|
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
import math
|
@@ -33,11 +34,12 @@ from utils.google_utils import attempt_download
|
|
33 |
from utils.loss import ComputeLoss
|
34 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
35 |
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
|
|
|
36 |
|
37 |
logger = logging.getLogger(__name__)
|
38 |
|
39 |
|
40 |
-
def train(hyp, opt, device, tb_writer=None
|
41 |
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
42 |
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
43 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
@@ -61,10 +63,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
61 |
init_seeds(2 + rank)
|
62 |
with open(opt.data) as f:
|
63 |
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
|
69 |
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
70 |
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
@@ -83,6 +92,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
83 |
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
|
84 |
else:
|
85 |
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# Freeze
|
88 |
freeze = [] # parameter names to freeze (full or partial)
|
@@ -126,16 +139,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
126 |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
127 |
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
128 |
|
129 |
-
# Logging
|
130 |
-
if rank in [-1, 0] and wandb and wandb.run is None:
|
131 |
-
opt.hyp = hyp # add hyperparameters
|
132 |
-
wandb_run = wandb.init(config=opt, resume="allow",
|
133 |
-
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
|
134 |
-
name=save_dir.stem,
|
135 |
-
entity=opt.entity,
|
136 |
-
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
|
137 |
-
loggers = {'wandb': wandb} # loggers dict
|
138 |
-
|
139 |
# EMA
|
140 |
ema = ModelEMA(model) if rank in [-1, 0] else None
|
141 |
|
@@ -326,9 +329,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
326 |
# if tb_writer:
|
327 |
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
328 |
# tb_writer.add_graph(model, imgs) # add model to tensorboard
|
329 |
-
elif plots and ni == 10 and wandb:
|
330 |
-
|
331 |
-
|
332 |
|
333 |
# end batch ------------------------------------------------------------------------------------------------
|
334 |
# end epoch ----------------------------------------------------------------------------------------------------
|
@@ -343,8 +346,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
343 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
344 |
final_epoch = epoch + 1 == epochs
|
345 |
if not opt.notest or final_epoch: # Calculate mAP
|
346 |
-
|
347 |
-
|
|
|
348 |
imgsz=imgsz_test,
|
349 |
model=ema.ema,
|
350 |
single_cls=opt.single_cls,
|
@@ -352,8 +356,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
352 |
save_dir=save_dir,
|
353 |
verbose=nc < 50 and final_epoch,
|
354 |
plots=plots and final_epoch,
|
355 |
-
|
356 |
-
compute_loss=compute_loss
|
|
|
357 |
|
358 |
# Write
|
359 |
with open(results_file, 'a') as f:
|
@@ -369,8 +374,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
369 |
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
|
370 |
if tb_writer:
|
371 |
tb_writer.add_scalar(tag, x, epoch) # tensorboard
|
372 |
-
if wandb:
|
373 |
-
|
374 |
|
375 |
# Update best mAP
|
376 |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
|
@@ -386,36 +391,29 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
386 |
'ema': deepcopy(ema.ema).half(),
|
387 |
'updates': ema.updates,
|
388 |
'optimizer': optimizer.state_dict(),
|
389 |
-
'wandb_id': wandb_run.id if wandb else None}
|
390 |
|
391 |
# Save last, best and delete
|
392 |
torch.save(ckpt, last)
|
393 |
if best_fitness == fi:
|
394 |
torch.save(ckpt, best)
|
|
|
|
|
|
|
|
|
395 |
del ckpt
|
396 |
-
|
|
|
397 |
# end epoch ----------------------------------------------------------------------------------------------------
|
398 |
# end training
|
399 |
-
|
400 |
if rank in [-1, 0]:
|
401 |
-
# Strip optimizers
|
402 |
-
final = best if best.exists() else last # final model
|
403 |
-
for f in last, best:
|
404 |
-
if f.exists():
|
405 |
-
strip_optimizer(f)
|
406 |
-
if opt.bucket:
|
407 |
-
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
|
408 |
-
|
409 |
# Plots
|
410 |
if plots:
|
411 |
plot_results(save_dir=save_dir) # save as results.png
|
412 |
-
if wandb:
|
413 |
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
414 |
-
|
415 |
-
|
416 |
-
if opt.log_artifacts:
|
417 |
-
wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)
|
418 |
-
|
419 |
# Test best.pt
|
420 |
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
421 |
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
|
@@ -430,13 +428,24 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
430 |
dataloader=testloader,
|
431 |
save_dir=save_dir,
|
432 |
save_json=True,
|
433 |
-
plots=False
|
|
|
434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
else:
|
436 |
dist.destroy_process_group()
|
437 |
-
|
438 |
-
wandb.run.finish() if wandb and wandb.run else None
|
439 |
torch.cuda.empty_cache()
|
|
|
440 |
return results
|
441 |
|
442 |
|
@@ -464,8 +473,6 @@ if __name__ == '__main__':
|
|
464 |
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
|
465 |
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
466 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
467 |
-
parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
|
468 |
-
parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
|
469 |
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
|
470 |
parser.add_argument('--project', default='runs/train', help='save to project/name')
|
471 |
parser.add_argument('--entity', default=None, help='W&B entity')
|
@@ -473,6 +480,10 @@ if __name__ == '__main__':
|
|
473 |
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
474 |
parser.add_argument('--quad', action='store_true', help='quad dataloader')
|
475 |
parser.add_argument('--linear-lr', action='store_true', help='linear LR')
|
|
|
|
|
|
|
|
|
476 |
opt = parser.parse_args()
|
477 |
|
478 |
# Set DDP variables
|
@@ -484,7 +495,8 @@ if __name__ == '__main__':
|
|
484 |
check_requirements()
|
485 |
|
486 |
# Resume
|
487 |
-
|
|
|
488 |
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
|
489 |
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
|
490 |
apriori = opt.global_rank, opt.local_rank
|
@@ -517,18 +529,12 @@ if __name__ == '__main__':
|
|
517 |
|
518 |
# Train
|
519 |
logger.info(opt)
|
520 |
-
try:
|
521 |
-
import wandb
|
522 |
-
except ImportError:
|
523 |
-
wandb = None
|
524 |
-
prefix = colorstr('wandb: ')
|
525 |
-
logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
|
526 |
if not opt.evolve:
|
527 |
tb_writer = None # init loggers
|
528 |
if opt.global_rank in [-1, 0]:
|
529 |
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
|
530 |
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
|
531 |
-
train(hyp, opt, device, tb_writer
|
532 |
|
533 |
# Evolve hyperparameters (optional)
|
534 |
else:
|
@@ -602,7 +608,7 @@ if __name__ == '__main__':
|
|
602 |
hyp[k] = round(hyp[k], 5) # significant digits
|
603 |
|
604 |
# Train mutation
|
605 |
-
results = train(hyp.copy(), opt, device
|
606 |
|
607 |
# Write mutation results
|
608 |
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
|
|
|
1 |
+
|
2 |
import argparse
|
3 |
import logging
|
4 |
import math
|
|
|
34 |
from utils.loss import ComputeLoss
|
35 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
36 |
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
|
37 |
+
from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file
|
38 |
|
39 |
logger = logging.getLogger(__name__)
|
40 |
|
41 |
|
42 |
+
def train(hyp, opt, device, tb_writer=None):
|
43 |
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
44 |
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
45 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
|
|
63 |
init_seeds(2 + rank)
|
64 |
with open(opt.data) as f:
|
65 |
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
|
66 |
+
is_coco = opt.data.endswith('coco.yaml')
|
67 |
+
|
68 |
+
# Logging- Doing this before checking the dataset. Might update data_dict
|
69 |
+
if rank in [-1, 0]:
|
70 |
+
opt.hyp = hyp # add hyperparameters
|
71 |
+
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
72 |
+
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
|
73 |
+
data_dict = wandb_logger.data_dict
|
74 |
+
if wandb_logger.wandb:
|
75 |
+
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
|
76 |
+
loggers = {'wandb': wandb_logger.wandb} # loggers dict
|
77 |
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
|
78 |
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
79 |
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
|
|
92 |
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
|
93 |
else:
|
94 |
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
95 |
+
with torch_distributed_zero_first(rank):
|
96 |
+
check_dataset(data_dict) # check
|
97 |
+
train_path = data_dict['train']
|
98 |
+
test_path = data_dict['val']
|
99 |
|
100 |
# Freeze
|
101 |
freeze = [] # parameter names to freeze (full or partial)
|
|
|
139 |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
140 |
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
# EMA
|
143 |
ema = ModelEMA(model) if rank in [-1, 0] else None
|
144 |
|
|
|
329 |
# if tb_writer:
|
330 |
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
331 |
# tb_writer.add_graph(model, imgs) # add model to tensorboard
|
332 |
+
elif plots and ni == 10 and wandb_logger.wandb:
|
333 |
+
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
|
334 |
+
save_dir.glob('train*.jpg') if x.exists()]})
|
335 |
|
336 |
# end batch ------------------------------------------------------------------------------------------------
|
337 |
# end epoch ----------------------------------------------------------------------------------------------------
|
|
|
346 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
347 |
final_epoch = epoch + 1 == epochs
|
348 |
if not opt.notest or final_epoch: # Calculate mAP
|
349 |
+
wandb_logger.current_epoch = epoch + 1
|
350 |
+
results, maps, times = test.test(data_dict,
|
351 |
+
batch_size=total_batch_size,
|
352 |
imgsz=imgsz_test,
|
353 |
model=ema.ema,
|
354 |
single_cls=opt.single_cls,
|
|
|
356 |
save_dir=save_dir,
|
357 |
verbose=nc < 50 and final_epoch,
|
358 |
plots=plots and final_epoch,
|
359 |
+
wandb_logger=wandb_logger,
|
360 |
+
compute_loss=compute_loss,
|
361 |
+
is_coco=is_coco)
|
362 |
|
363 |
# Write
|
364 |
with open(results_file, 'a') as f:
|
|
|
374 |
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
|
375 |
if tb_writer:
|
376 |
tb_writer.add_scalar(tag, x, epoch) # tensorboard
|
377 |
+
if wandb_logger.wandb:
|
378 |
+
wandb_logger.log({tag: x}) # W&B
|
379 |
|
380 |
# Update best mAP
|
381 |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
|
|
|
391 |
'ema': deepcopy(ema.ema).half(),
|
392 |
'updates': ema.updates,
|
393 |
'optimizer': optimizer.state_dict(),
|
394 |
+
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
|
395 |
|
396 |
# Save last, best and delete
|
397 |
torch.save(ckpt, last)
|
398 |
if best_fitness == fi:
|
399 |
torch.save(ckpt, best)
|
400 |
+
if wandb_logger.wandb:
|
401 |
+
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
402 |
+
wandb_logger.log_model(
|
403 |
+
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
404 |
del ckpt
|
405 |
+
wandb_logger.end_epoch(best_result=best_fitness == fi)
|
406 |
+
|
407 |
# end epoch ----------------------------------------------------------------------------------------------------
|
408 |
# end training
|
|
|
409 |
if rank in [-1, 0]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
# Plots
|
411 |
if plots:
|
412 |
plot_results(save_dir=save_dir) # save as results.png
|
413 |
+
if wandb_logger.wandb:
|
414 |
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
415 |
+
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
|
416 |
+
if (save_dir / f).exists()]})
|
|
|
|
|
|
|
417 |
# Test best.pt
|
418 |
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
419 |
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
|
|
|
428 |
dataloader=testloader,
|
429 |
save_dir=save_dir,
|
430 |
save_json=True,
|
431 |
+
plots=False,
|
432 |
+
is_coco=is_coco)
|
433 |
|
434 |
+
# Strip optimizers
|
435 |
+
final = best if best.exists() else last # final model
|
436 |
+
for f in last, best:
|
437 |
+
if f.exists():
|
438 |
+
strip_optimizer(f) # strip optimizers
|
439 |
+
if opt.bucket:
|
440 |
+
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
|
441 |
+
if wandb_logger.wandb: # Log the stripped model
|
442 |
+
wandb_logger.wandb.log_artifact(str(final), type='model',
|
443 |
+
name='run_' + wandb_logger.wandb_run.id + '_model',
|
444 |
+
aliases=['last', 'best', 'stripped'])
|
445 |
else:
|
446 |
dist.destroy_process_group()
|
|
|
|
|
447 |
torch.cuda.empty_cache()
|
448 |
+
wandb_logger.finish_run()
|
449 |
return results
|
450 |
|
451 |
|
|
|
473 |
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
|
474 |
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
475 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
|
|
|
|
476 |
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
|
477 |
parser.add_argument('--project', default='runs/train', help='save to project/name')
|
478 |
parser.add_argument('--entity', default=None, help='W&B entity')
|
|
|
480 |
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
481 |
parser.add_argument('--quad', action='store_true', help='quad dataloader')
|
482 |
parser.add_argument('--linear-lr', action='store_true', help='linear LR')
|
483 |
+
parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
|
484 |
+
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
|
485 |
+
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
486 |
+
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
487 |
opt = parser.parse_args()
|
488 |
|
489 |
# Set DDP variables
|
|
|
495 |
check_requirements()
|
496 |
|
497 |
# Resume
|
498 |
+
wandb_run = resume_and_get_id(opt)
|
499 |
+
if opt.resume and not wandb_run: # resume an interrupted run
|
500 |
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
|
501 |
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
|
502 |
apriori = opt.global_rank, opt.local_rank
|
|
|
529 |
|
530 |
# Train
|
531 |
logger.info(opt)
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
if not opt.evolve:
|
533 |
tb_writer = None # init loggers
|
534 |
if opt.global_rank in [-1, 0]:
|
535 |
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
|
536 |
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
|
537 |
+
train(hyp, opt, device, tb_writer)
|
538 |
|
539 |
# Evolve hyperparameters (optional)
|
540 |
else:
|
|
|
608 |
hyp[k] = round(hyp[k], 5) # significant digits
|
609 |
|
610 |
# Train mutation
|
611 |
+
results = train(hyp.copy(), opt, device)
|
612 |
|
613 |
# Write mutation results
|
614 |
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
|
@@ -12,20 +12,7 @@ WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
|
|
12 |
def create_dataset_artifact(opt):
|
13 |
with open(opt.data) as f:
|
14 |
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
|
15 |
-
logger = WandbLogger(opt, '', None, data, job_type='
|
16 |
-
nc, names = (1, ['item']) if opt.single_cls else (int(data['nc']), data['names'])
|
17 |
-
names = {k: v for k, v in enumerate(names)} # to index dictionary
|
18 |
-
logger.log_dataset_artifact(LoadImagesAndLabels(data['train']), names, name='train') # trainset
|
19 |
-
logger.log_dataset_artifact(LoadImagesAndLabels(data['val']), names, name='val') # valset
|
20 |
-
|
21 |
-
# Update data.yaml with artifact links
|
22 |
-
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'train')
|
23 |
-
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'val')
|
24 |
-
path = opt.data if opt.overwrite_config else opt.data.replace('.', '_wandb.') # updated data.yaml path
|
25 |
-
data.pop('download', None) # download via artifact instead of predefined field 'download:'
|
26 |
-
with open(path, 'w') as f:
|
27 |
-
yaml.dump(data, f)
|
28 |
-
print("New Config file => ", path)
|
29 |
|
30 |
|
31 |
if __name__ == '__main__':
|
@@ -33,7 +20,6 @@ if __name__ == '__main__':
|
|
33 |
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
|
34 |
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
35 |
parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
|
36 |
-
parser.add_argument('--overwrite_config', action='store_true', help='overwrite data.yaml')
|
37 |
opt = parser.parse_args()
|
38 |
|
39 |
create_dataset_artifact(opt)
|
|
|
12 |
def create_dataset_artifact(opt):
|
13 |
with open(opt.data) as f:
|
14 |
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
|
15 |
+
logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
if __name__ == '__main__':
|
|
|
20 |
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
|
21 |
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
22 |
parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
|
|
|
23 |
opt = parser.parse_args()
|
24 |
|
25 |
create_dataset_artifact(opt)
|
@@ -1,13 +1,18 @@
|
|
|
|
1 |
import json
|
|
|
2 |
import shutil
|
3 |
import sys
|
|
|
|
|
4 |
from datetime import datetime
|
5 |
from pathlib import Path
|
6 |
-
|
7 |
-
import torch
|
8 |
|
9 |
sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
|
10 |
-
from utils.
|
|
|
|
|
11 |
|
12 |
try:
|
13 |
import wandb
|
@@ -22,87 +27,183 @@ def remove_prefix(from_string, prefix):
|
|
22 |
return from_string[len(prefix):]
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class WandbLogger():
|
26 |
def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
|
27 |
-
|
28 |
-
self.
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
def setup_training(self, opt, data_dict):
|
42 |
-
self.log_dict = {}
|
43 |
-
self.
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
|
55 |
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
|
56 |
-
if opt.
|
57 |
-
|
58 |
-
|
59 |
-
self.weights = Path(modeldir) / "best.pt"
|
60 |
-
opt.weights = self.weights
|
61 |
|
62 |
def download_dataset_artifact(self, path, alias):
|
63 |
if path.startswith(WANDB_ARTIFACT_PREFIX):
|
64 |
dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
|
65 |
assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
|
66 |
datadir = dataset_artifact.download()
|
67 |
-
labels_zip = Path(datadir) / "data/labels.zip"
|
68 |
-
shutil.unpack_archive(labels_zip, Path(datadir) / 'data/labels', 'zip')
|
69 |
-
print("Downloaded dataset to : ", datadir)
|
70 |
return datadir, dataset_artifact
|
71 |
return None, None
|
72 |
|
73 |
-
def download_model_artifact(self,
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
def log_model(self, path, opt, epoch):
|
81 |
-
datetime_suffix = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
|
82 |
model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
|
83 |
'original_url': str(path),
|
84 |
-
'
|
85 |
'save period': opt.save_period,
|
86 |
'project': opt.project,
|
87 |
-
'
|
|
|
88 |
})
|
89 |
model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
|
90 |
-
|
91 |
-
|
92 |
print("Saving model artifact on epoch ", epoch + 1)
|
93 |
|
94 |
-
def log_dataset_artifact(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
artifact = wandb.Artifact(name=name, type="dataset")
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
|
100 |
-
for si, (img, labels, paths, shapes) in enumerate(dataset):
|
101 |
height, width = shapes[0]
|
102 |
-
labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4)))
|
103 |
-
|
104 |
-
box_data = []
|
105 |
-
img_classes = {}
|
106 |
for cls, *xyxy in labels[:, 1:].tolist():
|
107 |
cls = int(cls)
|
108 |
box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
@@ -112,34 +213,52 @@ class WandbLogger():
|
|
112 |
"domain": "pixel"})
|
113 |
img_classes[cls] = class_to_id[cls]
|
114 |
boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
|
115 |
-
table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes)
|
|
|
116 |
artifact.add(table, name)
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
def log(self, log_dict):
|
126 |
if self.wandb_run:
|
127 |
for key, value in log_dict.items():
|
128 |
self.log_dict[key] = value
|
129 |
|
130 |
-
def end_epoch(self):
|
131 |
-
if self.wandb_run
|
132 |
wandb.log(self.log_dict)
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
def finish_run(self):
|
136 |
if self.wandb_run:
|
137 |
-
if self.result_artifact:
|
138 |
-
print("Add Training Progress Artifact")
|
139 |
-
self.result_artifact.add(self.result_table, 'result')
|
140 |
-
train_results = wandb.JoinedTable(self.testset_artifact.get("val"), self.result_table, "id")
|
141 |
-
self.result_artifact.add(train_results, 'joined_result')
|
142 |
-
wandb.log_artifact(self.result_artifact)
|
143 |
if self.log_dict:
|
144 |
wandb.log(self.log_dict)
|
145 |
wandb.run.finish()
|
|
|
1 |
+
import argparse
|
2 |
import json
|
3 |
+
import os
|
4 |
import shutil
|
5 |
import sys
|
6 |
+
import torch
|
7 |
+
import yaml
|
8 |
from datetime import datetime
|
9 |
from pathlib import Path
|
10 |
+
from tqdm import tqdm
|
|
|
11 |
|
12 |
sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
|
13 |
+
from utils.datasets import LoadImagesAndLabels
|
14 |
+
from utils.datasets import img2label_paths
|
15 |
+
from utils.general import colorstr, xywh2xyxy, check_dataset
|
16 |
|
17 |
try:
|
18 |
import wandb
|
|
|
27 |
return from_string[len(prefix):]
|
28 |
|
29 |
|
30 |
+
def check_wandb_config_file(data_config_file):
|
31 |
+
wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
|
32 |
+
if Path(wandb_config).is_file():
|
33 |
+
return wandb_config
|
34 |
+
return data_config_file
|
35 |
+
|
36 |
+
|
37 |
+
def resume_and_get_id(opt):
|
38 |
+
# It's more elegant to stick to 1 wandb.init call, but as useful config data is overwritten in the WandbLogger's wandb.init call
|
39 |
+
if isinstance(opt.resume, str):
|
40 |
+
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
41 |
+
run_path = Path(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX))
|
42 |
+
run_id = run_path.stem
|
43 |
+
project = run_path.parent.stem
|
44 |
+
model_artifact_name = WANDB_ARTIFACT_PREFIX + 'run_' + run_id + '_model'
|
45 |
+
assert wandb, 'install wandb to resume wandb runs'
|
46 |
+
# Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
|
47 |
+
run = wandb.init(id=run_id, project=project, resume='allow')
|
48 |
+
opt.resume = model_artifact_name
|
49 |
+
return run
|
50 |
+
return None
|
51 |
+
|
52 |
+
|
53 |
class WandbLogger():
|
54 |
def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
|
55 |
+
# Pre-training routine --
|
56 |
+
self.job_type = job_type
|
57 |
+
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
|
58 |
+
if self.wandb:
|
59 |
+
self.wandb_run = wandb.init(config=opt,
|
60 |
+
resume="allow",
|
61 |
+
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
|
62 |
+
name=name,
|
63 |
+
job_type=job_type,
|
64 |
+
id=run_id) if not wandb.run else wandb.run
|
65 |
+
if self.job_type == 'Training':
|
66 |
+
if not opt.resume:
|
67 |
+
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
|
68 |
+
# Info useful for resuming from artifacts
|
69 |
+
self.wandb_run.config.opt = vars(opt)
|
70 |
+
self.wandb_run.config.data_dict = wandb_data_dict
|
71 |
+
self.data_dict = self.setup_training(opt, data_dict)
|
72 |
+
if self.job_type == 'Dataset Creation':
|
73 |
+
self.data_dict = self.check_and_upload_dataset(opt)
|
74 |
+
|
75 |
+
def check_and_upload_dataset(self, opt):
|
76 |
+
assert wandb, 'Install wandb to upload dataset'
|
77 |
+
check_dataset(self.data_dict)
|
78 |
+
config_path = self.log_dataset_artifact(opt.data,
|
79 |
+
opt.single_cls,
|
80 |
+
'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
|
81 |
+
print("Created dataset config file ", config_path)
|
82 |
+
with open(config_path) as f:
|
83 |
+
wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader)
|
84 |
+
return wandb_data_dict
|
85 |
|
86 |
def setup_training(self, opt, data_dict):
|
87 |
+
self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants
|
88 |
+
self.bbox_interval = opt.bbox_interval
|
89 |
+
if isinstance(opt.resume, str):
|
90 |
+
modeldir, _ = self.download_model_artifact(opt)
|
91 |
+
if modeldir:
|
92 |
+
self.weights = Path(modeldir) / "last.pt"
|
93 |
+
config = self.wandb_run.config
|
94 |
+
opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
|
95 |
+
self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
|
96 |
+
config.opt['hyp']
|
97 |
+
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
|
98 |
+
if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
|
99 |
+
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
|
100 |
+
opt.artifact_alias)
|
101 |
+
self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
|
102 |
+
opt.artifact_alias)
|
103 |
+
self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
|
104 |
+
if self.train_artifact_path is not None:
|
105 |
+
train_path = Path(self.train_artifact_path) / 'data/images/'
|
106 |
+
data_dict['train'] = str(train_path)
|
107 |
+
if self.val_artifact_path is not None:
|
108 |
+
val_path = Path(self.val_artifact_path) / 'data/images/'
|
109 |
+
data_dict['val'] = str(val_path)
|
110 |
+
self.val_table = self.val_artifact.get("val")
|
111 |
+
self.map_val_table_path()
|
112 |
+
if self.val_artifact is not None:
|
113 |
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
|
114 |
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
|
115 |
+
if opt.bbox_interval == -1:
|
116 |
+
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
|
117 |
+
return data_dict
|
|
|
|
|
118 |
|
119 |
def download_dataset_artifact(self, path, alias):
|
120 |
if path.startswith(WANDB_ARTIFACT_PREFIX):
|
121 |
dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
|
122 |
assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
|
123 |
datadir = dataset_artifact.download()
|
|
|
|
|
|
|
124 |
return datadir, dataset_artifact
|
125 |
return None, None
|
126 |
|
127 |
+
def download_model_artifact(self, opt):
|
128 |
+
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
129 |
+
model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
|
130 |
+
assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
|
131 |
+
modeldir = model_artifact.download()
|
132 |
+
epochs_trained = model_artifact.metadata.get('epochs_trained')
|
133 |
+
total_epochs = model_artifact.metadata.get('total_epochs')
|
134 |
+
assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % (
|
135 |
+
total_epochs)
|
136 |
+
return modeldir, model_artifact
|
137 |
+
return None, None
|
138 |
|
139 |
+
def log_model(self, path, opt, epoch, fitness_score, best_model=False):
|
|
|
140 |
model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
|
141 |
'original_url': str(path),
|
142 |
+
'epochs_trained': epoch + 1,
|
143 |
'save period': opt.save_period,
|
144 |
'project': opt.project,
|
145 |
+
'total_epochs': opt.epochs,
|
146 |
+
'fitness_score': fitness_score
|
147 |
})
|
148 |
model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
|
149 |
+
wandb.log_artifact(model_artifact,
|
150 |
+
aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
|
151 |
print("Saving model artifact on epoch ", epoch + 1)
|
152 |
|
153 |
+
def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
|
154 |
+
with open(data_file) as f:
|
155 |
+
data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
|
156 |
+
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
|
157 |
+
names = {k: v for k, v in enumerate(names)} # to index dictionary
|
158 |
+
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
|
159 |
+
data['train']), names, name='train') if data.get('train') else None
|
160 |
+
self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
|
161 |
+
data['val']), names, name='val') if data.get('val') else None
|
162 |
+
if data.get('train'):
|
163 |
+
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
|
164 |
+
if data.get('val'):
|
165 |
+
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
|
166 |
+
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
|
167 |
+
data.pop('download', None)
|
168 |
+
with open(path, 'w') as f:
|
169 |
+
yaml.dump(data, f)
|
170 |
+
|
171 |
+
if self.job_type == 'Training': # builds correct artifact pipeline graph
|
172 |
+
self.wandb_run.use_artifact(self.val_artifact)
|
173 |
+
self.wandb_run.use_artifact(self.train_artifact)
|
174 |
+
self.val_artifact.wait()
|
175 |
+
self.val_table = self.val_artifact.get('val')
|
176 |
+
self.map_val_table_path()
|
177 |
+
else:
|
178 |
+
self.wandb_run.log_artifact(self.train_artifact)
|
179 |
+
self.wandb_run.log_artifact(self.val_artifact)
|
180 |
+
return path
|
181 |
+
|
182 |
+
def map_val_table_path(self):
|
183 |
+
self.val_table_map = {}
|
184 |
+
print("Mapping dataset")
|
185 |
+
for i, data in enumerate(tqdm(self.val_table.data)):
|
186 |
+
self.val_table_map[data[3]] = data[0]
|
187 |
+
|
188 |
+
def create_dataset_table(self, dataset, class_to_id, name='dataset'):
|
189 |
+
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
|
190 |
artifact = wandb.Artifact(name=name, type="dataset")
|
191 |
+
for img_file in tqdm([dataset.path]) if Path(dataset.path).is_dir() else tqdm(dataset.img_files):
|
192 |
+
if Path(img_file).is_dir():
|
193 |
+
artifact.add_dir(img_file, name='data/images')
|
194 |
+
labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
|
195 |
+
artifact.add_dir(labels_path, name='data/labels')
|
196 |
+
else:
|
197 |
+
artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
|
198 |
+
label_file = Path(img2label_paths([img_file])[0])
|
199 |
+
artifact.add_file(str(label_file),
|
200 |
+
name='data/labels/' + label_file.name) if label_file.exists() else None
|
201 |
+
table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
|
202 |
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
|
203 |
+
for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
|
204 |
height, width = shapes[0]
|
205 |
+
labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height])
|
206 |
+
box_data, img_classes = [], {}
|
|
|
|
|
207 |
for cls, *xyxy in labels[:, 1:].tolist():
|
208 |
cls = int(cls)
|
209 |
box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
|
|
213 |
"domain": "pixel"})
|
214 |
img_classes[cls] = class_to_id[cls]
|
215 |
boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
|
216 |
+
table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes),
|
217 |
+
Path(paths).name)
|
218 |
artifact.add(table, name)
|
219 |
+
return artifact
|
220 |
+
|
221 |
+
def log_training_progress(self, predn, path, names):
|
222 |
+
if self.val_table and self.result_table:
|
223 |
+
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
|
224 |
+
box_data = []
|
225 |
+
total_conf = 0
|
226 |
+
for *xyxy, conf, cls in predn.tolist():
|
227 |
+
if conf >= 0.25:
|
228 |
+
box_data.append(
|
229 |
+
{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
230 |
+
"class_id": int(cls),
|
231 |
+
"box_caption": "%s %.3f" % (names[cls], conf),
|
232 |
+
"scores": {"class_score": conf},
|
233 |
+
"domain": "pixel"})
|
234 |
+
total_conf = total_conf + conf
|
235 |
+
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
|
236 |
+
id = self.val_table_map[Path(path).name]
|
237 |
+
self.result_table.add_data(self.current_epoch,
|
238 |
+
id,
|
239 |
+
wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
|
240 |
+
total_conf / max(1, len(box_data))
|
241 |
+
)
|
242 |
|
243 |
def log(self, log_dict):
|
244 |
if self.wandb_run:
|
245 |
for key, value in log_dict.items():
|
246 |
self.log_dict[key] = value
|
247 |
|
248 |
+
def end_epoch(self, best_result=False):
|
249 |
+
if self.wandb_run:
|
250 |
wandb.log(self.log_dict)
|
251 |
+
self.log_dict = {}
|
252 |
+
if self.result_artifact:
|
253 |
+
train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
|
254 |
+
self.result_artifact.add(train_results, 'result')
|
255 |
+
wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch),
|
256 |
+
('best' if best_result else '')])
|
257 |
+
self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
|
258 |
+
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
|
259 |
|
260 |
def finish_run(self):
|
261 |
if self.wandb_run:
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
if self.log_dict:
|
263 |
wandb.log(self.log_dict)
|
264 |
wandb.run.finish()
|