Spaces:
Runtime error
Runtime error
import os | |
import sys | |
sys.path.append('.') | |
import cv2 | |
import math | |
import torch | |
import argparse | |
import numpy as np | |
from torch.nn import functional as F | |
from model.pytorch_msssim import ssim_matlab | |
from model.RIFE import Model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = Model() | |
model.load_model('train_log') | |
model.eval() | |
model.device() | |
name = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'] | |
IE_list = [] | |
for i in name: | |
i0 = cv2.imread('other-data/{}/frame10.png'.format(i)).transpose(2, 0, 1) / 255. | |
i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255. | |
gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i)) | |
h, w = i0.shape[1], i0.shape[2] | |
imgs = torch.zeros([1, 6, 480, 640]).to(device) | |
ph = (480 - h) // 2 | |
pw = (640 - w) // 2 | |
imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float().to(device) | |
imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float().to(device) | |
I0 = imgs[:, :3] | |
I2 = imgs[:, 3:] | |
pred = model.inference(I0, I2) | |
out = pred[0].detach().cpu().numpy().transpose(1, 2, 0) | |
out = np.round(out[:h, :w] * 255) | |
IE_list.append(np.abs((out - gt * 1.0)).mean()) | |
print(np.mean(IE_list)) | |