Spaces:
Sleeping
Sleeping
import os | |
from glob import glob | |
import random | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
import torch | |
from torch.utils.data import Dataset | |
import torchvision.datasets.folder | |
import torchvision.transforms as transforms | |
from einops import rearrange | |
def compute_distance_transform(mask): | |
mask_dt = [] | |
for m in mask: | |
dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) | |
inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) | |
mask_dt += [torch.stack([dt, inv_dt], 0)] | |
return torch.stack(mask_dt, 0) # Bx2xHxW | |
def crop_image(image, boxs, size): | |
crops = [] | |
for box in boxs: | |
crop_x0, crop_y0, crop_w, crop_h = box | |
crop = transforms.functional.resized_crop(image, crop_y0, crop_x0, crop_h, crop_w, size) | |
crop = transforms.functional.to_tensor(crop) | |
crops += [crop] | |
return torch.stack(crops, 0) | |
def box_loader(fpath): | |
box = np.loadtxt(fpath, 'str') | |
box[0] = box[0].split('_')[0] | |
return box.astype(np.float32) | |
def read_feat_from_img(path, n_channels): | |
feat = np.array(Image.open(path)) | |
return dencode_feat_from_img(feat, n_channels) | |
def dencode_feat_from_img(img, n_channels): | |
n_addon_channels = int(np.ceil(n_channels / 3) * 3) - n_channels | |
n_tiles = int((n_channels + n_addon_channels) / 3) | |
feat = rearrange(img, 'h (t w) c -> h w (t c)', t=n_tiles, c=3) | |
feat = feat[:, :, :-n_addon_channels] | |
feat = feat.astype('float32') / 255 | |
return feat.transpose(2, 0, 1) | |
def dino_loader(fpath, n_channels): | |
dino_map = read_feat_from_img(fpath, n_channels) | |
return dino_map | |
def get_valid_mask(boxs, image_size): | |
valid_masks = [] | |
for box in boxs: | |
crop_x0, crop_y0, crop_w, crop_h, full_w, full_h = box[1:7].int().numpy() | |
# Discard a small margin near the boundary. | |
margin_w = int(crop_w * 0.02) | |
margin_h = int(crop_h * 0.02) | |
mask_full = torch.ones(full_h-margin_h*2, full_w-margin_w*2) | |
mask_full_pad = torch.nn.functional.pad(mask_full, (crop_w+margin_w, crop_w+margin_w, crop_h+margin_h, crop_h+margin_h), mode='constant', value=0.0) | |
mask_full_crop = mask_full_pad[crop_y0+crop_h:crop_y0+crop_h*2, crop_x0+crop_w:crop_x0+crop_w*2] | |
mask_crop = torch.nn.functional.interpolate(mask_full_crop[None, None, :, :], image_size, mode='nearest')[0,0] | |
valid_masks += [mask_crop] | |
return torch.stack(valid_masks, 0) # NxHxW | |
def horizontal_flip_box(box): | |
frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = box.unbind(1) | |
box[:,1] = full_w - crop_x0 - crop_w # x0 | |
return box | |
def horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features=None, dino_clusters=None): | |
images = images.flip(3) # NxCxHxW | |
masks = masks.flip(3) # NxCxHxW | |
mask_dt = mask_dt.flip(3) # NxCxHxW | |
mask_valid = mask_valid.flip(2) # NxHxW | |
if flows.dim() > 1: | |
flows = flows.flip(3) # (N-1)x(x,y)xHxW | |
flows[:,0] *= -1 # invert delta x | |
bboxs = horizontal_flip_box(bboxs) # NxK | |
bg_images = bg_images.flip(3) # NxCxHxW | |
if dino_features.dim() > 1: | |
dino_features = dino_features.flip(3) | |
if dino_clusters.dim() > 1: | |
dino_clusters = dino_clusters.flip(3) | |
return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters | |
class BaseSequenceDataset(Dataset): | |
def __init__(self, root, skip_beginning=4, skip_end=4, min_seq_len=10, debug_seq=False): | |
super().__init__() | |
self.skip_beginning = skip_beginning | |
self.skip_end = skip_end | |
self.min_seq_len = min_seq_len | |
# self.pattern = "{:07d}_{}" | |
self.sequences = self._make_sequences(root) | |
if debug_seq: | |
# self.sequences = [self.sequences[0][20:160]] * 100 | |
seq_len = 0 | |
while seq_len < min_seq_len: | |
i = np.random.randint(len(self.sequences)) | |
rand_seq = self.sequences[i] | |
seq_len = len(rand_seq) | |
self.sequences = [rand_seq] | |
self.samples = [] | |
def _make_sequences(self, path): | |
result = [] | |
for d in sorted(os.scandir(path), key=lambda e: e.name): | |
if d.is_dir(): | |
files = self._parse_folder(d) | |
if len(files) >= self.min_seq_len: | |
result.append(files) | |
return result | |
def _parse_folder(self, path): | |
result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0]))) | |
result = [p.replace(self.image_loaders[0][0], '{}') for p in result] | |
if len(result) <= self.skip_beginning + self.skip_end: | |
return [] | |
if self.skip_end == 0: | |
return result[self.skip_beginning:] | |
return result[self.skip_beginning:-self.skip_end] | |
def _load_ids(self, path_patterns, loaders, transform=None): | |
result = [] | |
for loader in loaders: | |
for p in path_patterns: | |
x = loader[1](p.format(loader[0]), *loader[2:]) | |
if transform: | |
x = transform(x) | |
result.append(x) | |
return tuple(result) | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, index): | |
raise NotImplemented("This is a base class and should not be used directly") | |
class NFrameSequenceDataset(BaseSequenceDataset): | |
def __init__(self, root, cat_name=None, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, in_image_size=256, out_image_size=256, debug_seq=False, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, **kwargs): | |
self.cat_name = cat_name | |
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] | |
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] | |
self.bbox_loaders = [("box.txt", box_loader)] | |
super().__init__(root, skip_beginning, skip_end, min_seq_len, debug_seq) | |
if num_sample_frames > 1: | |
self.flow_loaders = [("flow.png", cv2.imread, cv2.IMREAD_UNCHANGED)] | |
else: | |
self.flow_loaders = None | |
self.num_sample_frames = num_sample_frames | |
self.random_sample = random_sample | |
if self.random_sample: | |
if shuffle: | |
random.shuffle(self.sequences) | |
self.samples = self.sequences | |
else: | |
for i, s in enumerate(self.sequences): | |
stride = 1 if dense_sample else self.num_sample_frames | |
self.samples += [(i, k) for k in range(0, len(s), stride)] | |
if shuffle: | |
random.shuffle(self.samples) | |
self.in_image_size = in_image_size | |
self.out_image_size = out_image_size | |
self.load_background = load_background | |
self.color_jitter = color_jitter | |
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) | |
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) | |
if self.flow_loaders is not None: | |
self.flow_transform = lambda x: (torch.FloatTensor(x.astype(np.float32)).flip(2)[:,:,:2] / 65535. ) *2 -1 | |
self.random_flip = random_flip | |
self.load_dino_feature = load_dino_feature | |
if load_dino_feature: | |
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] | |
self.load_dino_cluster = load_dino_cluster | |
if load_dino_cluster: | |
self.dino_cluster_loaders = [("clusters.png", torchvision.datasets.folder.default_loader)] | |
def __getitem__(self, index): | |
if self.random_sample: | |
seq_idx = index % len(self.sequences) | |
seq = self.sequences[seq_idx] | |
if len(seq) < self.num_sample_frames: | |
start_frame_idx = 0 | |
else: | |
start_frame_idx = np.random.randint(len(seq)-self.num_sample_frames+1) | |
paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames] | |
else: | |
seq_idx, start_frame_idx = self.samples[index % len(self.samples)] | |
seq = self.sequences[seq_idx] | |
# Handle edge case: when only last frame is left, sample last two frames, except if the sequence only has one frame | |
if len(seq) <= start_frame_idx +1: | |
start_frame_idx = max(0, start_frame_idx-1) | |
paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames] | |
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images | |
mask_dt = compute_distance_transform(masks) | |
jitter = False | |
if self.color_jitter is not None: | |
prob, b, h = self.color_jitter | |
if np.random.rand() < prob: | |
jitter = True | |
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) | |
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) | |
if jitter: | |
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images | |
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images | |
images = images_fg * masks + images_bg * (1-masks) | |
else: | |
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images | |
if len(paths) > 1: | |
flows = torch.stack(self._load_ids(paths[:-1], self.flow_loaders, transform=self.flow_transform), 0).permute(0,3,1,2) # load flow for first image, (N-1)x(x,y)xHxW, -1~1 | |
flows = torch.nn.functional.interpolate(flows, size=self.out_image_size, mode="bilinear") | |
else: | |
flows = torch.zeros(1) | |
bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images | |
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image | |
if self.load_background: | |
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) | |
if jitter: | |
bg_image = color_jitter_tsf_bg(bg_image) | |
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) | |
else: | |
bg_images = torch.zeros_like(images) | |
if self.load_dino_feature: | |
dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 | |
else: | |
dino_features = torch.zeros(1) | |
if self.load_dino_cluster: | |
dino_clusters = torch.stack(self._load_ids(paths, self.dino_cluster_loaders, transform=transforms.ToTensor()), 0) # BxFx3x55x55 | |
else: | |
dino_clusters = torch.zeros(1) | |
seq_idx = torch.LongTensor([seq_idx]) | |
frame_idx = torch.arange(start_frame_idx, start_frame_idx+len(paths)).long() | |
if self.random_flip and np.random.rand() < 0.5: | |
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) | |
## pad shorter sequence | |
if len(paths) < self.num_sample_frames: | |
num_pad = self.num_sample_frames - len(paths) | |
images = torch.cat([images[:1]] *num_pad + [images], 0) | |
masks = torch.cat([masks[:1]] *num_pad + [masks], 0) | |
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) | |
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) | |
if flows.dim() > 1: | |
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) | |
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) | |
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) | |
if dino_features.dim() > 1: | |
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) | |
if dino_clusters.dim() > 1: | |
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) | |
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) | |
return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name | |
def get_sequence_loader(data_dir, **kwargs): | |
if isinstance(data_dir, dict): | |
loaders = [] | |
for k, v in data_dir.items(): | |
dataset= NFrameSequenceDataset(v, cat_name=k, **kwargs) | |
loader = torch.utils.data.DataLoader(dataset, batch_size=kwargs['batch_size'], shuffle=kwargs['shuffle'], num_workers=kwargs['num_workers'], pin_memory=True) | |
loaders += [loader] | |
return loaders | |
else: | |
return [get_sequence_loader_single(data_dir, **kwargs)] | |
def get_sequence_loader_single(data_dir, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64): | |
if mode == 'n_frame': | |
dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim) | |
else: | |
raise NotImplementedError | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=not is_validation, | |
num_workers=num_workers, | |
pin_memory=True | |
) | |
return loader | |
class ImageDataset(Dataset): | |
def __init__(self, root, is_validation=False, image_size=256, color_jitter=None): | |
super().__init__() | |
self.image_loader = ("rgb.jpg", torchvision.datasets.folder.default_loader) | |
self.mask_loader = ("mask.png", torchvision.datasets.folder.default_loader) | |
self.bbox_loader = ("box.txt", np.loadtxt, 'str') | |
self.samples = self._parse_folder(root) | |
self.image_size = image_size | |
self.color_jitter = color_jitter | |
self.image_transform = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor()]) | |
self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) | |
def _parse_folder(self, path): | |
result = sorted(glob(os.path.join(path, '**/*'+self.image_loader[0]), recursive=True)) | |
result = [p.replace(self.image_loader[0], '{}') for p in result] | |
return result | |
def _load_ids(self, path, loader, transform=None): | |
x = loader[1](path.format(loader[0]), *loader[2:]) | |
if transform: | |
x = transform(x) | |
return x | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, index): | |
path = self.samples[index % len(self.samples)] | |
masks = self._load_ids(path, self.mask_loader, transform=self.mask_transform).unsqueeze(0) | |
mask_dt = compute_distance_transform(masks) | |
jitter = False | |
if self.color_jitter is not None: | |
prob, b, h = self.color_jitter | |
if np.random.rand() < prob: | |
jitter = True | |
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_fg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_fg, transforms.ToTensor()]) | |
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_bg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_bg, transforms.ToTensor()]) | |
if jitter: | |
images_fg = self._load_ids(path, self.image_loader, transform=image_transform_fg).unsqueeze(0) | |
images_bg = self._load_ids(path, self.image_loader, transform=image_transform_bg).unsqueeze(0) | |
images = images_fg * masks + images_bg * (1-masks) | |
else: | |
images = self._load_ids(path, self.image_loader, transform=self.image_transform).unsqueeze(0) | |
flows = torch.zeros(1) | |
bboxs = self._load_ids(path, self.bbox_loader, transform=None) | |
bboxs[0] = '0' | |
bboxs = torch.FloatTensor(bboxs.astype('float')).unsqueeze(0) | |
bg_fpath = os.path.join(os.path.dirname(path), 'background_frame.jpg') | |
if os.path.isfile(bg_fpath): | |
bg_image = torchvision.datasets.folder.default_loader(bg_fpath) | |
if jitter: | |
bg_image = color_jitter_tsf_bg(bg_image) | |
bg_image = transforms.ToTensor()(bg_image) | |
else: | |
bg_image = images[0] | |
seq_idx = torch.LongTensor([index]) | |
frame_idx = torch.LongTensor([0]) | |
return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx | |
def get_image_loader(data_dir, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None): | |
dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
pin_memory=True | |
) | |
return loader | |