glenn-jocher
commited on
Commit
•
9728e2b
1
Parent(s):
e9a0ae6
--image_weights bug fix (#1524)
Browse files- train.py +3 -2
- utils/datasets.py +7 -6
train.py
CHANGED
@@ -181,8 +181,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
181 |
|
182 |
# Trainloader
|
183 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
184 |
-
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect,
|
185 |
-
|
|
|
186 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
187 |
nb = len(dataloader) # number of batches
|
188 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
|
|
181 |
|
182 |
# Trainloader
|
183 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
184 |
+
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
|
185 |
+
world_size=opt.world_size, workers=opt.workers,
|
186 |
+
image_weights=opt.image_weights)
|
187 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
188 |
nb = len(dataloader) # number of batches
|
189 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
utils/datasets.py
CHANGED
@@ -55,7 +55,7 @@ def exif_size(img):
|
|
55 |
|
56 |
|
57 |
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
58 |
-
rank=-1, world_size=1, workers=8):
|
59 |
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
60 |
with torch_distributed_zero_first(rank):
|
61 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
@@ -66,7 +66,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|
66 |
single_cls=opt.single_cls,
|
67 |
stride=int(stride),
|
68 |
pad=pad,
|
69 |
-
rank=rank
|
|
|
70 |
|
71 |
batch_size = min(batch_size, len(dataset))
|
72 |
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
@@ -392,6 +393,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
392 |
nb = bi[-1] + 1 # number of batches
|
393 |
self.batch = bi # batch index of image
|
394 |
self.n = n
|
|
|
395 |
|
396 |
# Rectangular Training
|
397 |
if self.rect:
|
@@ -485,8 +487,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
485 |
# return self
|
486 |
|
487 |
def __getitem__(self, index):
|
488 |
-
|
489 |
-
index = self.indices[index]
|
490 |
|
491 |
hyp = self.hyp
|
492 |
mosaic = self.mosaic and random.random() < hyp['mosaic']
|
@@ -497,7 +498,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
497 |
|
498 |
# MixUp https://arxiv.org/pdf/1710.09412.pdf
|
499 |
if random.random() < hyp['mixup']:
|
500 |
-
img2, labels2 = load_mosaic(self, random.randint(0,
|
501 |
r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
|
502 |
img = (img * r + img2 * (1 - r)).astype(np.uint8)
|
503 |
labels = np.concatenate((labels, labels2), 0)
|
@@ -619,7 +620,7 @@ def load_mosaic(self, index):
|
|
619 |
labels4 = []
|
620 |
s = self.img_size
|
621 |
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
|
622 |
-
indices = [index] + [random.randint(0,
|
623 |
for i, index in enumerate(indices):
|
624 |
# Load image
|
625 |
img, _, (h, w) = load_image(self, index)
|
|
|
55 |
|
56 |
|
57 |
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
58 |
+
rank=-1, world_size=1, workers=8, image_weights=False):
|
59 |
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
60 |
with torch_distributed_zero_first(rank):
|
61 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
|
|
66 |
single_cls=opt.single_cls,
|
67 |
stride=int(stride),
|
68 |
pad=pad,
|
69 |
+
rank=rank,
|
70 |
+
image_weights=image_weights)
|
71 |
|
72 |
batch_size = min(batch_size, len(dataset))
|
73 |
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
|
|
393 |
nb = bi[-1] + 1 # number of batches
|
394 |
self.batch = bi # batch index of image
|
395 |
self.n = n
|
396 |
+
self.indices = range(n)
|
397 |
|
398 |
# Rectangular Training
|
399 |
if self.rect:
|
|
|
487 |
# return self
|
488 |
|
489 |
def __getitem__(self, index):
|
490 |
+
index = self.indices[index] # linear, shuffled, or image_weights
|
|
|
491 |
|
492 |
hyp = self.hyp
|
493 |
mosaic = self.mosaic and random.random() < hyp['mosaic']
|
|
|
498 |
|
499 |
# MixUp https://arxiv.org/pdf/1710.09412.pdf
|
500 |
if random.random() < hyp['mixup']:
|
501 |
+
img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
|
502 |
r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
|
503 |
img = (img * r + img2 * (1 - r)).astype(np.uint8)
|
504 |
labels = np.concatenate((labels, labels2), 0)
|
|
|
620 |
labels4 = []
|
621 |
s = self.img_size
|
622 |
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
|
623 |
+
indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices
|
624 |
for i, index in enumerate(indices):
|
625 |
# Load image
|
626 |
img, _, (h, w) = load_image(self, index)
|