Ayush Chaurasia
commited on
Commit
•
1bf9365
1
Parent(s):
0d891c6
W&B DDP fix (#2574)
Browse files- train.py +5 -3
- utils/wandb_logging/wandb_utils.py +4 -1
train.py
CHANGED
@@ -66,14 +66,16 @@ def train(hyp, opt, device, tb_writer=None):
|
|
66 |
is_coco = opt.data.endswith('coco.yaml')
|
67 |
|
68 |
# Logging- Doing this before checking the dataset. Might update data_dict
|
|
|
69 |
if rank in [-1, 0]:
|
70 |
opt.hyp = hyp # add hyperparameters
|
71 |
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
72 |
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
|
|
|
73 |
data_dict = wandb_logger.data_dict
|
74 |
if wandb_logger.wandb:
|
75 |
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
|
76 |
-
|
77 |
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
|
78 |
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
79 |
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
@@ -381,6 +383,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
381 |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
|
382 |
if fi > best_fitness:
|
383 |
best_fitness = fi
|
|
|
384 |
|
385 |
# Save model
|
386 |
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
|
@@ -402,7 +405,6 @@ def train(hyp, opt, device, tb_writer=None):
|
|
402 |
wandb_logger.log_model(
|
403 |
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
404 |
del ckpt
|
405 |
-
wandb_logger.end_epoch(best_result=best_fitness == fi)
|
406 |
|
407 |
# end epoch ----------------------------------------------------------------------------------------------------
|
408 |
# end training
|
@@ -442,10 +444,10 @@ def train(hyp, opt, device, tb_writer=None):
|
|
442 |
wandb_logger.wandb.log_artifact(str(final), type='model',
|
443 |
name='run_' + wandb_logger.wandb_run.id + '_model',
|
444 |
aliases=['last', 'best', 'stripped'])
|
|
|
445 |
else:
|
446 |
dist.destroy_process_group()
|
447 |
torch.cuda.empty_cache()
|
448 |
-
wandb_logger.finish_run()
|
449 |
return results
|
450 |
|
451 |
|
|
|
66 |
is_coco = opt.data.endswith('coco.yaml')
|
67 |
|
68 |
# Logging- Doing this before checking the dataset. Might update data_dict
|
69 |
+
loggers = {'wandb': None} # loggers dict
|
70 |
if rank in [-1, 0]:
|
71 |
opt.hyp = hyp # add hyperparameters
|
72 |
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
73 |
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
|
74 |
+
loggers['wandb'] = wandb_logger.wandb
|
75 |
data_dict = wandb_logger.data_dict
|
76 |
if wandb_logger.wandb:
|
77 |
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
|
78 |
+
|
79 |
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
|
80 |
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
81 |
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
|
|
383 |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
|
384 |
if fi > best_fitness:
|
385 |
best_fitness = fi
|
386 |
+
wandb_logger.end_epoch(best_result=best_fitness == fi)
|
387 |
|
388 |
# Save model
|
389 |
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
|
|
|
405 |
wandb_logger.log_model(
|
406 |
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
407 |
del ckpt
|
|
|
408 |
|
409 |
# end epoch ----------------------------------------------------------------------------------------------------
|
410 |
# end training
|
|
|
444 |
wandb_logger.wandb.log_artifact(str(final), type='model',
|
445 |
name='run_' + wandb_logger.wandb_run.id + '_model',
|
446 |
aliases=['last', 'best', 'stripped'])
|
447 |
+
wandb_logger.finish_run()
|
448 |
else:
|
449 |
dist.destroy_process_group()
|
450 |
torch.cuda.empty_cache()
|
|
|
451 |
return results
|
452 |
|
453 |
|
utils/wandb_logging/wandb_utils.py
CHANGED
@@ -16,9 +16,9 @@ from utils.general import colorstr, xywh2xyxy, check_dataset
|
|
16 |
|
17 |
try:
|
18 |
import wandb
|
|
|
19 |
except ImportError:
|
20 |
wandb = None
|
21 |
-
print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
|
22 |
|
23 |
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
|
24 |
|
@@ -71,6 +71,9 @@ class WandbLogger():
|
|
71 |
self.data_dict = self.setup_training(opt, data_dict)
|
72 |
if self.job_type == 'Dataset Creation':
|
73 |
self.data_dict = self.check_and_upload_dataset(opt)
|
|
|
|
|
|
|
74 |
|
75 |
def check_and_upload_dataset(self, opt):
|
76 |
assert wandb, 'Install wandb to upload dataset'
|
|
|
16 |
|
17 |
try:
|
18 |
import wandb
|
19 |
+
from wandb import init, finish
|
20 |
except ImportError:
|
21 |
wandb = None
|
|
|
22 |
|
23 |
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
|
24 |
|
|
|
71 |
self.data_dict = self.setup_training(opt, data_dict)
|
72 |
if self.job_type == 'Dataset Creation':
|
73 |
self.data_dict = self.check_and_upload_dataset(opt)
|
74 |
+
else:
|
75 |
+
print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
|
76 |
+
|
77 |
|
78 |
def check_and_upload_dataset(self, opt):
|
79 |
assert wandb, 'Install wandb to upload dataset'
|