EscherNet / croco /datasets /crops /extract_crops_from_images.py
kxhit
update
5f093a6
raw
history blame
5.49 kB
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Extracting crops for pre-training
# --------------------------------------------------------
import os
import argparse
from tqdm import tqdm
from PIL import Image
import functools
from multiprocessing import Pool
import math
def arg_parser():
parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
parser.add_argument('--crops', type=str, required=True, help='crop file')
parser.add_argument('--root-dir', type=str, required=True, help='root directory')
parser.add_argument('--output-dir', type=str, required=True, help='output directory')
parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
return parser
def main(args):
listing_path = os.path.join(args.output_dir, 'listing.txt')
print(f'Loading list of crops ... ({args.nthread} threads)')
crops, num_crops_to_generate = load_crop_file(args.crops)
print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
del crops
os.makedirs(args.output_dir, exist_ok=True)
mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
call = functools.partial(save_image_crops, args)
print(f"Generating cropped images to {args.output_dir} ...")
with open(listing_path, 'w') as listing:
listing.write('# pair_path\n')
for results in tqdm(mmap(call, jobs), total=len(jobs)):
for path in results:
listing.write(f'{path}\n')
print('Finished writing listing to', listing_path)
def load_crop_file(path):
data = open(path).read().splitlines()
pairs = []
num_crops_to_generate = 0
for line in tqdm(data):
if line.startswith('#'):
continue
line = line.split(', ')
if len(line) < 8:
img1, img2, rotation = line
pairs.append((img1, img2, int(rotation), []))
else:
l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
pairs[-1][-1].append((rect1, rect2))
num_crops_to_generate += 1
return pairs, num_crops_to_generate
def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
jobs = []
powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
def get_path(idx):
idx_array = []
d = idx
for level in range(num_levels - 1):
idx_array.append(idx // powers[level])
idx = idx % powers[level]
idx_array.append(d)
return '/'.join(map(lambda x: hex(x)[2:], idx_array))
idx = 0
for pair_data in tqdm(pairs):
img1, img2, rotation, crops = pair_data
if -60 <= rotation and rotation <= 60:
rotation = 0 # most likely not a true rotation
paths = [get_path(idx + k) for k in range(len(crops))]
idx += len(crops)
jobs.append(((img1, img2), rotation, crops, paths))
return jobs
def load_image(path):
try:
return Image.open(path).convert('RGB')
except Exception as e:
print('skipping', path, e)
raise OSError()
def save_image_crops(args, data):
# load images
img_pair, rot, crops, paths = data
try:
img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
except OSError as e:
return []
def area(sz):
return sz[0] * sz[1]
tgt_size = (args.imsize, args.imsize)
def prepare_crop(img, rect, rot=0):
# actual crop
img = img.crop(rect)
# resize to desired size
interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
img = img.resize(tgt_size, resample=interp)
# rotate the image
rot90 = (round(rot/90) % 4) * 90
if rot90 == 90:
img = img.transpose(Image.Transpose.ROTATE_90)
elif rot90 == 180:
img = img.transpose(Image.Transpose.ROTATE_180)
elif rot90 == 270:
img = img.transpose(Image.Transpose.ROTATE_270)
return img
results = []
for (rect1, rect2), path in zip(crops, paths):
crop1 = prepare_crop(img1, rect1)
crop2 = prepare_crop(img2, rect2, rot)
fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
assert not os.path.isfile(fullpath1), fullpath1
assert not os.path.isfile(fullpath2), fullpath2
crop1.save(fullpath1)
crop2.save(fullpath2)
results.append(path)
return results
if __name__ == '__main__':
args = arg_parser().parse_args()
main(args)