UnglvKitDe pre-commit-ci[bot] glenn-jocher commited on
Commit
1c5e92a
1 Parent(s): b92430a

Add generator and worker seed (#8602)

Browse files

* Add generator and worker seed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataloaders.py

* Update dataloaders.py

* Update dataloaders.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. utils/dataloaders.py +12 -1
utils/dataloaders.py CHANGED
@@ -91,6 +91,13 @@ def exif_transpose(image):
91
  return image
92
 
93
 
 
 
 
 
 
 
 
94
  def create_dataloader(path,
95
  imgsz,
96
  batch_size,
@@ -130,13 +137,17 @@ def create_dataloader(path,
130
  nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
131
  sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
132
  loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
 
 
133
  return loader(dataset,
134
  batch_size=batch_size,
135
  shuffle=shuffle and sampler is None,
136
  num_workers=nw,
137
  sampler=sampler,
138
  pin_memory=True,
139
- collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
 
 
140
 
141
 
142
  class InfiniteDataLoader(dataloader.DataLoader):
 
91
  return image
92
 
93
 
94
+ def seed_worker(worker_id):
95
+ # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
96
+ worker_seed = torch.initial_seed() % 2 ** 32
97
+ np.random.seed(worker_seed)
98
+ random.seed(worker_seed)
99
+
100
+
101
  def create_dataloader(path,
102
  imgsz,
103
  batch_size,
 
137
  nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
138
  sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
139
  loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
140
+ generator = torch.Generator()
141
+ generator.manual_seed(0)
142
  return loader(dataset,
143
  batch_size=batch_size,
144
  shuffle=shuffle and sampler is None,
145
  num_workers=nw,
146
  sampler=sampler,
147
  pin_memory=True,
148
+ collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
149
+ worker_init_fn=seed_worker,
150
+ generator=generator), dataset
151
 
152
 
153
  class InfiniteDataLoader(dataloader.DataLoader):