glenn-jocher
commited on
Commit
•
4b5f480
1
Parent(s):
3b394b9
Update datasets.py (#494)
Browse files- utils/datasets.py +29 -49
utils/datasets.py
CHANGED
@@ -17,7 +17,7 @@ from tqdm import tqdm
|
|
17 |
from utils.utils import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
|
18 |
|
19 |
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
20 |
-
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff','.dng']
|
21 |
vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
|
22 |
|
23 |
# Get orientation exif tag
|
@@ -46,17 +46,18 @@ def exif_size(img):
|
|
46 |
return s
|
47 |
|
48 |
|
49 |
-
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
|
|
50 |
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
|
51 |
with torch_distributed_zero_first(local_rank):
|
52 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
|
61 |
batch_size = min(batch_size, len(dataset))
|
62 |
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, 8]) # number of workers
|
@@ -305,7 +306,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
305 |
f += glob.iglob(p + os.sep + '*.*')
|
306 |
else:
|
307 |
raise Exception('%s does not exist' % p)
|
308 |
-
self.img_files = sorted(
|
|
|
309 |
except Exception as e:
|
310 |
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
|
311 |
|
@@ -566,6 +568,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
566 |
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
|
567 |
|
568 |
|
|
|
569 |
def load_image(self, index):
|
570 |
# loads 1 image from dataset, returns img, original hw, resized hw
|
571 |
img = self.imgs[index]
|
@@ -766,26 +769,28 @@ def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10,
|
|
766 |
# h = (xy[:, 3] - xy[:, 1]) * reduction
|
767 |
# xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
|
768 |
|
769 |
-
#
|
770 |
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
|
771 |
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
|
772 |
-
w = xy[:, 2] - xy[:, 0]
|
773 |
-
h = xy[:, 3] - xy[:, 1]
|
774 |
-
area = w * h
|
775 |
-
area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
|
776 |
-
ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
|
777 |
-
i = (w > 2) & (h > 2) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 20)
|
778 |
|
|
|
|
|
779 |
targets = targets[i]
|
780 |
targets[:, 1:5] = xy[i]
|
781 |
|
782 |
return img, targets
|
783 |
|
784 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
785 |
def cutout(image, labels):
|
786 |
-
# https://arxiv.org/abs/1708.04552
|
787 |
-
# https://github.com/hysts/pytorch_cutout/blob/master/dataloader.py
|
788 |
-
# https://towardsdatascience.com/when-conventional-wisdom-fails-revisiting-data-augmentation-for-self-driving-cars-4831998c5509
|
789 |
h, w = image.shape[:2]
|
790 |
|
791 |
def bbox_ioa(box1, box2):
|
@@ -804,7 +809,6 @@ def cutout(image, labels):
|
|
804 |
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
|
805 |
|
806 |
# Intersection over box2 area
|
807 |
-
|
808 |
return inter_area / box2_area
|
809 |
|
810 |
# create random masks
|
@@ -831,7 +835,7 @@ def cutout(image, labels):
|
|
831 |
return labels
|
832 |
|
833 |
|
834 |
-
def reduce_img_size(path='
|
835 |
# creates a new ./images_reduced folder with reduced size images of maximum size img_size
|
836 |
path_new = path + '_reduced' # reduced images path
|
837 |
create_folder(path_new)
|
@@ -848,31 +852,7 @@ def reduce_img_size(path='../data/sm4/images', img_size=1024): # from utils.dat
|
|
848 |
print('WARNING: image failure %s' % f)
|
849 |
|
850 |
|
851 |
-
def
|
852 |
-
# Save images
|
853 |
-
formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
|
854 |
-
# for path in ['../coco/images/val2014', '../coco/images/train2014']:
|
855 |
-
for path in ['../data/sm4/images', '../data/sm4/background']:
|
856 |
-
create_folder(path + 'bmp')
|
857 |
-
for ext in formats: # ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
|
858 |
-
for f in tqdm(glob.glob('%s/*%s' % (path, ext)), desc='Converting %s' % ext):
|
859 |
-
cv2.imwrite(f.replace(ext.lower(), '.bmp').replace(path, path + 'bmp'), cv2.imread(f))
|
860 |
-
|
861 |
-
# Save labels
|
862 |
-
# for path in ['../coco/trainvalno5k.txt', '../coco/5k.txt']:
|
863 |
-
for file in ['../data/sm4/out_train.txt', '../data/sm4/out_test.txt']:
|
864 |
-
with open(file, 'r') as f:
|
865 |
-
lines = f.read()
|
866 |
-
# lines = f.read().replace('2014/', '2014bmp/') # coco
|
867 |
-
lines = lines.replace('/images', '/imagesbmp')
|
868 |
-
lines = lines.replace('/background', '/backgroundbmp')
|
869 |
-
for ext in formats:
|
870 |
-
lines = lines.replace(ext, '.bmp')
|
871 |
-
with open(file.replace('.txt', 'bmp.txt'), 'w') as f:
|
872 |
-
f.write(lines)
|
873 |
-
|
874 |
-
|
875 |
-
def recursive_dataset2bmp(dataset='../data/sm4_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
|
876 |
# Converts dataset to bmp (for faster training)
|
877 |
formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
|
878 |
for a, b, files in os.walk(dataset):
|
@@ -892,7 +872,7 @@ def recursive_dataset2bmp(dataset='../data/sm4_bmp'): # from utils.datasets imp
|
|
892 |
os.system("rm '%s'" % p)
|
893 |
|
894 |
|
895 |
-
def imagelist2folder(path='
|
896 |
# Copies all the images in a text file (list of images) into a folder
|
897 |
create_folder(path[:-4])
|
898 |
with open(path, 'r') as f:
|
@@ -901,7 +881,7 @@ def imagelist2folder(path='data/coco_64img.txt'): # from utils.datasets import
|
|
901 |
print(line)
|
902 |
|
903 |
|
904 |
-
def create_folder(path='./
|
905 |
# Create folder
|
906 |
if os.path.exists(path):
|
907 |
shutil.rmtree(path) # delete output folder
|
|
|
17 |
from utils.utils import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
|
18 |
|
19 |
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
20 |
+
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
|
21 |
vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
|
22 |
|
23 |
# Get orientation exif tag
|
|
|
46 |
return s
|
47 |
|
48 |
|
49 |
+
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
50 |
+
local_rank=-1, world_size=1):
|
51 |
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
|
52 |
with torch_distributed_zero_first(local_rank):
|
53 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
54 |
+
augment=augment, # augment images
|
55 |
+
hyp=hyp, # augmentation hyperparameters
|
56 |
+
rect=rect, # rectangular training
|
57 |
+
cache_images=cache,
|
58 |
+
single_cls=opt.single_cls,
|
59 |
+
stride=int(stride),
|
60 |
+
pad=pad)
|
61 |
|
62 |
batch_size = min(batch_size, len(dataset))
|
63 |
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, 8]) # number of workers
|
|
|
306 |
f += glob.iglob(p + os.sep + '*.*')
|
307 |
else:
|
308 |
raise Exception('%s does not exist' % p)
|
309 |
+
self.img_files = sorted(
|
310 |
+
[x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
|
311 |
except Exception as e:
|
312 |
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
|
313 |
|
|
|
568 |
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
|
569 |
|
570 |
|
571 |
+
# Ancillary functions --------------------------------------------------------------------------------------------------
|
572 |
def load_image(self, index):
|
573 |
# loads 1 image from dataset, returns img, original hw, resized hw
|
574 |
img = self.imgs[index]
|
|
|
769 |
# h = (xy[:, 3] - xy[:, 1]) * reduction
|
770 |
# xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
|
771 |
|
772 |
+
# clip boxes
|
773 |
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
|
774 |
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
|
|
|
|
|
|
|
|
|
|
|
|
|
775 |
|
776 |
+
# filter candidates
|
777 |
+
i = box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T)
|
778 |
targets = targets[i]
|
779 |
targets[:, 1:5] = xy[i]
|
780 |
|
781 |
return img, targets
|
782 |
|
783 |
|
784 |
+
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.2): # box1(4,n), box2(4,n)
|
785 |
+
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
786 |
+
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
787 |
+
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
788 |
+
ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
|
789 |
+
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates
|
790 |
+
|
791 |
+
|
792 |
def cutout(image, labels):
|
793 |
+
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552
|
|
|
|
|
794 |
h, w = image.shape[:2]
|
795 |
|
796 |
def bbox_ioa(box1, box2):
|
|
|
809 |
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
|
810 |
|
811 |
# Intersection over box2 area
|
|
|
812 |
return inter_area / box2_area
|
813 |
|
814 |
# create random masks
|
|
|
835 |
return labels
|
836 |
|
837 |
|
838 |
+
def reduce_img_size(path='path/images', img_size=1024): # from utils.datasets import *; reduce_img_size()
|
839 |
# creates a new ./images_reduced folder with reduced size images of maximum size img_size
|
840 |
path_new = path + '_reduced' # reduced images path
|
841 |
create_folder(path_new)
|
|
|
852 |
print('WARNING: image failure %s' % f)
|
853 |
|
854 |
|
855 |
+
def recursive_dataset2bmp(dataset='path/dataset_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
856 |
# Converts dataset to bmp (for faster training)
|
857 |
formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
|
858 |
for a, b, files in os.walk(dataset):
|
|
|
872 |
os.system("rm '%s'" % p)
|
873 |
|
874 |
|
875 |
+
def imagelist2folder(path='path/images.txt'): # from utils.datasets import *; imagelist2folder()
|
876 |
# Copies all the images in a text file (list of images) into a folder
|
877 |
create_folder(path[:-4])
|
878 |
with open(path, 'r') as f:
|
|
|
881 |
print(line)
|
882 |
|
883 |
|
884 |
+
def create_folder(path='./new'):
|
885 |
# Create folder
|
886 |
if os.path.exists(path):
|
887 |
shutil.rmtree(path) # delete output folder
|