|
from __future__ import annotations |
|
|
|
import json |
|
import math |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from einops import rearrange |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
|
|
class EditDataset(Dataset): |
|
def __init__( |
|
self, |
|
path: str, |
|
split: str = "train", |
|
splits: tuple[float, float, float] = (0.9, 0.05, 0.05), |
|
min_resize_res: int = 256, |
|
max_resize_res: int = 256, |
|
crop_res: int = 256, |
|
flip_prob: float = 0.0, |
|
): |
|
assert split in ("train", "val", "test") |
|
assert sum(splits) == 1 |
|
self.path = path |
|
self.min_resize_res = min_resize_res |
|
self.max_resize_res = max_resize_res |
|
self.crop_res = crop_res |
|
self.flip_prob = flip_prob |
|
|
|
with open(Path(self.path, "seeds.json")) as f: |
|
self.seeds = json.load(f) |
|
|
|
split_0, split_1 = { |
|
"train": (0.0, splits[0]), |
|
"val": (splits[0], splits[0] + splits[1]), |
|
"test": (splits[0] + splits[1], 1.0), |
|
}[split] |
|
|
|
idx_0 = math.floor(split_0 * len(self.seeds)) |
|
idx_1 = math.floor(split_1 * len(self.seeds)) |
|
self.seeds = self.seeds[idx_0:idx_1] |
|
|
|
def __len__(self) -> int: |
|
return len(self.seeds) |
|
|
|
def __getitem__(self, i: int) -> dict[str, Any]: |
|
name, seeds = self.seeds[i] |
|
propt_dir = Path(self.path, name) |
|
seed = seeds[torch.randint(0, len(seeds), ()).item()] |
|
with open(propt_dir.joinpath("prompt.json")) as fp: |
|
prompt = json.load(fp)["edit"] |
|
|
|
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) |
|
image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg")) |
|
|
|
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() |
|
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) |
|
image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS) |
|
|
|
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") |
|
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") |
|
|
|
crop = torchvision.transforms.RandomCrop(self.crop_res) |
|
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) |
|
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) |
|
|
|
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) |
|
|
|
|
|
class EditDatasetEval(Dataset): |
|
def __init__( |
|
self, |
|
path: str, |
|
split: str = "train", |
|
splits: tuple[float, float, float] = (0.9, 0.05, 0.05), |
|
res: int = 256, |
|
): |
|
assert split in ("train", "val", "test") |
|
assert sum(splits) == 1 |
|
self.path = path |
|
self.res = res |
|
|
|
with open(Path(self.path, "seeds.json")) as f: |
|
self.seeds = json.load(f) |
|
|
|
split_0, split_1 = { |
|
"train": (0.0, splits[0]), |
|
"val": (splits[0], splits[0] + splits[1]), |
|
"test": (splits[0] + splits[1], 1.0), |
|
}[split] |
|
|
|
idx_0 = math.floor(split_0 * len(self.seeds)) |
|
idx_1 = math.floor(split_1 * len(self.seeds)) |
|
self.seeds = self.seeds[idx_0:idx_1] |
|
|
|
def __len__(self) -> int: |
|
return len(self.seeds) |
|
|
|
def __getitem__(self, i: int) -> dict[str, Any]: |
|
name, seeds = self.seeds[i] |
|
propt_dir = Path(self.path, name) |
|
seed = seeds[torch.randint(0, len(seeds), ()).item()] |
|
with open(propt_dir.joinpath("prompt.json")) as fp: |
|
prompt = json.load(fp) |
|
edit = prompt["edit"] |
|
input_prompt = prompt["input"] |
|
output_prompt = prompt["output"] |
|
|
|
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) |
|
|
|
reize_res = torch.randint(self.res, self.res + 1, ()).item() |
|
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) |
|
|
|
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") |
|
|
|
return dict(image_0=image_0, input_prompt=input_prompt, edit=edit, output_prompt=output_prompt) |
|
|