Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import ast | |
import torch | |
import numpy as np | |
import random | |
from torch.utils.data import DataLoader, Dataset | |
cv2.setNumThreads(1) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class VimeoDataset(Dataset): | |
def __init__(self, dataset_name, batch_size=32): | |
self.batch_size = batch_size | |
self.dataset_name = dataset_name | |
self.h = 256 | |
self.w = 448 | |
self.data_root = 'vimeo_triplet' | |
self.image_root = os.path.join(self.data_root, 'sequences') | |
train_fn = os.path.join(self.data_root, 'tri_trainlist.txt') | |
test_fn = os.path.join(self.data_root, 'tri_testlist.txt') | |
with open(train_fn, 'r') as f: | |
self.trainlist = f.read().splitlines() | |
with open(test_fn, 'r') as f: | |
self.testlist = f.read().splitlines() | |
self.load_data() | |
def __len__(self): | |
return len(self.meta_data) | |
def load_data(self): | |
cnt = int(len(self.trainlist) * 0.95) | |
if self.dataset_name == 'train': | |
self.meta_data = self.trainlist[:cnt] | |
elif self.dataset_name == 'test': | |
self.meta_data = self.testlist | |
else: | |
self.meta_data = self.trainlist[cnt:] | |
def crop(self, img0, gt, img1, h, w): | |
ih, iw, _ = img0.shape | |
x = np.random.randint(0, ih - h + 1) | |
y = np.random.randint(0, iw - w + 1) | |
img0 = img0[x:x+h, y:y+w, :] | |
img1 = img1[x:x+h, y:y+w, :] | |
gt = gt[x:x+h, y:y+w, :] | |
return img0, gt, img1 | |
def getimg(self, index): | |
imgpath = os.path.join(self.image_root, self.meta_data[index]) | |
imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png'] | |
# Load images | |
img0 = cv2.imread(imgpaths[0]) | |
gt = cv2.imread(imgpaths[1]) | |
img1 = cv2.imread(imgpaths[2]) | |
timestep = 0.5 | |
return img0, gt, img1, timestep | |
# RIFEm with Vimeo-Septuplet | |
# imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png', imgpath + '/im4.png', imgpath + '/im5.png', imgpath + '/im6.png', imgpath + '/im7.png'] | |
# ind = [0, 1, 2, 3, 4, 5, 6] | |
# random.shuffle(ind) | |
# ind = ind[:3] | |
# ind.sort() | |
# img0 = cv2.imread(imgpaths[ind[0]]) | |
# gt = cv2.imread(imgpaths[ind[1]]) | |
# img1 = cv2.imread(imgpaths[ind[2]]) | |
# timestep = (ind[1] - ind[0]) * 1.0 / (ind[2] - ind[0] + 1e-6) | |
def __getitem__(self, index): | |
img0, gt, img1, timestep = self.getimg(index) | |
if self.dataset_name == 'train': | |
img0, gt, img1 = self.crop(img0, gt, img1, 224, 224) | |
if random.uniform(0, 1) < 0.5: | |
img0 = img0[:, :, ::-1] | |
img1 = img1[:, :, ::-1] | |
gt = gt[:, :, ::-1] | |
if random.uniform(0, 1) < 0.5: | |
img0 = img0[::-1] | |
img1 = img1[::-1] | |
gt = gt[::-1] | |
if random.uniform(0, 1) < 0.5: | |
img0 = img0[:, ::-1] | |
img1 = img1[:, ::-1] | |
gt = gt[:, ::-1] | |
if random.uniform(0, 1) < 0.5: | |
tmp = img1 | |
img1 = img0 | |
img0 = tmp | |
timestep = 1 - timestep | |
# random rotation | |
p = random.uniform(0, 1) | |
if p < 0.25: | |
img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE) | |
gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE) | |
img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE) | |
elif p < 0.5: | |
img0 = cv2.rotate(img0, cv2.ROTATE_180) | |
gt = cv2.rotate(gt, cv2.ROTATE_180) | |
img1 = cv2.rotate(img1, cv2.ROTATE_180) | |
elif p < 0.75: | |
img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1) | |
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1) | |
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1) | |
timestep = torch.tensor(timestep).reshape(1, 1, 1) | |
return torch.cat((img0, img1, gt), 0), timestep | |