MakeCartoonVideo / stylegan_prepare_data.py
Arnaudding001's picture
Create stylegan_prepare_data.py
c642d93
raw
history blame
3.01 kB
import argparse
from io import BytesIO
import multiprocessing
from functools import partial
import os
from PIL import Image
import lmdb
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import functional as trans_fn
def resize_and_convert(img, size, resample, quality=100):
img = trans_fn.resize(img, size, resample)
img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format="jpeg", quality=quality)
val = buffer.getvalue()
return val
def resize_multiple(
img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
):
imgs = []
for size in sizes:
imgs.append(resize_and_convert(img, size, resample, quality))
return imgs
def resize_worker(img_file, sizes, resample):
i, file = img_file
img = Image.open(file)
img = img.convert("RGB")
out = resize_multiple(img, sizes=sizes, resample=resample)
return i, out
def prepare(
env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
):
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
files = sorted(dataset.imgs, key=lambda x: x[0])
files = [(i, file) for i, (file, label) in enumerate(files)]
total = 0
with multiprocessing.Pool(n_worker) as pool:
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
for size, img in zip(sizes, imgs):
key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
with env.begin(write=True) as txn:
txn.put(key, img)
total += 1
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess images for model training")
parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
parser.add_argument(
"--size",
type=str,
default="128,256,512,1024",
help="resolutions of images for the dataset",
)
parser.add_argument(
"--n_worker",
type=int,
default=8,
help="number of workers for preparing dataset",
)
parser.add_argument(
"--resample",
type=str,
default="lanczos",
help="resampling methods for resizing images",
)
parser.add_argument("path", type=str, help="path to the image dataset")
args = parser.parse_args()
if not os.path.exists(args.out):
os.makedirs(args.out)
resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
resample = resample_map[args.resample]
sizes = [int(s.strip()) for s in args.size.split(",")]
print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
imgset = datasets.ImageFolder(args.path)
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)