Spaces:
Configuration error
Configuration error
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma | |
import getpass | |
import json | |
import os | |
import os.path as osp | |
import random | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from PIL import Image | |
from termcolor import colored | |
from torch.utils.data import Dataset | |
from diffusion.data.builder import DATASETS, get_data_path | |
from diffusion.data.wids import ShardListDataset, ShardListDatasetMulti, lru_json_load | |
from diffusion.utils.logger import get_root_logger | |
class SanaImgDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
data_dir="", | |
transform=None, | |
resolution=256, | |
load_vae_feat=False, | |
load_text_feat=False, | |
max_length=300, | |
config=None, | |
caption_proportion=None, | |
external_caption_suffixes=None, | |
external_clipscore_suffixes=None, | |
clip_thr=0.0, | |
clip_thr_temperature=1.0, | |
img_extension=".png", | |
**kwargs, | |
): | |
if external_caption_suffixes is None: | |
external_caption_suffixes = [] | |
if external_clipscore_suffixes is None: | |
external_clipscore_suffixes = [] | |
self.logger = ( | |
get_root_logger() if config is None else get_root_logger(osp.join(config.work_dir, "train_log.log")) | |
) | |
self.transform = transform if not load_vae_feat else None | |
self.load_vae_feat = load_vae_feat | |
self.load_text_feat = load_text_feat | |
self.resolution = resolution | |
self.max_length = max_length | |
self.caption_proportion = caption_proportion if caption_proportion is not None else {"prompt": 1.0} | |
self.external_caption_suffixes = external_caption_suffixes | |
self.external_clipscore_suffixes = external_clipscore_suffixes | |
self.clip_thr = clip_thr | |
self.clip_thr_temperature = clip_thr_temperature | |
self.default_prompt = "prompt" | |
self.img_extension = img_extension | |
self.data_dirs = data_dir if isinstance(data_dir, list) else [data_dir] | |
# self.meta_datas = [osp.join(data_dir, "meta_data.json") for data_dir in self.data_dirs] | |
self.dataset = [] | |
for data_dir in self.data_dirs: | |
meta_data = json.load(open(osp.join(data_dir, "meta_data.json"))) | |
self.dataset.extend([osp.join(data_dir, i) for i in meta_data["img_names"]]) | |
self.dataset = self.dataset * 2000 | |
self.logger.info(colored("Dataset is repeat 2000 times for toy dataset", "red", attrs=["bold"])) | |
self.ori_imgs_nums = len(self) | |
self.logger.info(f"Dataset samples: {len(self.dataset)}") | |
self.logger.info(f"Loading external caption json from: original_filename{external_caption_suffixes}.json") | |
self.logger.info(f"Loading external clipscore json from: original_filename{external_clipscore_suffixes}.json") | |
self.logger.info(f"external caption clipscore threshold: {clip_thr}, temperature: {clip_thr_temperature}") | |
self.logger.info(f"T5 max token length: {self.max_length}") | |
def getdata(self, idx): | |
data = self.dataset[idx] | |
self.key = data.split("/")[-1] | |
# info = json.load(open(f"{data}.json"))[self.key] | |
info = {} | |
with open(f"{data}.txt") as f: | |
info[self.default_prompt] = f.readlines()[0].strip() | |
# external json file | |
for suffix in self.external_caption_suffixes: | |
caption_json_path = f"{data}{suffix}.json" | |
if os.path.exists(caption_json_path): | |
try: | |
caption_json = lru_json_load(caption_json_path) | |
except: | |
caption_json = {} | |
if self.key in caption_json: | |
info.update(caption_json[self.key]) | |
caption_type, caption_clipscore = self.weighted_sample_clipscore(data, info) | |
caption_type = caption_type if caption_type in info else self.default_prompt | |
txt_fea = "" if info[caption_type] is None else info[caption_type] | |
data_info = { | |
"img_hw": torch.tensor([self.resolution, self.resolution], dtype=torch.float32), | |
"aspect_ratio": torch.tensor(1.0), | |
} | |
if self.load_vae_feat: | |
assert ValueError("Load VAE is not supported now") | |
else: | |
img = f"{data}{self.img_extension}" | |
img = Image.open(img) | |
if self.transform: | |
img = self.transform(img) | |
attention_mask = torch.ones(1, 1, self.max_length, dtype=torch.int16) # 1x1xT | |
if self.load_text_feat: | |
npz_path = f"{self.key}.npz" | |
txt_info = np.load(npz_path) | |
txt_fea = torch.from_numpy(txt_info["caption_feature"]) # 1xTx4096 | |
if "attention_mask" in txt_info: | |
attention_mask = torch.from_numpy(txt_info["attention_mask"])[None] | |
# make sure the feature length are the same | |
if txt_fea.shape[1] != self.max_length: | |
txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_length - txt_fea.shape[1], 1)], dim=1) | |
attention_mask = torch.cat( | |
[attention_mask, torch.zeros(1, 1, self.max_length - attention_mask.shape[-1])], dim=-1 | |
) | |
return ( | |
img, | |
txt_fea, | |
attention_mask.to(torch.int16), | |
data_info, | |
idx, | |
caption_type, | |
"", | |
str(caption_clipscore), | |
) | |
def __getitem__(self, idx): | |
for _ in range(10): | |
try: | |
data = self.getdata(idx) | |
return data | |
except Exception as e: | |
print(f"Error details: {str(e)}") | |
idx = idx + 1 | |
raise RuntimeError("Too many bad data.") | |
def __len__(self): | |
return len(self.dataset) | |
def weighted_sample_fix_prob(self): | |
labels = list(self.caption_proportion.keys()) | |
weights = list(self.caption_proportion.values()) | |
sampled_label = random.choices(labels, weights=weights, k=1)[0] | |
return sampled_label | |
def weighted_sample_clipscore(self, data, info): | |
labels = [] | |
weights = [] | |
fallback_label = None | |
max_clip_score = float("-inf") | |
for suffix in self.external_clipscore_suffixes: | |
clipscore_json_path = f"{data}{suffix}.json" | |
if os.path.exists(clipscore_json_path): | |
try: | |
clipscore_json = lru_json_load(clipscore_json_path) | |
except: | |
clipscore_json = {} | |
if self.key in clipscore_json: | |
clip_scores = clipscore_json[self.key] | |
for caption_type, clip_score in clip_scores.items(): | |
clip_score = float(clip_score) | |
if caption_type in info: | |
if clip_score >= self.clip_thr: | |
labels.append(caption_type) | |
weights.append(clip_score) | |
if clip_score > max_clip_score: | |
max_clip_score = clip_score | |
fallback_label = caption_type | |
if not labels and fallback_label: | |
return fallback_label, max_clip_score | |
if not labels: | |
return self.default_prompt, 0.0 | |
adjusted_weights = np.array(weights) ** (1.0 / max(self.clip_thr_temperature, 0.01)) | |
normalized_weights = adjusted_weights / np.sum(adjusted_weights) | |
sampled_label = random.choices(labels, weights=normalized_weights, k=1)[0] | |
# sampled_label = random.choices(labels, weights=[1]*len(weights), k=1)[0] | |
index = labels.index(sampled_label) | |
original_weight = weights[index] | |
return sampled_label, original_weight | |
class SanaWebDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
data_dir="", | |
meta_path=None, | |
cache_dir="/cache/data/sana-webds-meta", | |
max_shards_to_load=None, | |
transform=None, | |
resolution=256, | |
load_vae_feat=False, | |
load_text_feat=False, | |
max_length=300, | |
config=None, | |
caption_proportion=None, | |
sort_dataset=False, | |
num_replicas=None, | |
external_caption_suffixes=None, | |
external_clipscore_suffixes=None, | |
clip_thr=0.0, | |
clip_thr_temperature=1.0, | |
**kwargs, | |
): | |
if external_caption_suffixes is None: | |
external_caption_suffixes = [] | |
if external_clipscore_suffixes is None: | |
external_clipscore_suffixes = [] | |
self.logger = ( | |
get_root_logger() if config is None else get_root_logger(osp.join(config.work_dir, "train_log.log")) | |
) | |
self.transform = transform if not load_vae_feat else None | |
self.load_vae_feat = load_vae_feat | |
self.load_text_feat = load_text_feat | |
self.resolution = resolution | |
self.max_length = max_length | |
self.caption_proportion = caption_proportion if caption_proportion is not None else {"prompt": 1.0} | |
self.external_caption_suffixes = external_caption_suffixes | |
self.external_clipscore_suffixes = external_clipscore_suffixes | |
self.clip_thr = clip_thr | |
self.clip_thr_temperature = clip_thr_temperature | |
self.default_prompt = "prompt" | |
data_dirs = data_dir if isinstance(data_dir, list) else [data_dir] | |
meta_paths = meta_path if isinstance(meta_path, list) else [meta_path] * len(data_dirs) | |
self.meta_paths = [] | |
for data_path, meta_path in zip(data_dirs, meta_paths): | |
self.data_path = osp.expanduser(data_path) | |
self.meta_path = osp.expanduser(meta_path) if meta_path is not None else None | |
_local_meta_path = osp.join(self.data_path, "wids-meta.json") | |
if meta_path is None and osp.exists(_local_meta_path): | |
self.logger.info(f"loading from {_local_meta_path}") | |
self.meta_path = meta_path = _local_meta_path | |
if meta_path is None: | |
self.meta_path = osp.join( | |
osp.expanduser(cache_dir), | |
self.data_path.replace("/", "--") + f".max_shards:{max_shards_to_load}" + ".wdsmeta.json", | |
) | |
assert osp.exists(self.meta_path), f"meta path not found in [{self.meta_path}] or [{_local_meta_path}]" | |
self.logger.info(f"[SimplyInternal] Loading meta information {self.meta_path}") | |
self.meta_paths.append(self.meta_path) | |
self._initialize_dataset(num_replicas, sort_dataset) | |
self.logger.info(f"Loading external caption json from: original_filename{external_caption_suffixes}.json") | |
self.logger.info(f"Loading external clipscore json from: original_filename{external_clipscore_suffixes}.json") | |
self.logger.info(f"external caption clipscore threshold: {clip_thr}, temperature: {clip_thr_temperature}") | |
self.logger.info(f"T5 max token length: {self.max_length}") | |
self.logger.warning(f"Sort the dataset: {sort_dataset}") | |
def _initialize_dataset(self, num_replicas, sort_dataset): | |
# uuid = abs(hash(self.meta_path)) % (10 ** 8) | |
import hashlib | |
uuid = hashlib.sha256(self.meta_path.encode()).hexdigest()[:8] | |
if len(self.meta_paths) > 0: | |
self.dataset = ShardListDatasetMulti( | |
self.meta_paths, | |
cache_dir=osp.expanduser(f"~/.cache/_wids_cache/{getpass.getuser()}-{uuid}"), | |
sort_data_inseq=sort_dataset, | |
num_replicas=num_replicas or dist.get_world_size(), | |
) | |
else: | |
# TODO: tmp to ensure there is no bug | |
self.dataset = ShardListDataset( | |
self.meta_path, | |
cache_dir=osp.expanduser(f"~/.cache/_wids_cache/{getpass.getuser()}-{uuid}"), | |
) | |
self.ori_imgs_nums = len(self) | |
self.logger.info(f"{self.dataset.data_info}") | |
def getdata(self, idx): | |
data = self.dataset[idx] | |
info = data[".json"] | |
self.key = data["__key__"] | |
dataindex_info = { | |
"index": data["__index__"], | |
"shard": "/".join(data["__shard__"].rsplit("/", 2)[-2:]), | |
"shardindex": data["__shardindex__"], | |
} | |
# external json file | |
for suffix in self.external_caption_suffixes: | |
caption_json_path = data["__shard__"].replace(".tar", f"{suffix}.json") | |
if os.path.exists(caption_json_path): | |
try: | |
caption_json = lru_json_load(caption_json_path) | |
except: | |
caption_json = {} | |
if self.key in caption_json: | |
info.update(caption_json[self.key]) | |
caption_type, caption_clipscore = self.weighted_sample_clipscore(data, info) | |
caption_type = caption_type if caption_type in info else self.default_prompt | |
txt_fea = "" if info[caption_type] is None else info[caption_type] | |
data_info = { | |
"img_hw": torch.tensor([self.resolution, self.resolution], dtype=torch.float32), | |
"aspect_ratio": torch.tensor(1.0), | |
} | |
if self.load_vae_feat: | |
img = data[".npy"] | |
else: | |
img = data[".png"] if ".png" in data else data[".jpg"] | |
if self.transform: | |
img = self.transform(img) | |
attention_mask = torch.ones(1, 1, self.max_length, dtype=torch.int16) # 1x1xT | |
if self.load_text_feat: | |
npz_path = f"{self.key}.npz" | |
txt_info = np.load(npz_path) | |
txt_fea = torch.from_numpy(txt_info["caption_feature"]) # 1xTx4096 | |
if "attention_mask" in txt_info: | |
attention_mask = torch.from_numpy(txt_info["attention_mask"])[None] | |
# make sure the feature length are the same | |
if txt_fea.shape[1] != self.max_length: | |
txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_length - txt_fea.shape[1], 1)], dim=1) | |
attention_mask = torch.cat( | |
[attention_mask, torch.zeros(1, 1, self.max_length - attention_mask.shape[-1])], dim=-1 | |
) | |
return ( | |
img, | |
txt_fea, | |
attention_mask.to(torch.int16), | |
data_info, | |
idx, | |
caption_type, | |
dataindex_info, | |
str(caption_clipscore), | |
) | |
def __getitem__(self, idx): | |
for _ in range(10): | |
try: | |
data = self.getdata(idx) | |
return data | |
except Exception as e: | |
print(f"Error details: {str(e)}") | |
idx = idx + 1 | |
raise RuntimeError("Too many bad data.") | |
def __len__(self): | |
return len(self.dataset) | |
def weighted_sample_fix_prob(self): | |
labels = list(self.caption_proportion.keys()) | |
weights = list(self.caption_proportion.values()) | |
sampled_label = random.choices(labels, weights=weights, k=1)[0] | |
return sampled_label | |
def weighted_sample_clipscore(self, data, info): | |
labels = [] | |
weights = [] | |
fallback_label = None | |
max_clip_score = float("-inf") | |
for suffix in self.external_clipscore_suffixes: | |
clipscore_json_path = data["__shard__"].replace(".tar", f"{suffix}.json") | |
if os.path.exists(clipscore_json_path): | |
try: | |
clipscore_json = lru_json_load(clipscore_json_path) | |
except: | |
clipscore_json = {} | |
if self.key in clipscore_json: | |
clip_scores = clipscore_json[self.key] | |
for caption_type, clip_score in clip_scores.items(): | |
clip_score = float(clip_score) | |
if caption_type in info: | |
if clip_score >= self.clip_thr: | |
labels.append(caption_type) | |
weights.append(clip_score) | |
if clip_score > max_clip_score: | |
max_clip_score = clip_score | |
fallback_label = caption_type | |
if not labels and fallback_label: | |
return fallback_label, max_clip_score | |
if not labels: | |
return self.default_prompt, 0.0 | |
adjusted_weights = np.array(weights) ** (1.0 / max(self.clip_thr_temperature, 0.01)) | |
normalized_weights = adjusted_weights / np.sum(adjusted_weights) | |
sampled_label = random.choices(labels, weights=normalized_weights, k=1)[0] | |
# sampled_label = random.choices(labels, weights=[1]*len(weights), k=1)[0] | |
index = labels.index(sampled_label) | |
original_weight = weights[index] | |
return sampled_label, original_weight | |
def get_data_info(self, idx): | |
try: | |
data = self.dataset[idx] | |
info = data[".json"] | |
key = data["__key__"] | |
version = info.get("version", "others") | |
return {"height": info["height"], "width": info["width"], "version": version, "key": key} | |
except Exception as e: | |
print(f"Error details: {str(e)}") | |
return None | |
if __name__ == "__main__": | |
from torch.utils.data import DataLoader | |
from diffusion.data.transforms import get_transform | |
image_size = 1024 # 256 | |
transform = get_transform("default_train", image_size) | |
train_dataset = SanaWebDataset( | |
data_dir="debug_data_train/vaef32c32/debug_data", | |
resolution=image_size, | |
transform=transform, | |
max_length=300, | |
load_vae_feat=True, | |
num_replicas=1, | |
) | |
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=4) | |
for data in dataloader: | |
img, txt_fea, attention_mask, data_info = data | |
print(txt_fea) | |
break | |