Commit
•
1c5e92a
1
Parent(s):
b92430a
Add generator and worker seed (#8602)
Browse files* Add generator and worker seed
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update dataloaders.py
* Update dataloaders.py
* Update dataloaders.py
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
- utils/dataloaders.py +12 -1
utils/dataloaders.py
CHANGED
@@ -91,6 +91,13 @@ def exif_transpose(image):
|
|
91 |
return image
|
92 |
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def create_dataloader(path,
|
95 |
imgsz,
|
96 |
batch_size,
|
@@ -130,13 +137,17 @@ def create_dataloader(path,
|
|
130 |
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
131 |
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
132 |
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
|
|
|
|
133 |
return loader(dataset,
|
134 |
batch_size=batch_size,
|
135 |
shuffle=shuffle and sampler is None,
|
136 |
num_workers=nw,
|
137 |
sampler=sampler,
|
138 |
pin_memory=True,
|
139 |
-
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn
|
|
|
|
|
140 |
|
141 |
|
142 |
class InfiniteDataLoader(dataloader.DataLoader):
|
|
|
91 |
return image
|
92 |
|
93 |
|
94 |
+
def seed_worker(worker_id):
|
95 |
+
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
96 |
+
worker_seed = torch.initial_seed() % 2 ** 32
|
97 |
+
np.random.seed(worker_seed)
|
98 |
+
random.seed(worker_seed)
|
99 |
+
|
100 |
+
|
101 |
def create_dataloader(path,
|
102 |
imgsz,
|
103 |
batch_size,
|
|
|
137 |
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
138 |
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
139 |
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
140 |
+
generator = torch.Generator()
|
141 |
+
generator.manual_seed(0)
|
142 |
return loader(dataset,
|
143 |
batch_size=batch_size,
|
144 |
shuffle=shuffle and sampler is None,
|
145 |
num_workers=nw,
|
146 |
sampler=sampler,
|
147 |
pin_memory=True,
|
148 |
+
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
|
149 |
+
worker_init_fn=seed_worker,
|
150 |
+
generator=generator), dataset
|
151 |
|
152 |
|
153 |
class InfiniteDataLoader(dataloader.DataLoader):
|