yinwentao
DockerFile
8d34f50
# ref:https://github.com/ShunyuYao/DFA-NeRF
import sys
import os
from tqdm import tqdm
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(dir_path, 'core'))
from pathlib import Path
from data_test_flow import *
from models.network_test_flow import NeuralNRT
from options_test_flow import TestOptions
import torch
import numpy as np
def save_flow_numpy(filename, flow_input):
np.save(filename, flow_input)
def predict(data):
with torch.no_grad():
model.eval()
path_flow = data["path_flow"]
src_crop_im = data["src_crop_color"].cuda()
tar_crop_im = data["tar_crop_color"].cuda()
src_im = data["src_color"].cuda()
tar_im = data["tar_color"].cuda()
src_mask = data["src_mask"].cuda()
crop_param = data["Crop_param"].cuda()
B = src_mask.shape[0]
flow = model(src_crop_im, tar_crop_im, src_im, tar_im, crop_param)
for i in range(B):
flow_tmp = flow[i].cpu().numpy() * src_mask[i].cpu().numpy()
save_flow_numpy(os.path.join(save_path, os.path.basename(
path_flow[i])[:-6]+".npy"), flow_tmp)
if __name__ == "__main__":
width = 272
height = 480
test_opts = TestOptions().parse()
test_opts.pretrain_model_path = os.path.join(
dir_path, 'pretrain_model/raft-small.pth')
data_loader = CreateDataLoader(test_opts)
testloader = data_loader.load_data()
model_path = os.path.join(dir_path, 'sgd_NNRT_model_epoch19008_50000.pth')
model = NeuralNRT(test_opts, os.path.join(
dir_path, 'pretrain_model/raft-small.pth'))
state_dict = torch.load(model_path)
model.CorresPred.load_state_dict(state_dict["net_C"])
model.ImportanceW.load_state_dict(state_dict["net_W"])
model = model.cuda()
save_path = test_opts.savepath
Path(save_path).mkdir(parents=True, exist_ok=True)
total_length = len(testloader)
for batch_idx, data in tqdm(enumerate(testloader), total=total_length):
predict(data)