from utils.dataset_utils import * class CachedDataset(Dataset): def __init__(self,cache_dir: str = ''): self.cache_dir = cache_dir self.cached_data_list = self.get_files_list() def get_files_list(self): tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')] return sorted(tensors_list) def __len__(self): return len(self.cached_data_list) def __getitem__(self, index): cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0') return cached_latent