glenn-jocher
commited on
Commit
•
958ab92
1
Parent(s):
0cfc5b2
Remove `opt` from `create_dataloader()`` (#3552)
Browse files- test.py +1 -1
- train.py +9 -8
- utils/datasets.py +3 -3
test.py
CHANGED
@@ -88,7 +88,7 @@ def test(data,
|
|
88 |
if device.type != 'cpu':
|
89 |
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
|
90 |
task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
|
91 |
-
dataloader = create_dataloader(data[task], imgsz, batch_size, gs,
|
92 |
prefix=colorstr(f'{task}: '))[0]
|
93 |
|
94 |
seen = 0
|
|
|
88 |
if device.type != 'cpu':
|
89 |
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
|
90 |
task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
|
91 |
+
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True,
|
92 |
prefix=colorstr(f'{task}: '))[0]
|
93 |
|
94 |
seen = 0
|
train.py
CHANGED
@@ -41,8 +41,9 @@ logger = logging.getLogger(__name__)
|
|
41 |
|
42 |
def train(hyp, opt, device, tb_writer=None):
|
43 |
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
44 |
-
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
45 |
-
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
|
|
46 |
|
47 |
# Directories
|
48 |
wdir = save_dir / 'weights'
|
@@ -75,8 +76,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|
75 |
if wandb_logger.wandb:
|
76 |
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
|
77 |
|
78 |
-
nc = 1 if
|
79 |
-
names = ['item'] if
|
80 |
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
81 |
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
82 |
|
@@ -187,7 +188,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
187 |
logger.info('Using SyncBatchNorm()')
|
188 |
|
189 |
# Trainloader
|
190 |
-
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs,
|
191 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
|
192 |
world_size=opt.world_size, workers=opt.workers,
|
193 |
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
@@ -197,7 +198,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
197 |
|
198 |
# Process 0
|
199 |
if rank in [-1, 0]:
|
200 |
-
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs,
|
201 |
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
202 |
world_size=opt.world_size, workers=opt.workers,
|
203 |
pad=0.5, prefix=colorstr('val: '))[0]
|
@@ -357,7 +358,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
357 |
batch_size=batch_size * 2,
|
358 |
imgsz=imgsz_test,
|
359 |
model=ema.ema,
|
360 |
-
single_cls=
|
361 |
dataloader=testloader,
|
362 |
save_dir=save_dir,
|
363 |
save_json=is_coco and final_epoch,
|
@@ -429,7 +430,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
429 |
conf_thres=0.001,
|
430 |
iou_thres=0.7,
|
431 |
model=attempt_load(m, device).half(),
|
432 |
-
single_cls=
|
433 |
dataloader=testloader,
|
434 |
save_dir=save_dir,
|
435 |
save_json=True,
|
|
|
41 |
|
42 |
def train(hyp, opt, device, tb_writer=None):
|
43 |
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
44 |
+
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
|
45 |
+
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
|
46 |
+
opt.single_cls
|
47 |
|
48 |
# Directories
|
49 |
wdir = save_dir / 'weights'
|
|
|
76 |
if wandb_logger.wandb:
|
77 |
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
|
78 |
|
79 |
+
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
80 |
+
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
81 |
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
82 |
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
83 |
|
|
|
188 |
logger.info('Using SyncBatchNorm()')
|
189 |
|
190 |
# Trainloader
|
191 |
+
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
|
192 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
|
193 |
world_size=opt.world_size, workers=opt.workers,
|
194 |
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
|
|
198 |
|
199 |
# Process 0
|
200 |
if rank in [-1, 0]:
|
201 |
+
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
|
202 |
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
203 |
world_size=opt.world_size, workers=opt.workers,
|
204 |
pad=0.5, prefix=colorstr('val: '))[0]
|
|
|
358 |
batch_size=batch_size * 2,
|
359 |
imgsz=imgsz_test,
|
360 |
model=ema.ema,
|
361 |
+
single_cls=single_cls,
|
362 |
dataloader=testloader,
|
363 |
save_dir=save_dir,
|
364 |
save_json=is_coco and final_epoch,
|
|
|
430 |
conf_thres=0.001,
|
431 |
iou_thres=0.7,
|
432 |
model=attempt_load(m, device).half(),
|
433 |
+
single_cls=single_cls,
|
434 |
dataloader=testloader,
|
435 |
save_dir=save_dir,
|
436 |
save_json=True,
|
utils/datasets.py
CHANGED
@@ -62,8 +62,8 @@ def exif_size(img):
|
|
62 |
return s
|
63 |
|
64 |
|
65 |
-
def create_dataloader(path, imgsz, batch_size, stride,
|
66 |
-
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
|
67 |
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
68 |
with torch_distributed_zero_first(rank):
|
69 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
@@ -71,7 +71,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|
71 |
hyp=hyp, # augmentation hyperparameters
|
72 |
rect=rect, # rectangular training
|
73 |
cache_images=cache,
|
74 |
-
single_cls=
|
75 |
stride=int(stride),
|
76 |
pad=pad,
|
77 |
image_weights=image_weights,
|
|
|
62 |
return s
|
63 |
|
64 |
|
65 |
+
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
66 |
+
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
|
67 |
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
68 |
with torch_distributed_zero_first(rank):
|
69 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
|
|
71 |
hyp=hyp, # augmentation hyperparameters
|
72 |
rect=rect, # rectangular training
|
73 |
cache_images=cache,
|
74 |
+
single_cls=single_cls,
|
75 |
stride=int(stride),
|
76 |
pad=pad,
|
77 |
image_weights=image_weights,
|