Spaces:
Runtime error
Runtime error
File size: 4,575 Bytes
c05d22e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# -*- coding: utf-8 -*-
import numpy as np
import os
import pytorch_lightning as pl
import torch
import webdataset as wds
from torchvision.transforms import transforms
from ldm.util import instantiate_from_config
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
If `tensors` is True, `ndarray` objects are combined into
tensor batches.
:param dict samples: list of samples
:param bool tensors: whether to turn lists of ndarrays into a single ndarray
:returns: single sample consisting of a batch
:rtype: dict
"""
keys = set.intersection(*[set(sample.keys()) for sample in samples])
batched = {key: [] for key in keys}
for s in samples:
[batched[key].append(s[key]) for key in batched]
result = {}
for key in batched:
if isinstance(batched[key][0], (int, float)):
if combine_scalars:
result[key] = np.array(list(batched[key]))
elif isinstance(batched[key][0], torch.Tensor):
if combine_tensors:
result[key] = torch.stack(list(batched[key]))
elif isinstance(batched[key][0], np.ndarray):
if combine_tensors:
result[key] = np.array(list(batched[key]))
else:
result[key] = list(batched[key])
return result
class WebDataModuleFromConfig(pl.LightningDataModule):
def __init__(self,
tar_base,
batch_size,
train=None,
validation=None,
test=None,
num_workers=4,
multinode=True,
min_size=None,
max_pwatermark=1.0,
**kwargs):
super().__init__()
print(f'Setting tar base to {tar_base}')
self.tar_base = tar_base
self.batch_size = batch_size
self.num_workers = num_workers
self.train = train
self.validation = validation
self.test = test
self.multinode = multinode
self.min_size = min_size # filter out very small images
self.max_pwatermark = max_pwatermark # filter out watermarked images
def make_loader(self, dataset_config):
image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
image_transforms = transforms.Compose(image_transforms)
process = instantiate_from_config(dataset_config['process'])
shuffle = dataset_config.get('shuffle', 0)
shardshuffle = shuffle > 0
nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
tars = os.path.join(self.tar_base, dataset_config.shards)
dset = wds.WebDataset(
tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle,
handler=wds.warn_and_continue).repeat().shuffle(shuffle)
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
dset = (
dset.select(self.filter_keys).decode('pil',
handler=wds.warn_and_continue).select(self.filter_size).map_dict(
jpg=image_transforms, handler=wds.warn_and_continue).map(process))
dset = (dset.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn))
loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers)
return loader
def filter_size(self, x):
if self.min_size is None:
return True
try:
return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[
'json']['pwatermark'] <= self.max_pwatermark
except Exception:
return False
def filter_keys(self, x):
try:
return ("jpg" in x) and ("txt" in x)
except Exception:
return False
def train_dataloader(self):
return self.make_loader(self.train)
def val_dataloader(self):
return None
def test_dataloader(self):
return None
if __name__ == '__main__':
from omegaconf import OmegaConf
config = OmegaConf.load("configs/stable-diffusion/train_canny_sd_v1.yaml")
datamod = WebDataModuleFromConfig(**config["data"]["params"])
dataloader = datamod.train_dataloader()
for batch in dataloader:
print(batch.keys())
print(batch['jpg'].shape)
|