Arnaudding001 commited on
Commit
fc378c9
1 Parent(s): 191cab2

Create smooth_parsing_map.py

Browse files
Files changed (1) hide show
  1. smooth_parsing_map.py +172 -0
smooth_parsing_map.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mport os
2
+ #os.environ['CUDA_VISIBLE_DEVICES'] = "0"
3
+ import numpy as np
4
+ import cv2
5
+ import math
6
+ import argparse
7
+ from tqdm import tqdm
8
+ import torch
9
+ from torch import nn
10
+ from torchvision import transforms
11
+ import torch.nn.functional as F
12
+ from model.raft.core.raft import RAFT
13
+ from model.raft.core.utils.utils import InputPadder
14
+ from model.bisenet.model import BiSeNet
15
+ from model.stylegan.model import Downsample
16
+
17
+ class Options():
18
+ def __init__(self):
19
+
20
+ self.parser = argparse.ArgumentParser(description="Smooth Parsing Maps")
21
+ self.parser.add_argument("--window_size", type=int, default=5, help="temporal window size")
22
+
23
+ self.parser.add_argument("--faceparsing_path", type=str, default='./checkpoint/faceparsing.pth', help="path of the face parsing model")
24
+ self.parser.add_argument("--raft_path", type=str, default='./checkpoint/raft-things.pth', help="path of the RAFT model")
25
+
26
+ self.parser.add_argument("--video_path", type=str, help="path of the target video")
27
+ self.parser.add_argument("--output_path", type=str, default='./output/', help="path of the output parsing maps")
28
+
29
+ def parse(self):
30
+ self.opt = self.parser.parse_args()
31
+ args = vars(self.opt)
32
+ print('Load options')
33
+ for name, value in sorted(args.items()):
34
+ print('%s: %s' % (str(name), str(value)))
35
+ return self.opt
36
+
37
+ # from RAFT
38
+ def warp(x, flo):
39
+ """
40
+ warp an image/tensor (im2) back to im1, according to the optical flow
41
+ x: [B, C, H, W] (im2)
42
+ flo: [B, 2, H, W] flow
43
+ """
44
+ B, C, H, W = x.size()
45
+ # mesh grid
46
+ xx = torch.arange(0, W).view(1,-1).repeat(H,1)
47
+ yy = torch.arange(0, H).view(-1,1).repeat(1,W)
48
+ xx = xx.view(1,1,H,W).repeat(B,1,1,1)
49
+ yy = yy.view(1,1,H,W).repeat(B,1,1,1)
50
+ grid = torch.cat((xx,yy),1).float()
51
+
52
+
53
+ #x = x.cuda()
54
+ grid = grid.cuda()
55
+ vgrid = grid + flo # B,2,H,W
56
+
57
+ # scale grid to [-1,1]
58
+ ##2019 code
59
+ vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone()/max(W-1,1)-1.0
60
+ vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone()/max(H-1,1)-1.0
61
+
62
+ vgrid = vgrid.permute(0,2,3,1)
63
+ output = nn.functional.grid_sample(x, vgrid,align_corners=True)
64
+ mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
65
+ mask = nn.functional.grid_sample(mask, vgrid,align_corners=True)
66
+
67
+ ##2019 author
68
+ mask[mask<0.9999] = 0
69
+ mask[mask>0] = 1
70
+
71
+ ##2019 code
72
+ # mask = torch.floor(torch.clamp(mask, 0 ,1))
73
+
74
+ return output*mask, mask
75
+
76
+
77
+ if __name__ == "__main__":
78
+
79
+ parser = Options()
80
+ args = parser.parse()
81
+ print('*'*98)
82
+
83
+
84
+ device = "cuda"
85
+
86
+ transform = transforms.Compose([
87
+ transforms.ToTensor(),
88
+ transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
89
+ ])
90
+
91
+ parser = argparse.ArgumentParser()
92
+ parser.add_argument('--model', help="restore checkpoint")
93
+ parser.add_argument('--small', action='store_true', help='use small model')
94
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
95
+ parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
96
+
97
+ raft_model = torch.nn.DataParallel(RAFT(parser.parse_args(['--model', args.raft_path])))
98
+ raft_model.load_state_dict(torch.load(args.raft_path))
99
+
100
+ raft_model = raft_model.module
101
+ raft_model.to(device)
102
+ raft_model.eval()
103
+
104
+ parsingpredictor = BiSeNet(n_classes=19)
105
+ parsingpredictor.load_state_dict(torch.load(args.faceparsing_path, map_location=lambda storage, loc: storage))
106
+ parsingpredictor.to(device).eval()
107
+
108
+ down = Downsample(kernel=[1, 3, 3, 1], factor=2).to(device).eval()
109
+
110
+ print('Load models successfully!')
111
+
112
+ window = args.window_size
113
+
114
+ video_cap = cv2.VideoCapture(args.video_path)
115
+ num = int(video_cap.get(7))
116
+
117
+ Is = []
118
+ for i in range(num):
119
+ success, frame = video_cap.read()
120
+ if success == False:
121
+ break
122
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
123
+ with torch.no_grad():
124
+ Is += [transform(frame).unsqueeze(dim=0).cpu()]
125
+ video_cap.release()
126
+
127
+ # enlarge frames for more accurate parsing maps and optical flows
128
+ Is = F.upsample(torch.cat(Is, dim=0), scale_factor=2, mode='bilinear')
129
+ Is_ = torch.cat((Is[0:window], Is, Is[-window:]), dim=0)
130
+
131
+ print('Load video with %d frames successfully!'%(len(Is)))
132
+
133
+ Ps = []
134
+ for i in tqdm(range(len(Is))):
135
+ with torch.no_grad():
136
+ Ps += [parsingpredictor(2*Is[i:i+1].to(device))[0].detach().cpu()]
137
+ Ps = torch.cat(Ps, dim=0)
138
+ Ps_ = torch.cat((Ps[0:window], Ps, Ps[-window:]), dim=0)
139
+
140
+ print('Predict parsing maps successfully!')
141
+
142
+
143
+ # temporal weights of the (2*args.window_size+1) frames
144
+ wt = torch.exp(-(torch.arange(2*window+1).float()-window)**2/(2*((window+0.5)**2))).reshape(2*window+1,1,1,1).to(device)
145
+
146
+ parse = []
147
+ for ii in tqdm(range(len(Is))):
148
+ i = ii + window
149
+ image2 = Is_[i-window:i+window+1].to(device)
150
+ image1 = Is_[i].repeat(2*window+1,1,1,1).to(device)
151
+ padder = InputPadder(image1.shape)
152
+ image1, image2 = padder.pad(image1, image2)
153
+ with torch.no_grad():
154
+ flow_low, flow_up = raft_model((image1+1)*255.0/2, (image2+1)*255.0/2, iters=20, test_mode=True)
155
+ output, mask = warp(torch.cat((image2, Ps_[i-window:i+window+1].to(device)), dim=1), flow_up)
156
+ aligned_Is = output[:,0:3].detach()
157
+ aligned_Ps = output[:,3:].detach()
158
+ # the spatial weight
159
+ ws = torch.exp(-((aligned_Is-image1)**2).mean(dim=1, keepdims=True)/(2*(0.2**2))) * mask[:,0:1]
160
+ aligned_Ps[window] = Ps_[i].to(device)
161
+ # the weight between i and i shoud be 1.0
162
+ ws[window,:,:,:] = 1.0
163
+ weights = ws*wt
164
+ weights = weights / weights.sum(dim=(0), keepdims=True)
165
+ fused_Ps = (aligned_Ps * weights).sum(dim=0, keepdims=True)
166
+ parse += [down(fused_Ps).detach().cpu()]
167
+ parse = torch.cat(parse, dim=0)
168
+
169
+ basename = os.path.basename(args.video_path).split('.')[0]
170
+ np.save(os.path.join(args.output_path, basename+'_parsingmap.npy'), parse.numpy())
171
+
172
+ print('Done!')