glenn-jocher
commited on
Commit
•
b8557f8
1
Parent(s):
3b06225
add stride to datasets.py
Browse files- test.py +1 -0
- train.py +4 -2
- utils/datasets.py +2 -2
test.py
CHANGED
@@ -73,6 +73,7 @@ def test(data,
|
|
73 |
batch_size,
|
74 |
rect=True, # rectangular inference
|
75 |
single_cls=opt.single_cls, # single class mode
|
|
|
76 |
pad=0.5) # padding
|
77 |
batch_size = min(batch_size, len(dataset))
|
78 |
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
|
|
73 |
batch_size,
|
74 |
rect=True, # rectangular inference
|
75 |
single_cls=opt.single_cls, # single class mode
|
76 |
+
stride=int(max(model.stride)), # model stride
|
77 |
pad=0.5) # padding
|
78 |
batch_size = min(batch_size, len(dataset))
|
79 |
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
train.py
CHANGED
@@ -160,7 +160,8 @@ def train(hyp):
|
|
160 |
hyp=hyp, # augmentation hyperparameters
|
161 |
rect=opt.rect, # rectangular training
|
162 |
cache_images=opt.cache_images,
|
163 |
-
single_cls=opt.single_cls
|
|
|
164 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
165 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
166 |
|
@@ -179,7 +180,8 @@ def train(hyp):
|
|
179 |
hyp=hyp,
|
180 |
rect=True,
|
181 |
cache_images=opt.cache_images,
|
182 |
-
single_cls=opt.single_cls
|
|
|
183 |
batch_size=batch_size,
|
184 |
num_workers=nw,
|
185 |
pin_memory=True,
|
|
|
160 |
hyp=hyp, # augmentation hyperparameters
|
161 |
rect=opt.rect, # rectangular training
|
162 |
cache_images=opt.cache_images,
|
163 |
+
single_cls=opt.single_cls,
|
164 |
+
stride=gs)
|
165 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
166 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
|
167 |
|
|
|
180 |
hyp=hyp,
|
181 |
rect=True,
|
182 |
cache_images=opt.cache_images,
|
183 |
+
single_cls=opt.single_cls,
|
184 |
+
stride=gs),
|
185 |
batch_size=batch_size,
|
186 |
num_workers=nw,
|
187 |
pin_memory=True,
|
utils/datasets.py
CHANGED
@@ -258,7 +258,7 @@ class LoadStreams: # multiple IP or RTSP cameras
|
|
258 |
|
259 |
class LoadImagesAndLabels(Dataset): # for training/testing
|
260 |
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
|
261 |
-
cache_images=False, single_cls=False, pad=0.0):
|
262 |
try:
|
263 |
path = str(Path(path)) # os-agnostic
|
264 |
parent = str(Path(path).parent) + os.sep
|
@@ -325,7 +325,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
325 |
elif mini > 1:
|
326 |
shapes[i] = [1, 1 / mini]
|
327 |
|
328 |
-
self.batch_shapes = np.ceil(np.array(shapes) * img_size /
|
329 |
|
330 |
# Cache labels
|
331 |
self.imgs = [None] * n
|
|
|
258 |
|
259 |
class LoadImagesAndLabels(Dataset): # for training/testing
|
260 |
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
|
261 |
+
cache_images=False, single_cls=False, stride=32, pad=0.0):
|
262 |
try:
|
263 |
path = str(Path(path)) # os-agnostic
|
264 |
parent = str(Path(path).parent) + os.sep
|
|
|
325 |
elif mini > 1:
|
326 |
shapes[i] = [1, 1 / mini]
|
327 |
|
328 |
+
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
|
329 |
|
330 |
# Cache labels
|
331 |
self.imgs = [None] * n
|