glenn-jocher
commited on
Commit
•
127cbeb
1
Parent(s):
6f08e8b
hyperparameter expansion to flips, perspective, mixup
Browse files- train.py +13 -10
- utils/datasets.py +19 -22
train.py
CHANGED
@@ -16,25 +16,29 @@ from utils.datasets import *
|
|
16 |
from utils.utils import *
|
17 |
|
18 |
# Hyperparameters
|
19 |
-
hyp = {'optimizer': 'SGD', # ['
|
20 |
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
|
21 |
'momentum': 0.937, # SGD momentum/Adam beta1
|
22 |
'weight_decay': 5e-4, # optimizer weight decay
|
23 |
-
'giou': 0.05, #
|
24 |
'cls': 0.5, # cls loss gain
|
25 |
'cls_pw': 1.0, # cls BCELoss positive_weight
|
26 |
-
'obj': 1.0, # obj loss gain (
|
27 |
'obj_pw': 1.0, # obj BCELoss positive_weight
|
28 |
-
'iou_t': 0.20, #
|
29 |
'anchor_t': 4.0, # anchor-multiple threshold
|
30 |
-
'fl_gamma': 0.0, # focal loss gamma (efficientDet default
|
31 |
'hsv_h': 0.015, # image HSV-Hue augmentation (fraction)
|
32 |
'hsv_s': 0.7, # image HSV-Saturation augmentation (fraction)
|
33 |
'hsv_v': 0.4, # image HSV-Value augmentation (fraction)
|
34 |
'degrees': 0.0, # image rotation (+/- deg)
|
35 |
'translate': 0.5, # image translation (+/- fraction)
|
36 |
'scale': 0.5, # image scale (+/- gain)
|
37 |
-
'shear': 0.0
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
def train(hyp, tb_writer, opt, device):
|
@@ -47,8 +51,7 @@ def train(hyp, tb_writer, opt, device):
|
|
47 |
results_file = log_dir + os.sep + 'results.txt'
|
48 |
epochs, batch_size, total_batch_size, weights, rank = \
|
49 |
opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.local_rank
|
50 |
-
# TODO:
|
51 |
-
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
|
52 |
|
53 |
# Save run settings
|
54 |
with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
|
@@ -99,7 +102,7 @@ def train(hyp, tb_writer, opt, device):
|
|
99 |
else:
|
100 |
pg0.append(v) # all else
|
101 |
|
102 |
-
if hyp['optimizer'] == '
|
103 |
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
104 |
else:
|
105 |
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
@@ -110,9 +113,9 @@ def train(hyp, tb_writer, opt, device):
|
|
110 |
del pg0, pg1, pg2
|
111 |
|
112 |
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
|
|
113 |
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
|
114 |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
115 |
-
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
116 |
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
117 |
|
118 |
# Load Model
|
|
|
16 |
from utils.utils import *
|
17 |
|
18 |
# Hyperparameters
|
19 |
+
hyp = {'optimizer': 'SGD', # ['Adam', 'SGD', ...] from torch.optim
|
20 |
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
|
21 |
'momentum': 0.937, # SGD momentum/Adam beta1
|
22 |
'weight_decay': 5e-4, # optimizer weight decay
|
23 |
+
'giou': 0.05, # GIoU loss gain
|
24 |
'cls': 0.5, # cls loss gain
|
25 |
'cls_pw': 1.0, # cls BCELoss positive_weight
|
26 |
+
'obj': 1.0, # obj loss gain (scale with pixels)
|
27 |
'obj_pw': 1.0, # obj BCELoss positive_weight
|
28 |
+
'iou_t': 0.20, # IoU training threshold
|
29 |
'anchor_t': 4.0, # anchor-multiple threshold
|
30 |
+
'fl_gamma': 0.0, # focal loss gamma (efficientDet default gamma=1.5)
|
31 |
'hsv_h': 0.015, # image HSV-Hue augmentation (fraction)
|
32 |
'hsv_s': 0.7, # image HSV-Saturation augmentation (fraction)
|
33 |
'hsv_v': 0.4, # image HSV-Value augmentation (fraction)
|
34 |
'degrees': 0.0, # image rotation (+/- deg)
|
35 |
'translate': 0.5, # image translation (+/- fraction)
|
36 |
'scale': 0.5, # image scale (+/- gain)
|
37 |
+
'shear': 0.0, # image shear (+/- deg)
|
38 |
+
'perspective': 0.0, # image perspective (+/- fraction), range 0-0.001
|
39 |
+
'flipud': 0.0, # image flip up-down (probability)
|
40 |
+
'fliplr': 0.5, # image flip left-right (probability)
|
41 |
+
'mixup': 0.0} # image mixup (probability)
|
42 |
|
43 |
|
44 |
def train(hyp, tb_writer, opt, device):
|
|
|
51 |
results_file = log_dir + os.sep + 'results.txt'
|
52 |
epochs, batch_size, total_batch_size, weights, rank = \
|
53 |
opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.local_rank
|
54 |
+
# TODO: Use DDP logging. Only the first process is allowed to log.
|
|
|
55 |
|
56 |
# Save run settings
|
57 |
with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
|
|
|
102 |
else:
|
103 |
pg0.append(v) # all else
|
104 |
|
105 |
+
if hyp['optimizer'] == 'Adam':
|
106 |
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
107 |
else:
|
108 |
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
|
|
113 |
del pg0, pg1, pg2
|
114 |
|
115 |
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
116 |
+
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
|
117 |
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
|
118 |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
|
|
119 |
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
120 |
|
121 |
# Load Model
|
utils/datasets.py
CHANGED
@@ -484,11 +484,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
484 |
shapes = None
|
485 |
|
486 |
# MixUp https://arxiv.org/pdf/1710.09412.pdf
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
|
493 |
else:
|
494 |
# Load image
|
@@ -517,7 +517,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
517 |
degrees=hyp['degrees'],
|
518 |
translate=hyp['translate'],
|
519 |
scale=hyp['scale'],
|
520 |
-
shear=hyp['shear']
|
|
|
521 |
|
522 |
# Augment colorspace
|
523 |
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
|
@@ -528,28 +529,23 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
528 |
|
529 |
nL = len(labels) # number of labels
|
530 |
if nL:
|
531 |
-
# convert xyxy to xywh
|
532 |
-
labels[:,
|
533 |
-
|
534 |
-
# Normalize coordinates 0 - 1
|
535 |
-
labels[:, [2, 4]] /= img.shape[0] # height
|
536 |
-
labels[:, [1, 3]] /= img.shape[1] # width
|
537 |
|
538 |
if self.augment:
|
539 |
-
#
|
540 |
-
|
541 |
-
if lr_flip and random.random() < 0.5:
|
542 |
-
img = np.fliplr(img)
|
543 |
-
if nL:
|
544 |
-
labels[:, 1] = 1 - labels[:, 1]
|
545 |
-
|
546 |
-
# random up-down flip
|
547 |
-
ud_flip = False
|
548 |
-
if ud_flip and random.random() < 0.5:
|
549 |
img = np.flipud(img)
|
550 |
if nL:
|
551 |
labels[:, 2] = 1 - labels[:, 2]
|
552 |
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
labels_out = torch.zeros((nL, 6))
|
554 |
if nL:
|
555 |
labels_out[:, 1:] = torch.from_numpy(labels)
|
@@ -661,6 +657,7 @@ def load_mosaic(self, index):
|
|
661 |
translate=self.hyp['translate'],
|
662 |
scale=self.hyp['scale'],
|
663 |
shear=self.hyp['shear'],
|
|
|
664 |
border=self.mosaic_border) # border to remove
|
665 |
|
666 |
return img4, labels4
|
|
|
484 |
shapes = None
|
485 |
|
486 |
# MixUp https://arxiv.org/pdf/1710.09412.pdf
|
487 |
+
if random.random() < hyp['mixup']:
|
488 |
+
img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
|
489 |
+
r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
|
490 |
+
img = (img * r + img2 * (1 - r)).astype(np.uint8)
|
491 |
+
labels = np.concatenate((labels, labels2), 0)
|
492 |
|
493 |
else:
|
494 |
# Load image
|
|
|
517 |
degrees=hyp['degrees'],
|
518 |
translate=hyp['translate'],
|
519 |
scale=hyp['scale'],
|
520 |
+
shear=hyp['shear'],
|
521 |
+
perspective=hyp['perspective'])
|
522 |
|
523 |
# Augment colorspace
|
524 |
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
|
|
|
529 |
|
530 |
nL = len(labels) # number of labels
|
531 |
if nL:
|
532 |
+
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
|
533 |
+
labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
|
534 |
+
labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
|
|
|
|
|
|
|
535 |
|
536 |
if self.augment:
|
537 |
+
# flip up-down
|
538 |
+
if random.random() < hyp['flipud']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
img = np.flipud(img)
|
540 |
if nL:
|
541 |
labels[:, 2] = 1 - labels[:, 2]
|
542 |
|
543 |
+
# flip left-right
|
544 |
+
if random.random() < hyp['fliplr']:
|
545 |
+
img = np.fliplr(img)
|
546 |
+
if nL:
|
547 |
+
labels[:, 1] = 1 - labels[:, 1]
|
548 |
+
|
549 |
labels_out = torch.zeros((nL, 6))
|
550 |
if nL:
|
551 |
labels_out[:, 1:] = torch.from_numpy(labels)
|
|
|
657 |
translate=self.hyp['translate'],
|
658 |
scale=self.hyp['scale'],
|
659 |
shear=self.hyp['shear'],
|
660 |
+
perspective=self.hyp['perspective'],
|
661 |
border=self.mosaic_border) # border to remove
|
662 |
|
663 |
return img4, labels4
|