Arnaudding001 commited on
Commit
de0dd3a
1 Parent(s): cfd00dd

Create rft_demo.py

Browse files
Files changed (1) hide show
  1. rft_demo.py +75 -0
rft_demo.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('core')
3
+
4
+ import argparse
5
+ import os
6
+ import cv2
7
+ import glob
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+
12
+ from raft import RAFT
13
+ from utils import flow_viz
14
+ from utils.utils import InputPadder
15
+
16
+
17
+
18
+ DEVICE = 'cuda'
19
+
20
+ def load_image(imfile):
21
+ img = np.array(Image.open(imfile)).astype(np.uint8)
22
+ img = torch.from_numpy(img).permute(2, 0, 1).float()
23
+ return img[None].to(DEVICE)
24
+
25
+
26
+ def viz(img, flo):
27
+ img = img[0].permute(1,2,0).cpu().numpy()
28
+ flo = flo[0].permute(1,2,0).cpu().numpy()
29
+
30
+ # map flow to rgb image
31
+ flo = flow_viz.flow_to_image(flo)
32
+ img_flo = np.concatenate([img, flo], axis=0)
33
+
34
+ # import matplotlib.pyplot as plt
35
+ # plt.imshow(img_flo / 255.0)
36
+ # plt.show()
37
+
38
+ cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
39
+ cv2.waitKey()
40
+
41
+
42
+ def demo(args):
43
+ model = torch.nn.DataParallel(RAFT(args))
44
+ model.load_state_dict(torch.load(args.model))
45
+
46
+ model = model.module
47
+ model.to(DEVICE)
48
+ model.eval()
49
+
50
+ with torch.no_grad():
51
+ images = glob.glob(os.path.join(args.path, '*.png')) + \
52
+ glob.glob(os.path.join(args.path, '*.jpg'))
53
+
54
+ images = sorted(images)
55
+ for imfile1, imfile2 in zip(images[:-1], images[1:]):
56
+ image1 = load_image(imfile1)
57
+ image2 = load_image(imfile2)
58
+
59
+ padder = InputPadder(image1.shape)
60
+ image1, image2 = padder.pad(image1, image2)
61
+
62
+ flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
63
+ viz(image1, flow_up)
64
+
65
+
66
+ if __name__ == '__main__':
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument('--model', help="restore checkpoint")
69
+ parser.add_argument('--path', help="dataset for evaluation")
70
+ parser.add_argument('--small', action='store_true', help='use small model')
71
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
72
+ parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
73
+ args = parser.parse_args()
74
+
75
+ demo(args)