glenn-jocher commited on
Commit
9728e2b
1 Parent(s): e9a0ae6

--image_weights bug fix (#1524)

Browse files
Files changed (2) hide show
  1. train.py +3 -2
  2. utils/datasets.py +7 -6
train.py CHANGED
@@ -181,8 +181,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
181
 
182
  # Trainloader
183
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
184
- hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect,
185
- rank=rank, world_size=opt.world_size, workers=opt.workers)
 
186
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
187
  nb = len(dataloader) # number of batches
188
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
 
181
 
182
  # Trainloader
183
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
184
+ hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
185
+ world_size=opt.world_size, workers=opt.workers,
186
+ image_weights=opt.image_weights)
187
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
188
  nb = len(dataloader) # number of batches
189
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
utils/datasets.py CHANGED
@@ -55,7 +55,7 @@ def exif_size(img):
55
 
56
 
57
  def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
58
- rank=-1, world_size=1, workers=8):
59
  # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
60
  with torch_distributed_zero_first(rank):
61
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
@@ -66,7 +66,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
66
  single_cls=opt.single_cls,
67
  stride=int(stride),
68
  pad=pad,
69
- rank=rank)
 
70
 
71
  batch_size = min(batch_size, len(dataset))
72
  nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
@@ -392,6 +393,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
392
  nb = bi[-1] + 1 # number of batches
393
  self.batch = bi # batch index of image
394
  self.n = n
 
395
 
396
  # Rectangular Training
397
  if self.rect:
@@ -485,8 +487,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
485
  # return self
486
 
487
  def __getitem__(self, index):
488
- if self.image_weights:
489
- index = self.indices[index]
490
 
491
  hyp = self.hyp
492
  mosaic = self.mosaic and random.random() < hyp['mosaic']
@@ -497,7 +498,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
497
 
498
  # MixUp https://arxiv.org/pdf/1710.09412.pdf
499
  if random.random() < hyp['mixup']:
500
- img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
501
  r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
502
  img = (img * r + img2 * (1 - r)).astype(np.uint8)
503
  labels = np.concatenate((labels, labels2), 0)
@@ -619,7 +620,7 @@ def load_mosaic(self, index):
619
  labels4 = []
620
  s = self.img_size
621
  yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
622
- indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
623
  for i, index in enumerate(indices):
624
  # Load image
625
  img, _, (h, w) = load_image(self, index)
 
55
 
56
 
57
  def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
58
+ rank=-1, world_size=1, workers=8, image_weights=False):
59
  # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
60
  with torch_distributed_zero_first(rank):
61
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
 
66
  single_cls=opt.single_cls,
67
  stride=int(stride),
68
  pad=pad,
69
+ rank=rank,
70
+ image_weights=image_weights)
71
 
72
  batch_size = min(batch_size, len(dataset))
73
  nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
 
393
  nb = bi[-1] + 1 # number of batches
394
  self.batch = bi # batch index of image
395
  self.n = n
396
+ self.indices = range(n)
397
 
398
  # Rectangular Training
399
  if self.rect:
 
487
  # return self
488
 
489
  def __getitem__(self, index):
490
+ index = self.indices[index] # linear, shuffled, or image_weights
 
491
 
492
  hyp = self.hyp
493
  mosaic = self.mosaic and random.random() < hyp['mosaic']
 
498
 
499
  # MixUp https://arxiv.org/pdf/1710.09412.pdf
500
  if random.random() < hyp['mixup']:
501
+ img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
502
  r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
503
  img = (img * r + img2 * (1 - r)).astype(np.uint8)
504
  labels = np.concatenate((labels, labels2), 0)
 
620
  labels4 = []
621
  s = self.img_size
622
  yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
623
+ indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices
624
  for i, index in enumerate(indices):
625
  # Load image
626
  img, _, (h, w) = load_image(self, index)