hpoghos commited on
Commit
ca6b1f0
1 Parent(s): 2536d14

Update i2v_enhance/thirdparty/VFI/Trainer.py

Browse files
Files changed (1) hide show
  1. i2v_enhance/thirdparty/VFI/Trainer.py +168 -168
i2v_enhance/thirdparty/VFI/Trainer.py CHANGED
@@ -1,168 +1,168 @@
1
- # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/Trainer.py
2
- import torch
3
- import torch.nn.functional as F
4
- from torch.nn.parallel import DistributedDataParallel as DDP
5
- from torch.optim import AdamW
6
- from i2v_enhance.thirdparty.VFI.model.loss import *
7
- from i2v_enhance.thirdparty.VFI.config import *
8
-
9
-
10
- class Model:
11
- def __init__(self, local_rank):
12
- backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
13
- backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH']
14
- self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg)
15
- self.name = MODEL_CONFIG['LOGNAME']
16
- self.device()
17
-
18
- # train
19
- self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4)
20
- self.lap = LapLoss()
21
- if local_rank != -1:
22
- self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank)
23
-
24
- def train(self):
25
- self.net.train()
26
-
27
- def eval(self):
28
- self.net.eval()
29
-
30
- def device(self):
31
- self.net.to(torch.device("cuda"))
32
-
33
- def unload(self):
34
- self.net.to(torch.device("cpu"))
35
-
36
- def load_model(self, name=None, rank=0):
37
- def convert(param):
38
- return {
39
- k.replace("module.", ""): v
40
- for k, v in param.items()
41
- if "module." in k and 'attn_mask' not in k and 'HW' not in k
42
- }
43
- if rank <= 0 :
44
- if name is None:
45
- name = self.name
46
- # self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')))
47
- self.net.load_state_dict(convert(torch.load(f'{name}')))
48
-
49
- def save_model(self, rank=0):
50
- if rank == 0:
51
- torch.save(self.net.state_dict(),f'ckpt/{self.name}.pkl')
52
-
53
- @torch.no_grad()
54
- def hr_inference(self, img0, img1, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False):
55
- '''
56
- Infer with down_scale flow
57
- Noting: return BxCxHxW
58
- '''
59
- def infer(imgs):
60
- img0, img1 = imgs[:, :3], imgs[:, 3:6]
61
- imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
62
-
63
- flow, mask = self.net.calculate_flow(imgs_down, timestep)
64
-
65
- flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
66
- mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
67
-
68
- af, _ = self.net.feature_bone(img0, img1)
69
- pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
70
- return pred
71
-
72
- imgs = torch.cat((img0, img1), 1)
73
- if fast_TTA:
74
- imgs_ = imgs.flip(2).flip(3)
75
- input = torch.cat((imgs, imgs_), 0)
76
- preds = infer(input)
77
- return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
78
-
79
- if TTA == False:
80
- return infer(imgs)
81
- else:
82
- return (infer(imgs) + infer(imgs.flip(2).flip(3)).flip(2).flip(3)) / 2
83
-
84
- @torch.no_grad()
85
- def inference(self, img0, img1, TTA = False, timestep = 0.5, fast_TTA = False):
86
- imgs = torch.cat((img0, img1), 1)
87
- '''
88
- Noting: return BxCxHxW
89
- '''
90
- if fast_TTA:
91
- imgs_ = imgs.flip(2).flip(3)
92
- input = torch.cat((imgs, imgs_), 0)
93
- _, _, _, preds = self.net(input, timestep=timestep)
94
- return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
95
-
96
- _, _, _, pred = self.net(imgs, timestep=timestep)
97
- if TTA == False:
98
- return pred
99
- else:
100
- _, _, _, pred2 = self.net(imgs.flip(2).flip(3), timestep=timestep)
101
- return (pred + pred2.flip(2).flip(3)) / 2
102
-
103
- @torch.no_grad()
104
- def multi_inference(self, img0, img1, TTA = False, down_scale = 1.0, time_list=[], fast_TTA = False):
105
- '''
106
- Run backbone once, get multi frames at different timesteps
107
- Noting: return a list of [CxHxW]
108
- '''
109
- assert len(time_list) > 0, 'Time_list should not be empty!'
110
- def infer(imgs):
111
- img0, img1 = imgs[:, :3], imgs[:, 3:6]
112
- af, mf = self.net.feature_bone(img0, img1)
113
- imgs_down = None
114
- if down_scale != 1.0:
115
- imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
116
- afd, mfd = self.net.feature_bone(imgs_down[:, :3], imgs_down[:, 3:6])
117
-
118
- pred_list = []
119
- for timestep in time_list:
120
- if imgs_down is None:
121
- flow, mask = self.net.calculate_flow(imgs, timestep, af, mf)
122
- else:
123
- flow, mask = self.net.calculate_flow(imgs_down, timestep, afd, mfd)
124
- flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
125
- mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
126
-
127
- pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
128
- pred_list.append(pred)
129
-
130
- return pred_list
131
-
132
- imgs = torch.cat((img0, img1), 1)
133
- if fast_TTA:
134
- imgs_ = imgs.flip(2).flip(3)
135
- input = torch.cat((imgs, imgs_), 0)
136
- preds_lst = infer(input)
137
- return [(preds_lst[i][0] + preds_lst[i][1].flip(1).flip(2))/2 for i in range(len(time_list))]
138
-
139
- preds = infer(imgs)
140
- if TTA is False:
141
- return [preds[i][0] for i in range(len(time_list))]
142
- else:
143
- flip_pred = infer(imgs.flip(2).flip(3))
144
- return [(preds[i][0] + flip_pred[i][0].flip(1).flip(2))/2 for i in range(len(time_list))]
145
-
146
- def update(self, imgs, gt, learning_rate=0, training=True):
147
- for param_group in self.optimG.param_groups:
148
- param_group['lr'] = learning_rate
149
- if training:
150
- self.train()
151
- else:
152
- self.eval()
153
-
154
- if training:
155
- flow, mask, merged, pred = self.net(imgs)
156
- loss_l1 = (self.lap(pred, gt)).mean()
157
-
158
- for merge in merged:
159
- loss_l1 += (self.lap(merge, gt)).mean() * 0.5
160
-
161
- self.optimG.zero_grad()
162
- loss_l1.backward()
163
- self.optimG.step()
164
- return pred, loss_l1
165
- else:
166
- with torch.no_grad():
167
- flow, mask, merged, pred = self.net(imgs)
168
- return pred, 0
 
1
+ # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/Trainer.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn.parallel import DistributedDataParallel as DDP
5
+ from torch.optim import AdamW
6
+ from i2v_enhance.thirdparty.VFI.model.loss import *
7
+ from i2v_enhance.thirdparty.VFI.config import *
8
+
9
+
10
+ class Model:
11
+ def __init__(self, local_rank):
12
+ backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
13
+ backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH']
14
+ self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg)
15
+ self.name = MODEL_CONFIG['LOGNAME']
16
+ # self.device()
17
+
18
+ # train
19
+ self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4)
20
+ self.lap = LapLoss()
21
+ if local_rank != -1:
22
+ self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank)
23
+
24
+ def train(self):
25
+ self.net.train()
26
+
27
+ def eval(self):
28
+ self.net.eval()
29
+
30
+ def device(self):
31
+ self.net.to(torch.device("cuda"))
32
+
33
+ def unload(self):
34
+ self.net.to(torch.device("cpu"))
35
+
36
+ def load_model(self, name=None, rank=0):
37
+ def convert(param):
38
+ return {
39
+ k.replace("module.", ""): v
40
+ for k, v in param.items()
41
+ if "module." in k and 'attn_mask' not in k and 'HW' not in k
42
+ }
43
+ if rank <= 0 :
44
+ if name is None:
45
+ name = self.name
46
+ # self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')))
47
+ self.net.load_state_dict(convert(torch.load(f'{name}')))
48
+
49
+ def save_model(self, rank=0):
50
+ if rank == 0:
51
+ torch.save(self.net.state_dict(),f'ckpt/{self.name}.pkl')
52
+
53
+ @torch.no_grad()
54
+ def hr_inference(self, img0, img1, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False):
55
+ '''
56
+ Infer with down_scale flow
57
+ Noting: return BxCxHxW
58
+ '''
59
+ def infer(imgs):
60
+ img0, img1 = imgs[:, :3], imgs[:, 3:6]
61
+ imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
62
+
63
+ flow, mask = self.net.calculate_flow(imgs_down, timestep)
64
+
65
+ flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
66
+ mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
67
+
68
+ af, _ = self.net.feature_bone(img0, img1)
69
+ pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
70
+ return pred
71
+
72
+ imgs = torch.cat((img0, img1), 1)
73
+ if fast_TTA:
74
+ imgs_ = imgs.flip(2).flip(3)
75
+ input = torch.cat((imgs, imgs_), 0)
76
+ preds = infer(input)
77
+ return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
78
+
79
+ if TTA == False:
80
+ return infer(imgs)
81
+ else:
82
+ return (infer(imgs) + infer(imgs.flip(2).flip(3)).flip(2).flip(3)) / 2
83
+
84
+ @torch.no_grad()
85
+ def inference(self, img0, img1, TTA = False, timestep = 0.5, fast_TTA = False):
86
+ imgs = torch.cat((img0, img1), 1)
87
+ '''
88
+ Noting: return BxCxHxW
89
+ '''
90
+ if fast_TTA:
91
+ imgs_ = imgs.flip(2).flip(3)
92
+ input = torch.cat((imgs, imgs_), 0)
93
+ _, _, _, preds = self.net(input, timestep=timestep)
94
+ return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
95
+
96
+ _, _, _, pred = self.net(imgs, timestep=timestep)
97
+ if TTA == False:
98
+ return pred
99
+ else:
100
+ _, _, _, pred2 = self.net(imgs.flip(2).flip(3), timestep=timestep)
101
+ return (pred + pred2.flip(2).flip(3)) / 2
102
+
103
+ @torch.no_grad()
104
+ def multi_inference(self, img0, img1, TTA = False, down_scale = 1.0, time_list=[], fast_TTA = False):
105
+ '''
106
+ Run backbone once, get multi frames at different timesteps
107
+ Noting: return a list of [CxHxW]
108
+ '''
109
+ assert len(time_list) > 0, 'Time_list should not be empty!'
110
+ def infer(imgs):
111
+ img0, img1 = imgs[:, :3], imgs[:, 3:6]
112
+ af, mf = self.net.feature_bone(img0, img1)
113
+ imgs_down = None
114
+ if down_scale != 1.0:
115
+ imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
116
+ afd, mfd = self.net.feature_bone(imgs_down[:, :3], imgs_down[:, 3:6])
117
+
118
+ pred_list = []
119
+ for timestep in time_list:
120
+ if imgs_down is None:
121
+ flow, mask = self.net.calculate_flow(imgs, timestep, af, mf)
122
+ else:
123
+ flow, mask = self.net.calculate_flow(imgs_down, timestep, afd, mfd)
124
+ flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
125
+ mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
126
+
127
+ pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
128
+ pred_list.append(pred)
129
+
130
+ return pred_list
131
+
132
+ imgs = torch.cat((img0, img1), 1)
133
+ if fast_TTA:
134
+ imgs_ = imgs.flip(2).flip(3)
135
+ input = torch.cat((imgs, imgs_), 0)
136
+ preds_lst = infer(input)
137
+ return [(preds_lst[i][0] + preds_lst[i][1].flip(1).flip(2))/2 for i in range(len(time_list))]
138
+
139
+ preds = infer(imgs)
140
+ if TTA is False:
141
+ return [preds[i][0] for i in range(len(time_list))]
142
+ else:
143
+ flip_pred = infer(imgs.flip(2).flip(3))
144
+ return [(preds[i][0] + flip_pred[i][0].flip(1).flip(2))/2 for i in range(len(time_list))]
145
+
146
+ def update(self, imgs, gt, learning_rate=0, training=True):
147
+ for param_group in self.optimG.param_groups:
148
+ param_group['lr'] = learning_rate
149
+ if training:
150
+ self.train()
151
+ else:
152
+ self.eval()
153
+
154
+ if training:
155
+ flow, mask, merged, pred = self.net(imgs)
156
+ loss_l1 = (self.lap(pred, gt)).mean()
157
+
158
+ for merge in merged:
159
+ loss_l1 += (self.lap(merge, gt)).mean() * 0.5
160
+
161
+ self.optimG.zero_grad()
162
+ loss_l1.backward()
163
+ self.optimG.step()
164
+ return pred, loss_l1
165
+ else:
166
+ with torch.no_grad():
167
+ flow, mask, merged, pred = self.net(imgs)
168
+ return pred, 0