MotionInversion / dataset /cached_dataset.py
ziyangmai's picture
page demo
113884e
raw
history blame contribute delete
588 Bytes
from utils.dataset_utils import *
class CachedDataset(Dataset):
def __init__(self,cache_dir: str = ''):
self.cache_dir = cache_dir
self.cached_data_list = self.get_files_list()
def get_files_list(self):
tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')]
return sorted(tensors_list)
def __len__(self):
return len(self.cached_data_list)
def __getitem__(self, index):
cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0')
return cached_latent