burningdust
Initial commit
d72c37e
class FolderData(Dataset):
def __init__(self,
root_dir,
caption_file=None,
image_transforms=[],
ext="jpg",
default_caption="",
postprocess=None,
return_paths=False,
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self.root_dir = Path(root_dir)
self.default_caption = default_caption
self.return_paths = return_paths
if isinstance(postprocess, DictConfig):
postprocess = instantiate_from_config(postprocess)
self.postprocess = postprocess
if caption_file is not None:
with open(caption_file, "rt") as f:
ext = Path(caption_file).suffix.lower()
if ext == ".json":
captions = json.load(f)
elif ext == ".jsonl":
lines = f.readlines()
lines = [json.loads(x) for x in lines]
captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
else:
raise ValueError(f"Unrecognised format: {ext}")
self.captions = captions
else:
self.captions = None
if not isinstance(ext, (tuple, list, ListConfig)):
ext = [ext]
# Only used if there is no caption file
self.paths = []
for e in ext:
self.paths.extend(sorted(list(self.root_dir.rglob(f"*.{e}"))))
self.tform = make_tranforms(image_transforms)
def __len__(self):
if self.captions is not None:
return len(self.captions.keys())
else:
return len(self.paths)
def __getitem__(self, index):
data = {}
if self.captions is not None:
chosen = list(self.captions.keys())[index]
caption = self.captions.get(chosen, None)
if caption is None:
caption = self.default_caption
filename = self.root_dir/chosen
else:
filename = self.paths[index]
if self.return_paths:
data["path"] = str(filename)
im = Image.open(filename).convert("RGB")
im = self.process_im(im)
data["image"] = im
if self.captions is not None:
data["txt"] = caption
else:
data["txt"] = self.default_caption
if self.postprocess is not None:
data = self.postprocess(data)
return data
def process_im(self, im):
im = im.convert("RGB")
return self.tform(im)
import random
class TransformDataset():
def __init__(self, ds, extra_label="sksbspic"):
self.ds = ds
self.extra_label = extra_label
self.transforms = {
"align": transforms.Resize(768),
"centerzoom": transforms.CenterCrop(768),
"randzoom": transforms.RandomCrop(768),
}
def __getitem__(self, index):
data = self.ds[index]
im = data['image']
im = im.permute(2,0,1)
# In case data is smaller than expected
im = transforms.Resize(1024)(im)
tform_name = random.choice(list(self.transforms.keys()))
im = self.transforms[tform_name](im)
im = im.permute(1,2,0)
data['image'] = im
data['txt'] = data['txt'] + f" {self.extra_label} {tform_name}"
return data
def __len__(self):
return len(self.ds)
def hf_dataset(
name,
image_transforms=[],
image_column="image",
text_column="text",
split='train',
image_key='image',
caption_key='txt',
):
"""Make huggingface dataset with appropriate list of transforms applied
"""
ds = load_dataset(name, split=split)
tform = make_tranforms(image_transforms)
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
def pre_process(examples):
processed = {}
processed[image_key] = [tform(im) for im in examples[image_column]]
processed[caption_key] = examples[text_column]
return processed
ds.set_transform(pre_process)
return ds
class TextOnly(Dataset):
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
"""Returns only captions with dummy images"""
self.output_size = output_size
self.image_key = image_key
self.caption_key = caption_key
if isinstance(captions, Path):
self.captions = self._load_caption_file(captions)
else:
self.captions = captions
if n_gpus > 1:
# hack to make sure that all the captions appear on each gpu
repeated = [n_gpus*[x] for x in self.captions]
self.captions = []
[self.captions.extend(x) for x in repeated]
def __len__(self):
return len(self.captions)
def __getitem__(self, index):
dummy_im = torch.zeros(3, self.output_size, self.output_size)
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
def _load_caption_file(self, filename):
with open(filename, 'rt') as f:
captions = f.readlines()
return [x.strip('\n') for x in captions]
import random
import json
class IdRetreivalDataset(FolderData):
def __init__(self, ret_file, *args, **kwargs):
super().__init__(*args, **kwargs)
with open(ret_file, "rt") as f:
self.ret = json.load(f)
def __getitem__(self, index):
data = super().__getitem__(index)
key = self.paths[index].name
matches = self.ret[key]
if len(matches) > 0:
retreived = random.choice(matches)
else:
retreived = key
filename = self.root_dir/retreived
im = Image.open(filename).convert("RGB")
im = self.process_im(im)
# data["match"] = im
data["match"] = torch.cat((data["image"], im), dim=-1)
return data