Ayush Chaurasia glenn-jocher commited on
Commit
e8fc97a
1 Parent(s): ed2c742

Improved W&B integration (#2125)

Browse files

* Init Commit

* new wandb integration

* Update

* Use data_dict in test

* Updates

* Update: scope of log_img

* Update: scope of log_img

* Update

* Update: Fix logging conditions

* Add tqdm bar, support for .txt dataset format

* Improve Result table Logger

* Init Commit

* new wandb integration

* Update

* Use data_dict in test

* Updates

* Update: scope of log_img

* Update: scope of log_img

* Update

* Update: Fix logging conditions

* Add tqdm bar, support for .txt dataset format

* Improve Result table Logger

* Add dataset creation in training script

* Change scope: self.wandb_run

* Add wandb-artifact:// natively

you can now use --resume with wandb run links

* Add suuport for logging dataset while training

* Cleanup

* Fix: Merge conflict

* Fix: CI tests

* Automatically use wandb config

* Fix: Resume

* Fix: CI

* Enhance: Using val_table

* More resume enhancement

* FIX : CI

* Add alias

* Get useful opt config data

* train.py cleanup

* Cleanup train.py

* more cleanup

* Cleanup| CI fix

* Reformat using PEP8

* FIX:CI

* rebase

* remove uneccesary changes

* remove uneccesary changes

* remove uneccesary changes

* remove unecessary chage from test.py

* FIX: resume from local checkpoint

* FIX:resume

* FIX:resume

* Reformat

* Performance improvement

* Fix local resume

* Fix local resume

* FIX:CI

* Fix: CI

* Imporve image logging

* (:(:Redo CI tests:):)

* Remember epochs when resuming

* Remember epochs when resuming

* Update DDP location

Potential fix for #2405

* PEP8 reformat

* 0.25 confidence threshold

* reset train.py plots syntax to previous

* reset epochs completed syntax to previous

* reset space to previous

* remove brackets

* reset comment to previous

* Update: is_coco check, remove unused code

* Remove redundant print statement

* Remove wandb imports

* remove dsviz logger from test.py

* Remove redundant change from test.py

* remove redundant changes from train.py

* reformat and improvements

* Fix typo

* Add tqdm tqdm progress when scanning files, naming improvements

Co-authored-by: Glenn Jocher <[email protected]>

models/common.py CHANGED
@@ -278,7 +278,7 @@ class Detections:
278
  def print(self):
279
  self.display(pprint=True) # print results
280
  print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
281
- tuple(self.t))
282
 
283
  def show(self):
284
  self.display(show=True) # show results
 
278
  def print(self):
279
  self.display(pprint=True) # print results
280
  print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
281
+ tuple(self.t))
282
 
283
  def show(self):
284
  self.display(show=True) # show results
test.py CHANGED
@@ -35,8 +35,9 @@ def test(data,
35
  save_hybrid=False, # for hybrid auto-labelling
36
  save_conf=False, # save auto-label confidences
37
  plots=True,
38
- log_imgs=0, # number of logged images
39
- compute_loss=None):
 
40
  # Initialize/load model and set device
41
  training = model is not None
42
  if training: # called by train.py
@@ -66,21 +67,19 @@ def test(data,
66
 
67
  # Configure
68
  model.eval()
69
- is_coco = data.endswith('coco.yaml') # is COCO dataset
70
- with open(data) as f:
71
- data = yaml.load(f, Loader=yaml.SafeLoader) # model dict
 
72
  check_dataset(data) # check
73
  nc = 1 if single_cls else int(data['nc']) # number of classes
74
  iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95
75
  niou = iouv.numel()
76
 
77
  # Logging
78
- log_imgs, wandb = min(log_imgs, 100), None # ceil
79
- try:
80
- import wandb # Weights & Biases
81
- except ImportError:
82
- log_imgs = 0
83
-
84
  # Dataloader
85
  if not training:
86
  if device.type != 'cpu':
@@ -147,15 +146,17 @@ def test(data,
147
  with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
148
  f.write(('%g ' * len(line)).rstrip() % line + '\n')
149
 
150
- # W&B logging
151
- if plots and len(wandb_images) < log_imgs:
152
- box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
153
- "class_id": int(cls),
154
- "box_caption": "%s %.3f" % (names[cls], conf),
155
- "scores": {"class_score": conf},
156
- "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
157
- boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
158
- wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name))
 
 
159
 
160
  # Append to pycocotools JSON dictionary
161
  if save_json:
@@ -239,9 +240,11 @@ def test(data,
239
  # Plots
240
  if plots:
241
  confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
242
- if wandb and wandb.run:
243
- val_batches = [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
244
- wandb.log({"Images": wandb_images, "Validation": val_batches}, commit=False)
 
 
245
 
246
  # Save JSON
247
  if save_json and len(jdict):
 
35
  save_hybrid=False, # for hybrid auto-labelling
36
  save_conf=False, # save auto-label confidences
37
  plots=True,
38
+ wandb_logger=None,
39
+ compute_loss=None,
40
+ is_coco=False):
41
  # Initialize/load model and set device
42
  training = model is not None
43
  if training: # called by train.py
 
67
 
68
  # Configure
69
  model.eval()
70
+ if isinstance(data, str):
71
+ is_coco = data.endswith('coco.yaml')
72
+ with open(data) as f:
73
+ data = yaml.load(f, Loader=yaml.SafeLoader)
74
  check_dataset(data) # check
75
  nc = 1 if single_cls else int(data['nc']) # number of classes
76
  iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95
77
  niou = iouv.numel()
78
 
79
  # Logging
80
+ log_imgs = 0
81
+ if wandb_logger and wandb_logger.wandb:
82
+ log_imgs = min(wandb_logger.log_imgs, 100)
 
 
 
83
  # Dataloader
84
  if not training:
85
  if device.type != 'cpu':
 
146
  with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
147
  f.write(('%g ' * len(line)).rstrip() % line + '\n')
148
 
149
+ # W&B logging - Media Panel Plots
150
+ if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # Check for test operation
151
+ if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0:
152
+ box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
153
+ "class_id": int(cls),
154
+ "box_caption": "%s %.3f" % (names[cls], conf),
155
+ "scores": {"class_score": conf},
156
+ "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
157
+ boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
158
+ wandb_images.append(wandb_logger.wandb.Image(img[si], boxes=boxes, caption=path.name))
159
+ wandb_logger.log_training_progress(predn, path, names) # logs dsviz tables
160
 
161
  # Append to pycocotools JSON dictionary
162
  if save_json:
 
240
  # Plots
241
  if plots:
242
  confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
243
+ if wandb_logger and wandb_logger.wandb:
244
+ val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
245
+ wandb_logger.log({"Validation": val_batches})
246
+ if wandb_images:
247
+ wandb_logger.log({"Bounding Box Debugger/Images": wandb_images})
248
 
249
  # Save JSON
250
  if save_json and len(jdict):
train.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import argparse
2
  import logging
3
  import math
@@ -33,11 +34,12 @@ from utils.google_utils import attempt_download
33
  from utils.loss import ComputeLoss
34
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
35
  from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
 
36
 
37
  logger = logging.getLogger(__name__)
38
 
39
 
40
- def train(hyp, opt, device, tb_writer=None, wandb=None):
41
  logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
42
  save_dir, epochs, batch_size, total_batch_size, weights, rank = \
43
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
@@ -61,10 +63,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
61
  init_seeds(2 + rank)
62
  with open(opt.data) as f:
63
  data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
64
- with torch_distributed_zero_first(rank):
65
- check_dataset(data_dict) # check
66
- train_path = data_dict['train']
67
- test_path = data_dict['val']
 
 
 
 
 
 
 
68
  nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
69
  names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
70
  assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
@@ -83,6 +92,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
83
  logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
84
  else:
85
  model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
 
 
 
 
86
 
87
  # Freeze
88
  freeze = [] # parameter names to freeze (full or partial)
@@ -126,16 +139,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
126
  scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
127
  # plot_lr_scheduler(optimizer, scheduler, epochs)
128
 
129
- # Logging
130
- if rank in [-1, 0] and wandb and wandb.run is None:
131
- opt.hyp = hyp # add hyperparameters
132
- wandb_run = wandb.init(config=opt, resume="allow",
133
- project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
134
- name=save_dir.stem,
135
- entity=opt.entity,
136
- id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
137
- loggers = {'wandb': wandb} # loggers dict
138
-
139
  # EMA
140
  ema = ModelEMA(model) if rank in [-1, 0] else None
141
 
@@ -326,9 +329,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
326
  # if tb_writer:
327
  # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
328
  # tb_writer.add_graph(model, imgs) # add model to tensorboard
329
- elif plots and ni == 10 and wandb:
330
- wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')
331
- if x.exists()]}, commit=False)
332
 
333
  # end batch ------------------------------------------------------------------------------------------------
334
  # end epoch ----------------------------------------------------------------------------------------------------
@@ -343,8 +346,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
343
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
344
  final_epoch = epoch + 1 == epochs
345
  if not opt.notest or final_epoch: # Calculate mAP
346
- results, maps, times = test.test(opt.data,
347
- batch_size=batch_size * 2,
 
348
  imgsz=imgsz_test,
349
  model=ema.ema,
350
  single_cls=opt.single_cls,
@@ -352,8 +356,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
352
  save_dir=save_dir,
353
  verbose=nc < 50 and final_epoch,
354
  plots=plots and final_epoch,
355
- log_imgs=opt.log_imgs if wandb else 0,
356
- compute_loss=compute_loss)
 
357
 
358
  # Write
359
  with open(results_file, 'a') as f:
@@ -369,8 +374,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
369
  for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
370
  if tb_writer:
371
  tb_writer.add_scalar(tag, x, epoch) # tensorboard
372
- if wandb:
373
- wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B
374
 
375
  # Update best mAP
376
  fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
@@ -386,36 +391,29 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
386
  'ema': deepcopy(ema.ema).half(),
387
  'updates': ema.updates,
388
  'optimizer': optimizer.state_dict(),
389
- 'wandb_id': wandb_run.id if wandb else None}
390
 
391
  # Save last, best and delete
392
  torch.save(ckpt, last)
393
  if best_fitness == fi:
394
  torch.save(ckpt, best)
 
 
 
 
395
  del ckpt
396
-
 
397
  # end epoch ----------------------------------------------------------------------------------------------------
398
  # end training
399
-
400
  if rank in [-1, 0]:
401
- # Strip optimizers
402
- final = best if best.exists() else last # final model
403
- for f in last, best:
404
- if f.exists():
405
- strip_optimizer(f)
406
- if opt.bucket:
407
- os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
408
-
409
  # Plots
410
  if plots:
411
  plot_results(save_dir=save_dir) # save as results.png
412
- if wandb:
413
  files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
414
- wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
415
- if (save_dir / f).exists()]})
416
- if opt.log_artifacts:
417
- wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)
418
-
419
  # Test best.pt
420
  logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
421
  if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
@@ -430,13 +428,24 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
430
  dataloader=testloader,
431
  save_dir=save_dir,
432
  save_json=True,
433
- plots=False)
 
434
 
 
 
 
 
 
 
 
 
 
 
 
435
  else:
436
  dist.destroy_process_group()
437
-
438
- wandb.run.finish() if wandb and wandb.run else None
439
  torch.cuda.empty_cache()
 
440
  return results
441
 
442
 
@@ -464,8 +473,6 @@ if __name__ == '__main__':
464
  parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
465
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
466
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
467
- parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
468
- parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
469
  parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
470
  parser.add_argument('--project', default='runs/train', help='save to project/name')
471
  parser.add_argument('--entity', default=None, help='W&B entity')
@@ -473,6 +480,10 @@ if __name__ == '__main__':
473
  parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
474
  parser.add_argument('--quad', action='store_true', help='quad dataloader')
475
  parser.add_argument('--linear-lr', action='store_true', help='linear LR')
 
 
 
 
476
  opt = parser.parse_args()
477
 
478
  # Set DDP variables
@@ -484,7 +495,8 @@ if __name__ == '__main__':
484
  check_requirements()
485
 
486
  # Resume
487
- if opt.resume: # resume an interrupted run
 
488
  ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
489
  assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
490
  apriori = opt.global_rank, opt.local_rank
@@ -517,18 +529,12 @@ if __name__ == '__main__':
517
 
518
  # Train
519
  logger.info(opt)
520
- try:
521
- import wandb
522
- except ImportError:
523
- wandb = None
524
- prefix = colorstr('wandb: ')
525
- logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
526
  if not opt.evolve:
527
  tb_writer = None # init loggers
528
  if opt.global_rank in [-1, 0]:
529
  logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
530
  tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
531
- train(hyp, opt, device, tb_writer, wandb)
532
 
533
  # Evolve hyperparameters (optional)
534
  else:
@@ -602,7 +608,7 @@ if __name__ == '__main__':
602
  hyp[k] = round(hyp[k], 5) # significant digits
603
 
604
  # Train mutation
605
- results = train(hyp.copy(), opt, device, wandb=wandb)
606
 
607
  # Write mutation results
608
  print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
 
1
+
2
  import argparse
3
  import logging
4
  import math
 
34
  from utils.loss import ComputeLoss
35
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
36
  from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
37
+ from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file
38
 
39
  logger = logging.getLogger(__name__)
40
 
41
 
42
+ def train(hyp, opt, device, tb_writer=None):
43
  logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
44
  save_dir, epochs, batch_size, total_batch_size, weights, rank = \
45
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
 
63
  init_seeds(2 + rank)
64
  with open(opt.data) as f:
65
  data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
66
+ is_coco = opt.data.endswith('coco.yaml')
67
+
68
+ # Logging- Doing this before checking the dataset. Might update data_dict
69
+ if rank in [-1, 0]:
70
+ opt.hyp = hyp # add hyperparameters
71
+ run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
72
+ wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
73
+ data_dict = wandb_logger.data_dict
74
+ if wandb_logger.wandb:
75
+ weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
76
+ loggers = {'wandb': wandb_logger.wandb} # loggers dict
77
  nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
78
  names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
79
  assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
 
92
  logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
93
  else:
94
  model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
95
+ with torch_distributed_zero_first(rank):
96
+ check_dataset(data_dict) # check
97
+ train_path = data_dict['train']
98
+ test_path = data_dict['val']
99
 
100
  # Freeze
101
  freeze = [] # parameter names to freeze (full or partial)
 
139
  scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
140
  # plot_lr_scheduler(optimizer, scheduler, epochs)
141
 
 
 
 
 
 
 
 
 
 
 
142
  # EMA
143
  ema = ModelEMA(model) if rank in [-1, 0] else None
144
 
 
329
  # if tb_writer:
330
  # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
331
  # tb_writer.add_graph(model, imgs) # add model to tensorboard
332
+ elif plots and ni == 10 and wandb_logger.wandb:
333
+ wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
334
+ save_dir.glob('train*.jpg') if x.exists()]})
335
 
336
  # end batch ------------------------------------------------------------------------------------------------
337
  # end epoch ----------------------------------------------------------------------------------------------------
 
346
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
347
  final_epoch = epoch + 1 == epochs
348
  if not opt.notest or final_epoch: # Calculate mAP
349
+ wandb_logger.current_epoch = epoch + 1
350
+ results, maps, times = test.test(data_dict,
351
+ batch_size=total_batch_size,
352
  imgsz=imgsz_test,
353
  model=ema.ema,
354
  single_cls=opt.single_cls,
 
356
  save_dir=save_dir,
357
  verbose=nc < 50 and final_epoch,
358
  plots=plots and final_epoch,
359
+ wandb_logger=wandb_logger,
360
+ compute_loss=compute_loss,
361
+ is_coco=is_coco)
362
 
363
  # Write
364
  with open(results_file, 'a') as f:
 
374
  for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
375
  if tb_writer:
376
  tb_writer.add_scalar(tag, x, epoch) # tensorboard
377
+ if wandb_logger.wandb:
378
+ wandb_logger.log({tag: x}) # W&B
379
 
380
  # Update best mAP
381
  fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
 
391
  'ema': deepcopy(ema.ema).half(),
392
  'updates': ema.updates,
393
  'optimizer': optimizer.state_dict(),
394
+ 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
395
 
396
  # Save last, best and delete
397
  torch.save(ckpt, last)
398
  if best_fitness == fi:
399
  torch.save(ckpt, best)
400
+ if wandb_logger.wandb:
401
+ if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
402
+ wandb_logger.log_model(
403
+ last.parent, opt, epoch, fi, best_model=best_fitness == fi)
404
  del ckpt
405
+ wandb_logger.end_epoch(best_result=best_fitness == fi)
406
+
407
  # end epoch ----------------------------------------------------------------------------------------------------
408
  # end training
 
409
  if rank in [-1, 0]:
 
 
 
 
 
 
 
 
410
  # Plots
411
  if plots:
412
  plot_results(save_dir=save_dir) # save as results.png
413
+ if wandb_logger.wandb:
414
  files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
415
+ wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
416
+ if (save_dir / f).exists()]})
 
 
 
417
  # Test best.pt
418
  logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
419
  if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
 
428
  dataloader=testloader,
429
  save_dir=save_dir,
430
  save_json=True,
431
+ plots=False,
432
+ is_coco=is_coco)
433
 
434
+ # Strip optimizers
435
+ final = best if best.exists() else last # final model
436
+ for f in last, best:
437
+ if f.exists():
438
+ strip_optimizer(f) # strip optimizers
439
+ if opt.bucket:
440
+ os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
441
+ if wandb_logger.wandb: # Log the stripped model
442
+ wandb_logger.wandb.log_artifact(str(final), type='model',
443
+ name='run_' + wandb_logger.wandb_run.id + '_model',
444
+ aliases=['last', 'best', 'stripped'])
445
  else:
446
  dist.destroy_process_group()
 
 
447
  torch.cuda.empty_cache()
448
+ wandb_logger.finish_run()
449
  return results
450
 
451
 
 
473
  parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
474
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
475
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
 
 
476
  parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
477
  parser.add_argument('--project', default='runs/train', help='save to project/name')
478
  parser.add_argument('--entity', default=None, help='W&B entity')
 
480
  parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
481
  parser.add_argument('--quad', action='store_true', help='quad dataloader')
482
  parser.add_argument('--linear-lr', action='store_true', help='linear LR')
483
+ parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
484
+ parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
485
+ parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
486
+ parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
487
  opt = parser.parse_args()
488
 
489
  # Set DDP variables
 
495
  check_requirements()
496
 
497
  # Resume
498
+ wandb_run = resume_and_get_id(opt)
499
+ if opt.resume and not wandb_run: # resume an interrupted run
500
  ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
501
  assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
502
  apriori = opt.global_rank, opt.local_rank
 
529
 
530
  # Train
531
  logger.info(opt)
 
 
 
 
 
 
532
  if not opt.evolve:
533
  tb_writer = None # init loggers
534
  if opt.global_rank in [-1, 0]:
535
  logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
536
  tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
537
+ train(hyp, opt, device, tb_writer)
538
 
539
  # Evolve hyperparameters (optional)
540
  else:
 
608
  hyp[k] = round(hyp[k], 5) # significant digits
609
 
610
  # Train mutation
611
+ results = train(hyp.copy(), opt, device)
612
 
613
  # Write mutation results
614
  print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
utils/wandb_logging/log_dataset.py CHANGED
@@ -12,20 +12,7 @@ WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
12
  def create_dataset_artifact(opt):
13
  with open(opt.data) as f:
14
  data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
15
- logger = WandbLogger(opt, '', None, data, job_type='create_dataset')
16
- nc, names = (1, ['item']) if opt.single_cls else (int(data['nc']), data['names'])
17
- names = {k: v for k, v in enumerate(names)} # to index dictionary
18
- logger.log_dataset_artifact(LoadImagesAndLabels(data['train']), names, name='train') # trainset
19
- logger.log_dataset_artifact(LoadImagesAndLabels(data['val']), names, name='val') # valset
20
-
21
- # Update data.yaml with artifact links
22
- data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'train')
23
- data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'val')
24
- path = opt.data if opt.overwrite_config else opt.data.replace('.', '_wandb.') # updated data.yaml path
25
- data.pop('download', None) # download via artifact instead of predefined field 'download:'
26
- with open(path, 'w') as f:
27
- yaml.dump(data, f)
28
- print("New Config file => ", path)
29
 
30
 
31
  if __name__ == '__main__':
@@ -33,7 +20,6 @@ if __name__ == '__main__':
33
  parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
34
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
35
  parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
36
- parser.add_argument('--overwrite_config', action='store_true', help='overwrite data.yaml')
37
  opt = parser.parse_args()
38
 
39
  create_dataset_artifact(opt)
 
12
  def create_dataset_artifact(opt):
13
  with open(opt.data) as f:
14
  data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
15
+ logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation')
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  if __name__ == '__main__':
 
20
  parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
21
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
22
  parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
 
23
  opt = parser.parse_args()
24
 
25
  create_dataset_artifact(opt)
utils/wandb_logging/wandb_utils.py CHANGED
@@ -1,13 +1,18 @@
 
1
  import json
 
2
  import shutil
3
  import sys
 
 
4
  from datetime import datetime
5
  from pathlib import Path
6
-
7
- import torch
8
 
9
  sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
10
- from utils.general import colorstr, xywh2xyxy
 
 
11
 
12
  try:
13
  import wandb
@@ -22,87 +27,183 @@ def remove_prefix(from_string, prefix):
22
  return from_string[len(prefix):]
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class WandbLogger():
26
  def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
27
- self.wandb = wandb
28
- self.wandb_run = wandb.init(config=opt, resume="allow",
29
- project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
30
- name=name,
31
- job_type=job_type,
32
- id=run_id) if self.wandb else None
33
-
34
- if job_type == 'Training':
35
- self.setup_training(opt, data_dict)
36
- if opt.bbox_interval == -1:
37
- opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs
38
- if opt.save_period == -1:
39
- opt.save_period = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def setup_training(self, opt, data_dict):
42
- self.log_dict = {}
43
- self.train_artifact_path, self.trainset_artifact = \
44
- self.download_dataset_artifact(data_dict['train'], opt.artifact_alias)
45
- self.test_artifact_path, self.testset_artifact = \
46
- self.download_dataset_artifact(data_dict['val'], opt.artifact_alias)
47
- self.result_artifact, self.result_table, self.weights = None, None, None
48
- if self.train_artifact_path is not None:
49
- train_path = Path(self.train_artifact_path) / 'data/images/'
50
- data_dict['train'] = str(train_path)
51
- if self.test_artifact_path is not None:
52
- test_path = Path(self.test_artifact_path) / 'data/images/'
53
- data_dict['val'] = str(test_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
55
  self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
56
- if opt.resume_from_artifact:
57
- modeldir, _ = self.download_model_artifact(opt.resume_from_artifact)
58
- if modeldir:
59
- self.weights = Path(modeldir) / "best.pt"
60
- opt.weights = self.weights
61
 
62
  def download_dataset_artifact(self, path, alias):
63
  if path.startswith(WANDB_ARTIFACT_PREFIX):
64
  dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
65
  assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
66
  datadir = dataset_artifact.download()
67
- labels_zip = Path(datadir) / "data/labels.zip"
68
- shutil.unpack_archive(labels_zip, Path(datadir) / 'data/labels', 'zip')
69
- print("Downloaded dataset to : ", datadir)
70
  return datadir, dataset_artifact
71
  return None, None
72
 
73
- def download_model_artifact(self, name):
74
- model_artifact = wandb.use_artifact(name + ":latest")
75
- assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
76
- modeldir = model_artifact.download()
77
- print("Downloaded model to : ", modeldir)
78
- return modeldir, model_artifact
 
 
 
 
 
79
 
80
- def log_model(self, path, opt, epoch):
81
- datetime_suffix = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
82
  model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
83
  'original_url': str(path),
84
- 'epoch': epoch + 1,
85
  'save period': opt.save_period,
86
  'project': opt.project,
87
- 'datetime': datetime_suffix
 
88
  })
89
  model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
90
- model_artifact.add_file(str(path / 'best.pt'), name='best.pt')
91
- wandb.log_artifact(model_artifact)
92
  print("Saving model artifact on epoch ", epoch + 1)
93
 
94
- def log_dataset_artifact(self, dataset, class_to_id, name='dataset'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  artifact = wandb.Artifact(name=name, type="dataset")
96
- image_path = dataset.path
97
- artifact.add_dir(image_path, name='data/images')
98
- table = wandb.Table(columns=["id", "train_image", "Classes"])
 
 
 
 
 
 
 
 
99
  class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
100
- for si, (img, labels, paths, shapes) in enumerate(dataset):
101
  height, width = shapes[0]
102
- labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4)))
103
- labels[:, 2:] *= torch.Tensor([width, height, width, height])
104
- box_data = []
105
- img_classes = {}
106
  for cls, *xyxy in labels[:, 1:].tolist():
107
  cls = int(cls)
108
  box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
@@ -112,34 +213,52 @@ class WandbLogger():
112
  "domain": "pixel"})
113
  img_classes[cls] = class_to_id[cls]
114
  boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
115
- table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes))
 
116
  artifact.add(table, name)
117
- labels_path = 'labels'.join(image_path.rsplit('images', 1))
118
- zip_path = Path(labels_path).parent / (name + '_labels.zip')
119
- if not zip_path.is_file(): # make_archive won't check if file exists
120
- shutil.make_archive(zip_path.with_suffix(''), 'zip', labels_path)
121
- artifact.add_file(str(zip_path), name='data/labels.zip')
122
- wandb.log_artifact(artifact)
123
- print("Saving data to W&B...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def log(self, log_dict):
126
  if self.wandb_run:
127
  for key, value in log_dict.items():
128
  self.log_dict[key] = value
129
 
130
- def end_epoch(self):
131
- if self.wandb_run and self.log_dict:
132
  wandb.log(self.log_dict)
133
- self.log_dict = {}
 
 
 
 
 
 
 
134
 
135
  def finish_run(self):
136
  if self.wandb_run:
137
- if self.result_artifact:
138
- print("Add Training Progress Artifact")
139
- self.result_artifact.add(self.result_table, 'result')
140
- train_results = wandb.JoinedTable(self.testset_artifact.get("val"), self.result_table, "id")
141
- self.result_artifact.add(train_results, 'joined_result')
142
- wandb.log_artifact(self.result_artifact)
143
  if self.log_dict:
144
  wandb.log(self.log_dict)
145
  wandb.run.finish()
 
1
+ import argparse
2
  import json
3
+ import os
4
  import shutil
5
  import sys
6
+ import torch
7
+ import yaml
8
  from datetime import datetime
9
  from pathlib import Path
10
+ from tqdm import tqdm
 
11
 
12
  sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path
13
+ from utils.datasets import LoadImagesAndLabels
14
+ from utils.datasets import img2label_paths
15
+ from utils.general import colorstr, xywh2xyxy, check_dataset
16
 
17
  try:
18
  import wandb
 
27
  return from_string[len(prefix):]
28
 
29
 
30
+ def check_wandb_config_file(data_config_file):
31
+ wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
32
+ if Path(wandb_config).is_file():
33
+ return wandb_config
34
+ return data_config_file
35
+
36
+
37
+ def resume_and_get_id(opt):
38
+ # It's more elegant to stick to 1 wandb.init call, but as useful config data is overwritten in the WandbLogger's wandb.init call
39
+ if isinstance(opt.resume, str):
40
+ if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
41
+ run_path = Path(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX))
42
+ run_id = run_path.stem
43
+ project = run_path.parent.stem
44
+ model_artifact_name = WANDB_ARTIFACT_PREFIX + 'run_' + run_id + '_model'
45
+ assert wandb, 'install wandb to resume wandb runs'
46
+ # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
47
+ run = wandb.init(id=run_id, project=project, resume='allow')
48
+ opt.resume = model_artifact_name
49
+ return run
50
+ return None
51
+
52
+
53
  class WandbLogger():
54
  def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
55
+ # Pre-training routine --
56
+ self.job_type = job_type
57
+ self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
58
+ if self.wandb:
59
+ self.wandb_run = wandb.init(config=opt,
60
+ resume="allow",
61
+ project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
62
+ name=name,
63
+ job_type=job_type,
64
+ id=run_id) if not wandb.run else wandb.run
65
+ if self.job_type == 'Training':
66
+ if not opt.resume:
67
+ wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
68
+ # Info useful for resuming from artifacts
69
+ self.wandb_run.config.opt = vars(opt)
70
+ self.wandb_run.config.data_dict = wandb_data_dict
71
+ self.data_dict = self.setup_training(opt, data_dict)
72
+ if self.job_type == 'Dataset Creation':
73
+ self.data_dict = self.check_and_upload_dataset(opt)
74
+
75
+ def check_and_upload_dataset(self, opt):
76
+ assert wandb, 'Install wandb to upload dataset'
77
+ check_dataset(self.data_dict)
78
+ config_path = self.log_dataset_artifact(opt.data,
79
+ opt.single_cls,
80
+ 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
81
+ print("Created dataset config file ", config_path)
82
+ with open(config_path) as f:
83
+ wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader)
84
+ return wandb_data_dict
85
 
86
  def setup_training(self, opt, data_dict):
87
+ self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants
88
+ self.bbox_interval = opt.bbox_interval
89
+ if isinstance(opt.resume, str):
90
+ modeldir, _ = self.download_model_artifact(opt)
91
+ if modeldir:
92
+ self.weights = Path(modeldir) / "last.pt"
93
+ config = self.wandb_run.config
94
+ opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
95
+ self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \
96
+ config.opt['hyp']
97
+ data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
98
+ if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
99
+ self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
100
+ opt.artifact_alias)
101
+ self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
102
+ opt.artifact_alias)
103
+ self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
104
+ if self.train_artifact_path is not None:
105
+ train_path = Path(self.train_artifact_path) / 'data/images/'
106
+ data_dict['train'] = str(train_path)
107
+ if self.val_artifact_path is not None:
108
+ val_path = Path(self.val_artifact_path) / 'data/images/'
109
+ data_dict['val'] = str(val_path)
110
+ self.val_table = self.val_artifact.get("val")
111
+ self.map_val_table_path()
112
+ if self.val_artifact is not None:
113
  self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
114
  self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
115
+ if opt.bbox_interval == -1:
116
+ self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
117
+ return data_dict
 
 
118
 
119
  def download_dataset_artifact(self, path, alias):
120
  if path.startswith(WANDB_ARTIFACT_PREFIX):
121
  dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
122
  assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
123
  datadir = dataset_artifact.download()
 
 
 
124
  return datadir, dataset_artifact
125
  return None, None
126
 
127
+ def download_model_artifact(self, opt):
128
+ if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
129
+ model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
130
+ assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
131
+ modeldir = model_artifact.download()
132
+ epochs_trained = model_artifact.metadata.get('epochs_trained')
133
+ total_epochs = model_artifact.metadata.get('total_epochs')
134
+ assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % (
135
+ total_epochs)
136
+ return modeldir, model_artifact
137
+ return None, None
138
 
139
+ def log_model(self, path, opt, epoch, fitness_score, best_model=False):
 
140
  model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
141
  'original_url': str(path),
142
+ 'epochs_trained': epoch + 1,
143
  'save period': opt.save_period,
144
  'project': opt.project,
145
+ 'total_epochs': opt.epochs,
146
+ 'fitness_score': fitness_score
147
  })
148
  model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
149
+ wandb.log_artifact(model_artifact,
150
+ aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
151
  print("Saving model artifact on epoch ", epoch + 1)
152
 
153
+ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
154
+ with open(data_file) as f:
155
+ data = yaml.load(f, Loader=yaml.SafeLoader) # data dict
156
+ nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
157
+ names = {k: v for k, v in enumerate(names)} # to index dictionary
158
+ self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
159
+ data['train']), names, name='train') if data.get('train') else None
160
+ self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(
161
+ data['val']), names, name='val') if data.get('val') else None
162
+ if data.get('train'):
163
+ data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
164
+ if data.get('val'):
165
+ data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
166
+ path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
167
+ data.pop('download', None)
168
+ with open(path, 'w') as f:
169
+ yaml.dump(data, f)
170
+
171
+ if self.job_type == 'Training': # builds correct artifact pipeline graph
172
+ self.wandb_run.use_artifact(self.val_artifact)
173
+ self.wandb_run.use_artifact(self.train_artifact)
174
+ self.val_artifact.wait()
175
+ self.val_table = self.val_artifact.get('val')
176
+ self.map_val_table_path()
177
+ else:
178
+ self.wandb_run.log_artifact(self.train_artifact)
179
+ self.wandb_run.log_artifact(self.val_artifact)
180
+ return path
181
+
182
+ def map_val_table_path(self):
183
+ self.val_table_map = {}
184
+ print("Mapping dataset")
185
+ for i, data in enumerate(tqdm(self.val_table.data)):
186
+ self.val_table_map[data[3]] = data[0]
187
+
188
+ def create_dataset_table(self, dataset, class_to_id, name='dataset'):
189
+ # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
190
  artifact = wandb.Artifact(name=name, type="dataset")
191
+ for img_file in tqdm([dataset.path]) if Path(dataset.path).is_dir() else tqdm(dataset.img_files):
192
+ if Path(img_file).is_dir():
193
+ artifact.add_dir(img_file, name='data/images')
194
+ labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
195
+ artifact.add_dir(labels_path, name='data/labels')
196
+ else:
197
+ artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
198
+ label_file = Path(img2label_paths([img_file])[0])
199
+ artifact.add_file(str(label_file),
200
+ name='data/labels/' + label_file.name) if label_file.exists() else None
201
+ table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
202
  class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
203
+ for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
204
  height, width = shapes[0]
205
+ labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height])
206
+ box_data, img_classes = [], {}
 
 
207
  for cls, *xyxy in labels[:, 1:].tolist():
208
  cls = int(cls)
209
  box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
 
213
  "domain": "pixel"})
214
  img_classes[cls] = class_to_id[cls]
215
  boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
216
+ table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes),
217
+ Path(paths).name)
218
  artifact.add(table, name)
219
+ return artifact
220
+
221
+ def log_training_progress(self, predn, path, names):
222
+ if self.val_table and self.result_table:
223
+ class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
224
+ box_data = []
225
+ total_conf = 0
226
+ for *xyxy, conf, cls in predn.tolist():
227
+ if conf >= 0.25:
228
+ box_data.append(
229
+ {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
230
+ "class_id": int(cls),
231
+ "box_caption": "%s %.3f" % (names[cls], conf),
232
+ "scores": {"class_score": conf},
233
+ "domain": "pixel"})
234
+ total_conf = total_conf + conf
235
+ boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
236
+ id = self.val_table_map[Path(path).name]
237
+ self.result_table.add_data(self.current_epoch,
238
+ id,
239
+ wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
240
+ total_conf / max(1, len(box_data))
241
+ )
242
 
243
  def log(self, log_dict):
244
  if self.wandb_run:
245
  for key, value in log_dict.items():
246
  self.log_dict[key] = value
247
 
248
+ def end_epoch(self, best_result=False):
249
+ if self.wandb_run:
250
  wandb.log(self.log_dict)
251
+ self.log_dict = {}
252
+ if self.result_artifact:
253
+ train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
254
+ self.result_artifact.add(train_results, 'result')
255
+ wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch),
256
+ ('best' if best_result else '')])
257
+ self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
258
+ self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
259
 
260
  def finish_run(self):
261
  if self.wandb_run:
 
 
 
 
 
 
262
  if self.log_dict:
263
  wandb.log(self.log_dict)
264
  wandb.run.finish()