akhaliq3 commited on
Commit
035e10c
1 Parent(s): 80e980c

spaces demo

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Huage001
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
inference/.DS_Store ADDED
Binary file (6.15 kB). View file
 
inference/brush/brush_large_horizontal.png ADDED
inference/brush/brush_large_vertical.png ADDED
inference/brush/brush_small_horizontal.png ADDED
inference/brush/brush_small_vertical.png ADDED
inference/inference.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+ import network
6
+ import morphology
7
+ import os
8
+ import math
9
+
10
+ idx = 0
11
+
12
+
13
+ def save_img(img, output_path):
14
+ result = Image.fromarray((img.data.cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
15
+ result.save(output_path)
16
+
17
+
18
+ def param2stroke(param, H, W, meta_brushes):
19
+ """
20
+ Input a set of stroke parameters and output its corresponding foregrounds and alpha maps.
21
+ Args:
22
+ param: a tensor with shape n_strokes x n_param_per_stroke. Here, param_per_stroke is 8:
23
+ x_center, y_center, width, height, theta, R, G, and B.
24
+ H: output height.
25
+ W: output width.
26
+ meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
27
+ The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
28
+
29
+ Returns:
30
+ foregrounds: a tensor with shape n_strokes x 3 x H x W, containing color information.
31
+ alphas: a tensor with shape n_strokes x 3 x H x W,
32
+ containing binary information of whether a pixel is belonging to the stroke (alpha mat), for painting process.
33
+ """
34
+ # Firstly, resize the meta brushes to the required shape,
35
+ # in order to decrease GPU memory especially when the required shape is small.
36
+ meta_brushes_resize = F.interpolate(meta_brushes, (H, W))
37
+ b = param.shape[0]
38
+ # Extract shape parameters and color parameters.
39
+ param_list = torch.split(param, 1, dim=1)
40
+ x0, y0, w, h, theta = [item.squeeze(-1) for item in param_list[:5]]
41
+ R, G, B = param_list[5:]
42
+ # Pre-compute sin theta and cos theta
43
+ sin_theta = torch.sin(torch.acos(torch.tensor(-1., device=param.device)) * theta)
44
+ cos_theta = torch.cos(torch.acos(torch.tensor(-1., device=param.device)) * theta)
45
+ # index means each stroke should use which meta stroke? Vertical meta stroke or horizontal meta stroke.
46
+ # When h > w, vertical stroke should be used. When h <= w, horizontal stroke should be used.
47
+ index = torch.full((b,), -1, device=param.device, dtype=torch.long)
48
+ index[h > w] = 0
49
+ index[h <= w] = 1
50
+ brush = meta_brushes_resize[index.long()]
51
+
52
+ # Calculate warp matrix according to the rules defined by pytorch, in order for warping.
53
+ warp_00 = cos_theta / w
54
+ warp_01 = sin_theta * H / (W * w)
55
+ warp_02 = (1 - 2 * x0) * cos_theta / w + (1 - 2 * y0) * sin_theta * H / (W * w)
56
+ warp_10 = -sin_theta * W / (H * h)
57
+ warp_11 = cos_theta / h
58
+ warp_12 = (1 - 2 * y0) * cos_theta / h - (1 - 2 * x0) * sin_theta * W / (H * h)
59
+ warp_0 = torch.stack([warp_00, warp_01, warp_02], dim=1)
60
+ warp_1 = torch.stack([warp_10, warp_11, warp_12], dim=1)
61
+ warp = torch.stack([warp_0, warp_1], dim=1)
62
+ # Conduct warping.
63
+ grid = F.affine_grid(warp, [b, 3, H, W], align_corners=False)
64
+ brush = F.grid_sample(brush, grid, align_corners=False)
65
+ # alphas is the binary information suggesting whether a pixel is belonging to the stroke.
66
+ alphas = (brush > 0).float()
67
+ brush = brush.repeat(1, 3, 1, 1)
68
+ alphas = alphas.repeat(1, 3, 1, 1)
69
+ # Give color to foreground strokes.
70
+ color_map = torch.cat([R, G, B], dim=1)
71
+ color_map = color_map.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, H, W)
72
+ foreground = brush * color_map
73
+ # Dilation and erosion are used for foregrounds and alphas respectively to prevent artifacts on stroke borders.
74
+ foreground = morphology.dilation(foreground)
75
+ alphas = morphology.erosion(alphas)
76
+ return foreground, alphas
77
+
78
+
79
+ def param2img_serial(
80
+ param, decision, meta_brushes, cur_canvas, frame_dir, has_border=False, original_h=None, original_w=None):
81
+ """
82
+ Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
83
+ and whether there is a border (if intermediate painting results are required).
84
+ Output the painting results of adding the corresponding strokes on the current canvas.
85
+ Args:
86
+ param: a tensor with shape batch size x patch along height dimension x patch along width dimension
87
+ x n_stroke_per_patch x n_param_per_stroke
88
+ decision: a 01 tensor with shape batch size x patch along height dimension x patch along width dimension
89
+ x n_stroke_per_patch
90
+ meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
91
+ The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
92
+ cur_canvas: a tensor with shape batch size x 3 x H x W,
93
+ where H and W denote height and width of padded results of original images.
94
+ frame_dir: directory to save intermediate painting results. None means intermediate results are not required.
95
+ has_border: on the last painting layer, in order to make sure that the painting results do not miss
96
+ any important detail, we choose to paint again on this layer but shift patch_size // 2 pixels when
97
+ cutting patches. In this case, if intermediate results are required, we need to cut the shifted length
98
+ on the border before saving, or there would be a black border.
99
+ original_h: to indicate the original height for cropping when saving intermediate results.
100
+ original_w: to indicate the original width for cropping when saving intermediate results.
101
+
102
+ Returns:
103
+ cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
104
+ """
105
+ # param: b, h, w, stroke_per_patch, param_per_stroke
106
+ # decision: b, h, w, stroke_per_patch
107
+ b, h, w, s, p = param.shape
108
+ H, W = cur_canvas.shape[-2:]
109
+ is_odd_y = h % 2 == 1
110
+ is_odd_x = w % 2 == 1
111
+ patch_size_y = 2 * H // h
112
+ patch_size_x = 2 * W // w
113
+ even_idx_y = torch.arange(0, h, 2, device=cur_canvas.device)
114
+ even_idx_x = torch.arange(0, w, 2, device=cur_canvas.device)
115
+ odd_idx_y = torch.arange(1, h, 2, device=cur_canvas.device)
116
+ odd_idx_x = torch.arange(1, w, 2, device=cur_canvas.device)
117
+ even_y_even_x_coord_y, even_y_even_x_coord_x = torch.meshgrid([even_idx_y, even_idx_x])
118
+ odd_y_odd_x_coord_y, odd_y_odd_x_coord_x = torch.meshgrid([odd_idx_y, odd_idx_x])
119
+ even_y_odd_x_coord_y, even_y_odd_x_coord_x = torch.meshgrid([even_idx_y, odd_idx_x])
120
+ odd_y_even_x_coord_y, odd_y_even_x_coord_x = torch.meshgrid([odd_idx_y, even_idx_x])
121
+ cur_canvas = F.pad(cur_canvas, [patch_size_x // 4, patch_size_x // 4,
122
+ patch_size_y // 4, patch_size_y // 4, 0, 0, 0, 0])
123
+
124
+ def partial_render(this_canvas, patch_coord_y, patch_coord_x, stroke_id):
125
+ canvas_patch = F.unfold(this_canvas, (patch_size_y, patch_size_x),
126
+ stride=(patch_size_y // 2, patch_size_x // 2))
127
+ # canvas_patch: b, 3 * py * px, h * w
128
+ canvas_patch = canvas_patch.view(b, 3, patch_size_y, patch_size_x, h, w).contiguous()
129
+ canvas_patch = canvas_patch.permute(0, 4, 5, 1, 2, 3).contiguous()
130
+ # canvas_patch: b, h, w, 3, py, px
131
+ selected_canvas_patch = canvas_patch[:, patch_coord_y, patch_coord_x, :, :, :]
132
+ selected_h, selected_w = selected_canvas_patch.shape[1:3]
133
+ selected_param = param[:, patch_coord_y, patch_coord_x, stroke_id, :].view(-1, p).contiguous()
134
+ selected_decision = decision[:, patch_coord_y, patch_coord_x, stroke_id].view(-1).contiguous()
135
+ selected_foregrounds = torch.zeros(selected_param.shape[0], 3, patch_size_y, patch_size_x,
136
+ device=this_canvas.device)
137
+ selected_alphas = torch.zeros(selected_param.shape[0], 3, patch_size_y, patch_size_x, device=this_canvas.device)
138
+ if selected_param[selected_decision, :].shape[0] > 0:
139
+ selected_foregrounds[selected_decision, :, :, :], selected_alphas[selected_decision, :, :, :] = \
140
+ param2stroke(selected_param[selected_decision, :], patch_size_y, patch_size_x, meta_brushes)
141
+ selected_foregrounds = selected_foregrounds.view(
142
+ b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
143
+ selected_alphas = selected_alphas.view(b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
144
+ selected_decision = selected_decision.view(b, selected_h, selected_w, 1, 1, 1).contiguous()
145
+ selected_canvas_patch = selected_foregrounds * selected_alphas * selected_decision + selected_canvas_patch * (
146
+ 1 - selected_alphas * selected_decision)
147
+ this_canvas = selected_canvas_patch.permute(0, 3, 1, 4, 2, 5).contiguous()
148
+ # this_canvas: b, 3, selected_h, py, selected_w, px
149
+ this_canvas = this_canvas.view(b, 3, selected_h * patch_size_y, selected_w * patch_size_x).contiguous()
150
+ # this_canvas: b, 3, selected_h * py, selected_w * px
151
+ return this_canvas
152
+
153
+ global idx
154
+ if has_border:
155
+ factor = 2
156
+ else:
157
+ factor = 4
158
+ if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
159
+ for i in range(s):
160
+ canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x, i)
161
+ if not is_odd_y:
162
+ canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
163
+ if not is_odd_x:
164
+ canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
165
+ cur_canvas = canvas
166
+ idx += 1
167
+ if frame_dir is not None:
168
+ frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
169
+ patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
170
+ save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
171
+
172
+ if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
173
+ for i in range(s):
174
+ canvas = partial_render(cur_canvas, odd_y_odd_x_coord_y, odd_y_odd_x_coord_x, i)
175
+ canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, -canvas.shape[3]:], canvas], dim=2)
176
+ canvas = torch.cat([cur_canvas[:, :, -canvas.shape[2]:, :patch_size_x // 2], canvas], dim=3)
177
+ if is_odd_y:
178
+ canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
179
+ if is_odd_x:
180
+ canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
181
+ cur_canvas = canvas
182
+ idx += 1
183
+ if frame_dir is not None:
184
+ frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
185
+ patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
186
+ save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
187
+
188
+ if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
189
+ for i in range(s):
190
+ canvas = partial_render(cur_canvas, odd_y_even_x_coord_y, odd_y_even_x_coord_x, i)
191
+ canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, :canvas.shape[3]], canvas], dim=2)
192
+ if is_odd_y:
193
+ canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
194
+ if not is_odd_x:
195
+ canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
196
+ cur_canvas = canvas
197
+ idx += 1
198
+ if frame_dir is not None:
199
+ frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
200
+ patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
201
+ save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
202
+
203
+ if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
204
+ for i in range(s):
205
+ canvas = partial_render(cur_canvas, even_y_odd_x_coord_y, even_y_odd_x_coord_x, i)
206
+ canvas = torch.cat([cur_canvas[:, :, :canvas.shape[2], :patch_size_x // 2], canvas], dim=3)
207
+ if not is_odd_y:
208
+ canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, -canvas.shape[3]:]], dim=2)
209
+ if is_odd_x:
210
+ canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
211
+ cur_canvas = canvas
212
+ idx += 1
213
+ if frame_dir is not None:
214
+ frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
215
+ patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
216
+ save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
217
+
218
+ cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
219
+
220
+ return cur_canvas
221
+
222
+
223
+ def param2img_parallel(param, decision, meta_brushes, cur_canvas):
224
+ """
225
+ Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
226
+ and whether there is a border (if intermediate painting results are required).
227
+ Output the painting results of adding the corresponding strokes on the current canvas.
228
+ Args:
229
+ param: a tensor with shape batch size x patch along height dimension x patch along width dimension
230
+ x n_stroke_per_patch x n_param_per_stroke
231
+ decision: a 01 tensor with shape batch size x patch along height dimension x patch along width dimension
232
+ x n_stroke_per_patch
233
+ meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
234
+ The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
235
+ cur_canvas: a tensor with shape batch size x 3 x H x W,
236
+ where H and W denote height and width of padded results of original images.
237
+
238
+ Returns:
239
+ cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
240
+ """
241
+ # param: b, h, w, stroke_per_patch, param_per_stroke
242
+ # decision: b, h, w, stroke_per_patch
243
+ b, h, w, s, p = param.shape
244
+ param = param.view(-1, 8).contiguous()
245
+ decision = decision.view(-1).contiguous().bool()
246
+ H, W = cur_canvas.shape[-2:]
247
+ is_odd_y = h % 2 == 1
248
+ is_odd_x = w % 2 == 1
249
+ patch_size_y = 2 * H // h
250
+ patch_size_x = 2 * W // w
251
+ even_idx_y = torch.arange(0, h, 2, device=cur_canvas.device)
252
+ even_idx_x = torch.arange(0, w, 2, device=cur_canvas.device)
253
+ odd_idx_y = torch.arange(1, h, 2, device=cur_canvas.device)
254
+ odd_idx_x = torch.arange(1, w, 2, device=cur_canvas.device)
255
+ even_y_even_x_coord_y, even_y_even_x_coord_x = torch.meshgrid([even_idx_y, even_idx_x])
256
+ odd_y_odd_x_coord_y, odd_y_odd_x_coord_x = torch.meshgrid([odd_idx_y, odd_idx_x])
257
+ even_y_odd_x_coord_y, even_y_odd_x_coord_x = torch.meshgrid([even_idx_y, odd_idx_x])
258
+ odd_y_even_x_coord_y, odd_y_even_x_coord_x = torch.meshgrid([odd_idx_y, even_idx_x])
259
+ cur_canvas = F.pad(cur_canvas, [patch_size_x // 4, patch_size_x // 4,
260
+ patch_size_y // 4, patch_size_y // 4, 0, 0, 0, 0])
261
+ foregrounds = torch.zeros(param.shape[0], 3, patch_size_y, patch_size_x, device=cur_canvas.device)
262
+ alphas = torch.zeros(param.shape[0], 3, patch_size_y, patch_size_x, device=cur_canvas.device)
263
+ valid_foregrounds, valid_alphas = param2stroke(param[decision, :], patch_size_y, patch_size_x, meta_brushes)
264
+ foregrounds[decision, :, :, :] = valid_foregrounds
265
+ alphas[decision, :, :, :] = valid_alphas
266
+ # foreground, alpha: b * h * w * stroke_per_patch, 3, patch_size_y, patch_size_x
267
+ foregrounds = foregrounds.view(-1, h, w, s, 3, patch_size_y, patch_size_x).contiguous()
268
+ alphas = alphas.view(-1, h, w, s, 3, patch_size_y, patch_size_x).contiguous()
269
+ # foreground, alpha: b, h, w, stroke_per_patch, 3, render_size_y, render_size_x
270
+ decision = decision.view(-1, h, w, s, 1, 1, 1).contiguous()
271
+
272
+ # decision: b, h, w, stroke_per_patch, 1, 1, 1
273
+
274
+ def partial_render(this_canvas, patch_coord_y, patch_coord_x):
275
+
276
+ canvas_patch = F.unfold(this_canvas, (patch_size_y, patch_size_x),
277
+ stride=(patch_size_y // 2, patch_size_x // 2))
278
+ # canvas_patch: b, 3 * py * px, h * w
279
+ canvas_patch = canvas_patch.view(b, 3, patch_size_y, patch_size_x, h, w).contiguous()
280
+ canvas_patch = canvas_patch.permute(0, 4, 5, 1, 2, 3).contiguous()
281
+ # canvas_patch: b, h, w, 3, py, px
282
+ selected_canvas_patch = canvas_patch[:, patch_coord_y, patch_coord_x, :, :, :]
283
+ selected_foregrounds = foregrounds[:, patch_coord_y, patch_coord_x, :, :, :, :]
284
+ selected_alphas = alphas[:, patch_coord_y, patch_coord_x, :, :, :, :]
285
+ selected_decisions = decision[:, patch_coord_y, patch_coord_x, :, :, :, :]
286
+ for i in range(s):
287
+ cur_foreground = selected_foregrounds[:, :, :, i, :, :, :]
288
+ cur_alpha = selected_alphas[:, :, :, i, :, :, :]
289
+ cur_decision = selected_decisions[:, :, :, i, :, :, :]
290
+ selected_canvas_patch = cur_foreground * cur_alpha * cur_decision + selected_canvas_patch * (
291
+ 1 - cur_alpha * cur_decision)
292
+ this_canvas = selected_canvas_patch.permute(0, 3, 1, 4, 2, 5).contiguous()
293
+ # this_canvas: b, 3, h_half, py, w_half, px
294
+ h_half = this_canvas.shape[2]
295
+ w_half = this_canvas.shape[4]
296
+ this_canvas = this_canvas.view(b, 3, h_half * patch_size_y, w_half * patch_size_x).contiguous()
297
+ # this_canvas: b, 3, h_half * py, w_half * px
298
+ return this_canvas
299
+
300
+ if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
301
+ canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x)
302
+ if not is_odd_y:
303
+ canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
304
+ if not is_odd_x:
305
+ canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
306
+ cur_canvas = canvas
307
+
308
+ if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
309
+ canvas = partial_render(cur_canvas, odd_y_odd_x_coord_y, odd_y_odd_x_coord_x)
310
+ canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, -canvas.shape[3]:], canvas], dim=2)
311
+ canvas = torch.cat([cur_canvas[:, :, -canvas.shape[2]:, :patch_size_x // 2], canvas], dim=3)
312
+ if is_odd_y:
313
+ canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
314
+ if is_odd_x:
315
+ canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
316
+ cur_canvas = canvas
317
+
318
+ if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
319
+ canvas = partial_render(cur_canvas, odd_y_even_x_coord_y, odd_y_even_x_coord_x)
320
+ canvas = torch.cat([cur_canvas[:, :, :patch_size_y // 2, :canvas.shape[3]], canvas], dim=2)
321
+ if is_odd_y:
322
+ canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, :canvas.shape[3]]], dim=2)
323
+ if not is_odd_x:
324
+ canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
325
+ cur_canvas = canvas
326
+
327
+ if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
328
+ canvas = partial_render(cur_canvas, even_y_odd_x_coord_y, even_y_odd_x_coord_x)
329
+ canvas = torch.cat([cur_canvas[:, :, :canvas.shape[2], :patch_size_x // 2], canvas], dim=3)
330
+ if not is_odd_y:
331
+ canvas = torch.cat([canvas, cur_canvas[:, :, -patch_size_y // 2:, -canvas.shape[3]:]], dim=2)
332
+ if is_odd_x:
333
+ canvas = torch.cat([canvas, cur_canvas[:, :, :canvas.shape[2], -patch_size_x // 2:]], dim=3)
334
+ cur_canvas = canvas
335
+
336
+ cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
337
+
338
+ return cur_canvas
339
+
340
+
341
+ def read_img(img_path, img_type='RGB', h=None, w=None):
342
+ img = Image.open(img_path).convert(img_type)
343
+ if h is not None and w is not None:
344
+ img = img.resize((w, h), resample=Image.NEAREST)
345
+ img = np.array(img)
346
+ if img.ndim == 2:
347
+ img = np.expand_dims(img, axis=-1)
348
+ img = img.transpose((2, 0, 1))
349
+ img = torch.from_numpy(img).unsqueeze(0).float() / 255.
350
+ return img
351
+
352
+
353
+ def pad(img, H, W):
354
+ b, c, h, w = img.shape
355
+ pad_h = (H - h) // 2
356
+ pad_w = (W - w) // 2
357
+ remainder_h = (H - h) % 2
358
+ remainder_w = (W - w) % 2
359
+ img = torch.cat([torch.zeros((b, c, pad_h, w), device=img.device), img,
360
+ torch.zeros((b, c, pad_h + remainder_h, w), device=img.device)], dim=-2)
361
+ img = torch.cat([torch.zeros((b, c, H, pad_w), device=img.device), img,
362
+ torch.zeros((b, c, H, pad_w + remainder_w), device=img.device)], dim=-1)
363
+ return img
364
+
365
+
366
+ def crop(img, h, w):
367
+ H, W = img.shape[-2:]
368
+ pad_h = (H - h) // 2
369
+ pad_w = (W - w) // 2
370
+ remainder_h = (H - h) % 2
371
+ remainder_w = (W - w) % 2
372
+ img = img[:, :, pad_h:H - pad_h - remainder_h, pad_w:W - pad_w - remainder_w]
373
+ return img
374
+
375
+
376
+ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
377
+ if not os.path.exists(output_dir):
378
+ os.mkdir(output_dir)
379
+ input_name = os.path.basename(input_path)
380
+ output_path = os.path.join(output_dir, input_name)
381
+ frame_dir = None
382
+ if need_animation:
383
+ if not serial:
384
+ print('It must be under serial mode if animation results are required, so serial flag is set to True!')
385
+ serial = True
386
+ frame_dir = os.path.join(output_dir, input_name[:input_name.find('.')])
387
+ if not os.path.exists(frame_dir):
388
+ os.mkdir(frame_dir)
389
+ patch_size = 32
390
+ stroke_num = 8
391
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
392
+ net_g = network.Painter(5, stroke_num, 256, 8, 3, 3).to(device)
393
+ net_g.load_state_dict(torch.load(model_path))
394
+ net_g.eval()
395
+ for param in net_g.parameters():
396
+ param.requires_grad = False
397
+
398
+ brush_large_vertical = read_img('brush/brush_large_vertical.png', 'L').to(device)
399
+ brush_large_horizontal = read_img('brush/brush_large_horizontal.png', 'L').to(device)
400
+ meta_brushes = torch.cat(
401
+ [brush_large_vertical, brush_large_horizontal], dim=0)
402
+
403
+ with torch.no_grad():
404
+ original_img = read_img(input_path, 'RGB', resize_h, resize_w).to(device)
405
+ original_h, original_w = original_img.shape[-2:]
406
+ K = max(math.ceil(math.log2(max(original_h, original_w) / patch_size)), 0)
407
+ original_img_pad_size = patch_size * (2 ** K)
408
+ original_img_pad = pad(original_img, original_img_pad_size, original_img_pad_size)
409
+ final_result = torch.zeros_like(original_img_pad).to(device)
410
+ for layer in range(0, K + 1):
411
+ layer_size = patch_size * (2 ** layer)
412
+ img = F.interpolate(original_img_pad, (layer_size, layer_size))
413
+ result = F.interpolate(final_result, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
414
+ img_patch = F.unfold(img, (patch_size, patch_size), stride=(patch_size, patch_size))
415
+ result_patch = F.unfold(result, (patch_size, patch_size),
416
+ stride=(patch_size, patch_size))
417
+ # There are patch_num * patch_num patches in total
418
+ patch_num = (layer_size - patch_size) // patch_size + 1
419
+
420
+ # img_patch, result_patch: b, 3 * output_size * output_size, h * w
421
+ img_patch = img_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
422
+ result_patch = result_patch.permute(0, 2, 1).contiguous().view(
423
+ -1, 3, patch_size, patch_size).contiguous()
424
+ shape_param, stroke_decision = net_g(img_patch, result_patch)
425
+ stroke_decision = network.SignWithSigmoidGrad.apply(stroke_decision)
426
+
427
+ grid = shape_param[:, :, :2].view(img_patch.shape[0] * stroke_num, 1, 1, 2).contiguous()
428
+ img_temp = img_patch.unsqueeze(1).contiguous().repeat(1, stroke_num, 1, 1, 1).view(
429
+ img_patch.shape[0] * stroke_num, 3, patch_size, patch_size).contiguous()
430
+ color = F.grid_sample(img_temp, 2 * grid - 1, align_corners=False).view(
431
+ img_patch.shape[0], stroke_num, 3).contiguous()
432
+ stroke_param = torch.cat([shape_param, color], dim=-1)
433
+ # stroke_param: b * h * w, stroke_per_patch, param_per_stroke
434
+ # stroke_decision: b * h * w, stroke_per_patch, 1
435
+ param = stroke_param.view(1, patch_num, patch_num, stroke_num, 8).contiguous()
436
+ decision = stroke_decision.view(1, patch_num, patch_num, stroke_num).contiguous().bool()
437
+ # param: b, h, w, stroke_per_patch, 8
438
+ # decision: b, h, w, stroke_per_patch
439
+ param[..., :2] = param[..., :2] / 2 + 0.25
440
+ param[..., 2:4] = param[..., 2:4] / 2
441
+ if serial:
442
+ final_result = param2img_serial(param, decision, meta_brushes, final_result,
443
+ frame_dir, False, original_h, original_w)
444
+ else:
445
+ final_result = param2img_parallel(param, decision, meta_brushes, final_result)
446
+
447
+ border_size = original_img_pad_size // (2 * patch_num)
448
+ img = F.interpolate(original_img_pad, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
449
+ result = F.interpolate(final_result, (patch_size * (2 ** layer), patch_size * (2 ** layer)))
450
+ img = F.pad(img, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2,
451
+ 0, 0, 0, 0])
452
+ result = F.pad(result, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2,
453
+ 0, 0, 0, 0])
454
+ img_patch = F.unfold(img, (patch_size, patch_size), stride=(patch_size, patch_size))
455
+ result_patch = F.unfold(result, (patch_size, patch_size), stride=(patch_size, patch_size))
456
+ final_result = F.pad(final_result, [border_size, border_size, border_size, border_size, 0, 0, 0, 0])
457
+ h = (img.shape[2] - patch_size) // patch_size + 1
458
+ w = (img.shape[3] - patch_size) // patch_size + 1
459
+ # img_patch, result_patch: b, 3 * output_size * output_size, h * w
460
+ img_patch = img_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
461
+ result_patch = result_patch.permute(0, 2, 1).contiguous().view(-1, 3, patch_size, patch_size).contiguous()
462
+ shape_param, stroke_decision = net_g(img_patch, result_patch)
463
+
464
+ grid = shape_param[:, :, :2].view(img_patch.shape[0] * stroke_num, 1, 1, 2).contiguous()
465
+ img_temp = img_patch.unsqueeze(1).contiguous().repeat(1, stroke_num, 1, 1, 1).view(
466
+ img_patch.shape[0] * stroke_num, 3, patch_size, patch_size).contiguous()
467
+ color = F.grid_sample(img_temp, 2 * grid - 1, align_corners=False).view(
468
+ img_patch.shape[0], stroke_num, 3).contiguous()
469
+ stroke_param = torch.cat([shape_param, color], dim=-1)
470
+ # stroke_param: b * h * w, stroke_per_patch, param_per_stroke
471
+ # stroke_decision: b * h * w, stroke_per_patch, 1
472
+ param = stroke_param.view(1, h, w, stroke_num, 8).contiguous()
473
+ decision = stroke_decision.view(1, h, w, stroke_num).contiguous().bool()
474
+ # param: b, h, w, stroke_per_patch, 8
475
+ # decision: b, h, w, stroke_per_patch
476
+ param[..., :2] = param[..., :2] / 2 + 0.25
477
+ param[..., 2:4] = param[..., 2:4] / 2
478
+ if serial:
479
+ final_result = param2img_serial(param, decision, meta_brushes, final_result,
480
+ frame_dir, True, original_h, original_w)
481
+ else:
482
+ final_result = param2img_parallel(param, decision, meta_brushes, final_result)
483
+ final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
484
+
485
+ final_result = crop(final_result, original_h, original_w)
486
+ save_img(final_result[0], output_path)
487
+
488
+
489
+ if __name__ == '__main__':
490
+ main(input_path='input/chicago.jpg',
491
+ model_path='model.pth',
492
+ output_dir='output/',
493
+ need_animation=False, # whether need intermediate results for animation.
494
+ resize_h=None, # resize original input to this size. None means do not resize.
495
+ resize_w=None, # resize original input to this size. None means do not resize.
496
+ serial=False) # if need animation, serial must be True.
inference/input/.DS_Store ADDED
Binary file (6.15 kB). View file
 
inference/input/temp.txt ADDED
File without changes
inference/morphology.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Erosion2d(nn.Module):
7
+
8
+ def __init__(self, m=1):
9
+ super(Erosion2d, self).__init__()
10
+ self.m = m
11
+ self.pad = [m, m, m, m]
12
+ self.unfold = nn.Unfold(2 * m + 1, padding=0, stride=1)
13
+
14
+ def forward(self, x):
15
+ batch_size, c, h, w = x.shape
16
+ x_pad = F.pad(x, pad=self.pad, mode='constant', value=1e9)
17
+ channel = self.unfold(x_pad).view(batch_size, c, -1, h, w)
18
+ result = torch.min(channel, dim=2)[0]
19
+ return result
20
+
21
+
22
+ def erosion(x, m=1):
23
+ b, c, h, w = x.shape
24
+ x_pad = F.pad(x, pad=[m, m, m, m], mode='constant', value=1e9)
25
+ channel = nn.functional.unfold(x_pad, 2 * m + 1, padding=0, stride=1).view(b, c, -1, h, w)
26
+ result = torch.min(channel, dim=2)[0]
27
+ return result
28
+
29
+
30
+ class Dilation2d(nn.Module):
31
+
32
+ def __init__(self, m=1):
33
+ super(Dilation2d, self).__init__()
34
+ self.m = m
35
+ self.pad = [m, m, m, m]
36
+ self.unfold = nn.Unfold(2 * m + 1, padding=0, stride=1)
37
+
38
+ def forward(self, x):
39
+ batch_size, c, h, w = x.shape
40
+ x_pad = F.pad(x, pad=self.pad, mode='constant', value=-1e9)
41
+ channel = self.unfold(x_pad).view(batch_size, c, -1, h, w)
42
+ result = torch.max(channel, dim=2)[0]
43
+ return result
44
+
45
+
46
+ def dilation(x, m=1):
47
+ b, c, h, w = x.shape
48
+ x_pad = F.pad(x, pad=[m, m, m, m], mode='constant', value=-1e9)
49
+ channel = nn.functional.unfold(x_pad, 2 * m + 1, padding=0, stride=1).view(b, c, -1, h, w)
50
+ result = torch.max(channel, dim=2)[0]
51
+ return result
inference/network.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SignWithSigmoidGrad(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, x):
9
+ result = (x > 0).float()
10
+ sigmoid_result = torch.sigmoid(x)
11
+ ctx.save_for_backward(sigmoid_result)
12
+ return result
13
+
14
+ @staticmethod
15
+ def backward(ctx, grad_result):
16
+ (sigmoid_result,) = ctx.saved_tensors
17
+ if ctx.needs_input_grad[0]:
18
+ grad_input = grad_result * sigmoid_result * (1 - sigmoid_result)
19
+ else:
20
+ grad_input = None
21
+ return grad_input
22
+
23
+
24
+ class Painter(nn.Module):
25
+
26
+ def __init__(self, param_per_stroke, total_strokes, hidden_dim, n_heads=8, n_enc_layers=3, n_dec_layers=3):
27
+ super().__init__()
28
+ self.enc_img = nn.Sequential(
29
+ nn.ReflectionPad2d(1),
30
+ nn.Conv2d(3, 32, 3, 1),
31
+ nn.BatchNorm2d(32),
32
+ nn.ReLU(True),
33
+ nn.ReflectionPad2d(1),
34
+ nn.Conv2d(32, 64, 3, 2),
35
+ nn.BatchNorm2d(64),
36
+ nn.ReLU(True),
37
+ nn.ReflectionPad2d(1),
38
+ nn.Conv2d(64, 128, 3, 2),
39
+ nn.BatchNorm2d(128),
40
+ nn.ReLU(True))
41
+ self.enc_canvas = nn.Sequential(
42
+ nn.ReflectionPad2d(1),
43
+ nn.Conv2d(3, 32, 3, 1),
44
+ nn.BatchNorm2d(32),
45
+ nn.ReLU(True),
46
+ nn.ReflectionPad2d(1),
47
+ nn.Conv2d(32, 64, 3, 2),
48
+ nn.BatchNorm2d(64),
49
+ nn.ReLU(True),
50
+ nn.ReflectionPad2d(1),
51
+ nn.Conv2d(64, 128, 3, 2),
52
+ nn.BatchNorm2d(128),
53
+ nn.ReLU(True))
54
+ self.conv = nn.Conv2d(128 * 2, hidden_dim, 1)
55
+ self.transformer = nn.Transformer(hidden_dim, n_heads, n_enc_layers, n_dec_layers)
56
+ self.linear_param = nn.Sequential(
57
+ nn.Linear(hidden_dim, hidden_dim),
58
+ nn.ReLU(True),
59
+ nn.Linear(hidden_dim, hidden_dim),
60
+ nn.ReLU(True),
61
+ nn.Linear(hidden_dim, param_per_stroke))
62
+ self.linear_decider = nn.Linear(hidden_dim, 1)
63
+ self.query_pos = nn.Parameter(torch.rand(total_strokes, hidden_dim))
64
+ self.row_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
65
+ self.col_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
66
+
67
+ def forward(self, img, canvas):
68
+ b, _, H, W = img.shape
69
+ img_feat = self.enc_img(img)
70
+ canvas_feat = self.enc_canvas(canvas)
71
+ h, w = img_feat.shape[-2:]
72
+ feat = torch.cat([img_feat, canvas_feat], dim=1)
73
+ feat_conv = self.conv(feat)
74
+
75
+ pos_embed = torch.cat([
76
+ self.col_embed[:w].unsqueeze(0).contiguous().repeat(h, 1, 1),
77
+ self.row_embed[:h].unsqueeze(1).contiguous().repeat(1, w, 1),
78
+ ], dim=-1).flatten(0, 1).unsqueeze(1)
79
+ hidden_state = self.transformer(pos_embed + feat_conv.flatten(2).permute(2, 0, 1).contiguous(),
80
+ self.query_pos.unsqueeze(1).contiguous().repeat(1, b, 1))
81
+ hidden_state = hidden_state.permute(1, 0, 2).contiguous()
82
+ param = self.linear_param(hidden_state)
83
+ decision = self.linear_decider(hidden_state)
84
+ return param, decision
train/brush/brush_large_horizontal.png ADDED
train/brush/brush_large_vertical.png ADDED
train/brush/brush_small_horizontal.png ADDED
train/brush/brush_small_vertical.png ADDED
train/data/__init__.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import importlib
14
+ import torch.utils.data
15
+ from data.base_dataset import BaseDataset
16
+
17
+
18
+ def find_dataset_using_name(dataset_name):
19
+ """Import the module "data/[dataset_name]_dataset.py".
20
+
21
+ In the file, the class called DatasetNameDataset() will
22
+ be instantiated. It has to be a subclass of BaseDataset,
23
+ and it is case-insensitive.
24
+ """
25
+ dataset_filename = "data." + dataset_name + "_dataset"
26
+ datasetlib = importlib.import_module(dataset_filename)
27
+
28
+ dataset = None
29
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30
+ for name, cls in datasetlib.__dict__.items():
31
+ if name.lower() == target_dataset_name.lower() \
32
+ and issubclass(cls, BaseDataset):
33
+ dataset = cls
34
+
35
+ if dataset is None:
36
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37
+
38
+ return dataset
39
+
40
+
41
+ def get_option_setter(dataset_name):
42
+ """Return the static method <modify_commandline_options> of the dataset class."""
43
+ dataset_class = find_dataset_using_name(dataset_name)
44
+ return dataset_class.modify_commandline_options
45
+
46
+
47
+ def create_dataset(opt):
48
+ """Create a dataset given the option.
49
+
50
+ This function wraps the class CustomDatasetDataLoader.
51
+ This is the main interface between this package and 'train.py'/'test.py'
52
+
53
+ Example:
54
+ >>> from data import create_dataset
55
+ >>> dataset = create_dataset(opt)
56
+ """
57
+ data_loader = CustomDatasetDataLoader(opt)
58
+ dataset = data_loader.load_data()
59
+ return dataset
60
+
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ print("dataset [%s] was created" % type(self.dataset).__name__)
75
+ self.dataloader = torch.utils.data.DataLoader(
76
+ self.dataset,
77
+ batch_size=opt.batch_size,
78
+ shuffle=not opt.serial_batches,
79
+ num_workers=int(opt.num_threads),
80
+ drop_last=True)
81
+
82
+ def load_data(self):
83
+ return self
84
+
85
+ def __len__(self):
86
+ """Return the number of data in the dataset"""
87
+ return min(len(self.dataset), self.opt.max_dataset_size)
88
+
89
+ def __iter__(self):
90
+ """Return a batch of data"""
91
+ for i, data in enumerate(self.dataloader):
92
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
93
+ break
94
+ yield data
train/data/base_dataset.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ self.root = opt.dataroot
31
+
32
+ @staticmethod
33
+ def modify_commandline_options(parser, is_train):
34
+ """Add new dataset-specific options, and rewrite default values for existing options.
35
+
36
+ Parameters:
37
+ parser -- original option parser
38
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
39
+
40
+ Returns:
41
+ the modified parser.
42
+ """
43
+ return parser
44
+
45
+ @abstractmethod
46
+ def __len__(self):
47
+ """Return the total number of images in the dataset."""
48
+ return 0
49
+
50
+ @abstractmethod
51
+ def __getitem__(self, index):
52
+ """Return a data point and its metadata information.
53
+
54
+ Parameters:
55
+ index - - a random integer for data indexing
56
+
57
+ Returns:
58
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
59
+ """
60
+ pass
61
+
62
+
63
+ def get_params(opt, size):
64
+ w, h = size
65
+ new_h = h
66
+ new_w = w
67
+ if opt.preprocess == 'resize_and_crop':
68
+ new_h = new_w = opt.load_size
69
+ elif opt.preprocess == 'scale_width_and_crop':
70
+ new_w = opt.load_size
71
+ new_h = opt.load_size * h // w
72
+
73
+ x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
74
+ y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
75
+
76
+ flip = random.random() > 0.5
77
+
78
+ return {'crop_pos': (x, y), 'flip': flip}
79
+
80
+
81
+ def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
82
+ transform_list = []
83
+ if grayscale:
84
+ transform_list.append(transforms.Grayscale(1))
85
+ if 'resize' in opt.preprocess:
86
+ osize = [opt.load_size, opt.load_size]
87
+ transform_list.append(transforms.Resize(osize, method))
88
+ elif 'scale_width' in opt.preprocess:
89
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
90
+
91
+ if 'crop' in opt.preprocess:
92
+ if params is None:
93
+ transform_list.append(transforms.RandomCrop(opt.crop_size))
94
+ else:
95
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
96
+
97
+ if opt.preprocess == 'none':
98
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
99
+
100
+ if not opt.no_flip:
101
+ if params is None:
102
+ transform_list.append(transforms.RandomHorizontalFlip())
103
+ elif params['flip']:
104
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
105
+
106
+ if convert:
107
+ transform_list += [transforms.ToTensor()]
108
+ return transforms.Compose(transform_list)
109
+
110
+
111
+ def __make_power_2(img, base, method=Image.BICUBIC):
112
+ ow, oh = img.size
113
+ h = int(round(oh / base) * base)
114
+ w = int(round(ow / base) * base)
115
+ if h == oh and w == ow:
116
+ return img
117
+
118
+ __print_size_warning(ow, oh, w, h)
119
+ return img.resize((w, h), method)
120
+
121
+
122
+ def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
123
+ ow, oh = img.size
124
+ if ow == target_size and oh >= crop_size:
125
+ return img
126
+ w = target_size
127
+ h = int(max(target_size * oh / ow, crop_size))
128
+ return img.resize((w, h), method)
129
+
130
+
131
+ def __crop(img, pos, size):
132
+ ow, oh = img.size
133
+ x1, y1 = pos
134
+ tw = th = size
135
+ if (ow > tw or oh > th):
136
+ return img.crop((x1, y1, x1 + tw, y1 + th))
137
+ return img
138
+
139
+
140
+ def __flip(img, flip):
141
+ if flip:
142
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
143
+ return img
144
+
145
+
146
+ def __print_size_warning(ow, oh, w, h):
147
+ """Print warning information about image size(only print once)"""
148
+ if not hasattr(__print_size_warning, 'has_printed'):
149
+ print("The image size needs to be a multiple of 4. "
150
+ "The loaded image size was (%d, %d), so it was adjusted to "
151
+ "(%d, %d). This adjustment will be done to all images "
152
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
153
+ __print_size_warning.has_printed = True
train/data/null_dataset.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.base_dataset import BaseDataset
2
+ import os
3
+
4
+
5
+ class NullDataset(BaseDataset):
6
+
7
+ def __init__(self, opt):
8
+ BaseDataset.__init__(self, opt)
9
+
10
+ def __getitem__(self, index):
11
+ return {'A_paths': os.path.join(self.opt.dataroot, '%d.jpg' % index)}
12
+
13
+ def __len__(self):
14
+ """Return the total number of images in the dataset."""
15
+ return self.opt.max_dataset_size
train/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
train/models/base_model.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from . import networks
6
+
7
+
8
+ class BaseModel(ABC):
9
+ """This class is an abstract base class (ABC) for models.
10
+ To create a subclass, you need to implement the following five functions:
11
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
+ -- <set_input>: unpack data from dataset and apply preprocessing.
13
+ -- <forward>: produce intermediate results.
14
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ """Initialize the BaseModel class.
20
+
21
+ Parameters:
22
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
+
24
+ When creating your custom class, you need to implement your own initialization.
25
+ In this function, you should first call <BaseModel.__init__(self, opt)>
26
+ Then, you need to define four lists:
27
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
+ -- self.model_names (str list): define networks used in our training.
29
+ -- self.visual_names (str list): specify the images that you want to display and save.
30
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
+ """
32
+ self.opt = opt
33
+ self.gpu_ids = opt.gpu_ids
34
+ self.isTrain = opt.isTrain
35
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38
+ torch.backends.cudnn.benchmark = True
39
+ self.loss_names = []
40
+ self.model_names = []
41
+ self.visual_names = []
42
+ self.optimizers = []
43
+ self.image_paths = []
44
+ self.metric = 0 # used for learning rate policy 'plateau'
45
+
46
+ @staticmethod
47
+ def modify_commandline_options(parser, is_train):
48
+ """Add new model-specific options, and rewrite default values for existing options.
49
+
50
+ Parameters:
51
+ parser -- original option parser
52
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
53
+
54
+ Returns:
55
+ the modified parser.
56
+ """
57
+ return parser
58
+
59
+ @abstractmethod
60
+ def set_input(self, input):
61
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
62
+
63
+ Parameters:
64
+ input (dict): includes the data itself and its metadata information.
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def forward(self):
70
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
71
+ pass
72
+
73
+ @abstractmethod
74
+ def optimize_parameters(self):
75
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
76
+ pass
77
+
78
+ def setup(self, opt):
79
+ """Load and print networks; create schedulers
80
+
81
+ Parameters:
82
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
83
+ """
84
+ if self.isTrain:
85
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
86
+ if not self.isTrain or opt.continue_train:
87
+ load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
88
+ self.load_networks(load_suffix)
89
+ self.print_networks(opt.verbose)
90
+
91
+ def eval(self):
92
+ """Make models eval mode during test time"""
93
+ for name in self.model_names:
94
+ if isinstance(name, str):
95
+ net = getattr(self, 'net_' + name)
96
+ net.eval()
97
+
98
+ def test(self):
99
+ """Forward function used in test time.
100
+
101
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
102
+ It also calls <compute_visuals> to produce additional visualization results
103
+ """
104
+ with torch.no_grad():
105
+ self.forward()
106
+ self.compute_visuals()
107
+
108
+ def compute_visuals(self):
109
+ """Calculate additional output images for visdom and HTML visualization"""
110
+ pass
111
+
112
+ def get_image_paths(self):
113
+ """ Return image paths that are used to load current data"""
114
+ return self.image_paths
115
+
116
+ def update_learning_rate(self):
117
+ """Update learning rates for all the networks; called at the end of every epoch"""
118
+ old_lr = self.optimizers[0].param_groups[0]['lr']
119
+ for scheduler in self.schedulers:
120
+ if self.opt.lr_policy == 'plateau':
121
+ scheduler.step(self.metric)
122
+ else:
123
+ scheduler.step()
124
+
125
+ lr = self.optimizers[0].param_groups[0]['lr']
126
+ print('learning rate %.7f -> %.7f' % (old_lr, lr))
127
+
128
+ def get_current_visuals(self):
129
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
130
+ visual_ret = OrderedDict()
131
+ for name in self.visual_names:
132
+ if isinstance(name, str):
133
+ visual_ret[name] = getattr(self, name)
134
+ return visual_ret
135
+
136
+ def get_current_losses(self):
137
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
138
+ errors_ret = OrderedDict()
139
+ for name in self.loss_names:
140
+ if isinstance(name, str):
141
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
142
+ return errors_ret
143
+
144
+ def save_networks(self, epoch):
145
+ """Save all the networks to the disk.
146
+
147
+ Parameters:
148
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
149
+ """
150
+ for name in self.model_names:
151
+ if isinstance(name, str):
152
+ save_filename = '%s_net_%s.pth' % (epoch, name)
153
+ save_path = os.path.join(self.save_dir, save_filename)
154
+ net = getattr(self, 'net_' + name)
155
+
156
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
157
+ torch.save(net.module.cpu().state_dict(), save_path)
158
+ net.cuda(self.gpu_ids[0])
159
+ else:
160
+ torch.save(net.cpu().state_dict(), save_path)
161
+
162
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
163
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
164
+ key = keys[i]
165
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
166
+ if module.__class__.__name__.startswith('InstanceNorm') and \
167
+ (key == 'running_mean' or key == 'running_var'):
168
+ if getattr(module, key) is None:
169
+ state_dict.pop('.'.join(keys))
170
+ if module.__class__.__name__.startswith('InstanceNorm') and \
171
+ (key == 'num_batches_tracked'):
172
+ state_dict.pop('.'.join(keys))
173
+ else:
174
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
175
+
176
+ def load_networks(self, epoch):
177
+ """Load all the networks from the disk.
178
+
179
+ Parameters:
180
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
181
+ """
182
+ for name in self.model_names:
183
+ if isinstance(name, str):
184
+ load_filename = '%s_net_%s.pth' % (epoch, name)
185
+ load_path = os.path.join(self.save_dir, load_filename)
186
+ net = getattr(self, 'net_' + name)
187
+ if isinstance(net, torch.nn.DataParallel):
188
+ net = net.module
189
+ print('loading the model from %s' % load_path)
190
+ # if you are using PyTorch newer than 0.4 (e.g., built from
191
+ # GitHub source), you can remove str() on self.device
192
+ state_dict = torch.load(load_path, map_location=str(self.device))
193
+ if hasattr(state_dict, '_metadata'):
194
+ del state_dict._metadata
195
+
196
+ # patch InstanceNorm checkpoints prior to 0.4
197
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
198
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
199
+ net.load_state_dict(state_dict)
200
+
201
+ def print_networks(self, verbose):
202
+ """Print the total number of parameters in the network and (if verbose) network architecture
203
+
204
+ Parameters:
205
+ verbose (bool) -- if verbose: print the network architecture
206
+ """
207
+ print('---------- Networks initialized -------------')
208
+ for name in self.model_names:
209
+ if isinstance(name, str):
210
+ net = getattr(self, 'net_' + name)
211
+ num_params = 0
212
+ for param in net.parameters():
213
+ num_params += param.numel()
214
+ if verbose:
215
+ print(net)
216
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
217
+ print('-----------------------------------------------')
218
+
219
+ def set_requires_grad(self, nets, requires_grad=False):
220
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
221
+ Parameters:
222
+ nets (network list) -- a list of networks
223
+ requires_grad (bool) -- whether the networks require gradients or not
224
+ """
225
+ if not isinstance(nets, list):
226
+ nets = [nets]
227
+ for net in nets:
228
+ if net is not None:
229
+ for param in net.parameters():
230
+ param.requires_grad = requires_grad
train/models/networks.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ from torch.optim import lr_scheduler
5
+
6
+
7
+ def get_scheduler(optimizer, opt):
8
+ if opt.lr_policy == 'linear':
9
+ def lambda_rule(epoch):
10
+ # lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
11
+ lr_l = 0.3 ** max(0, (epoch + opt.epoch_count - opt.n_epochs) // 5)
12
+ return lr_l
13
+
14
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
15
+ elif opt.lr_policy == 'step':
16
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
17
+ elif opt.lr_policy == 'plateau':
18
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
19
+ elif opt.lr_policy == 'cosine':
20
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
21
+ else:
22
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
23
+ return scheduler
24
+
25
+
26
+ def init_weights(net, init_type='normal', init_gain=0.02):
27
+ def init_func(m):
28
+ classname = m.__class__.__name__
29
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
30
+ if init_type == 'normal':
31
+ init.normal_(m.weight.data, 0.0, init_gain)
32
+ elif init_type == 'xavier':
33
+ init.xavier_normal_(m.weight.data, gain=init_gain)
34
+ elif init_type == 'kaiming':
35
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
36
+ elif init_type == 'orthogonal':
37
+ init.orthogonal_(m.weight.data, gain=init_gain)
38
+ else:
39
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
40
+ if hasattr(m, 'bias') and m.bias is not None:
41
+ init.constant_(m.bias.data, 0.0)
42
+ elif classname.find('BatchNorm2d') != -1:
43
+ init.normal_(m.weight.data, 1.0, init_gain)
44
+ init.constant_(m.bias.data, 0.0)
45
+
46
+ print('initialize network with %s' % init_type)
47
+ net.apply(init_func)
48
+
49
+
50
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=()):
51
+ if len(gpu_ids) > 0:
52
+ assert (torch.cuda.is_available())
53
+ net.to(gpu_ids[0])
54
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
55
+ init_weights(net, init_type, init_gain=init_gain)
56
+ return net
57
+
58
+
59
+ class SignWithSigmoidGrad(torch.autograd.Function):
60
+
61
+ @staticmethod
62
+ def forward(ctx, x):
63
+ result = (x > 0).float()
64
+ sigmoid_result = torch.sigmoid(x)
65
+ ctx.save_for_backward(sigmoid_result)
66
+ return result
67
+
68
+ @staticmethod
69
+ def backward(ctx, grad_result):
70
+ (sigmoid_result,) = ctx.saved_tensors
71
+ if ctx.needs_input_grad[0]:
72
+ grad_input = grad_result * sigmoid_result * (1 - sigmoid_result)
73
+ else:
74
+ grad_input = None
75
+ return grad_input
76
+
77
+
78
+ class Painter(nn.Module):
79
+
80
+ def __init__(self, param_per_stroke, total_strokes, hidden_dim, n_heads=8, n_enc_layers=3, n_dec_layers=3):
81
+ super().__init__()
82
+ self.enc_img = nn.Sequential(
83
+ nn.ReflectionPad2d(1),
84
+ nn.Conv2d(3, 32, 3, 1),
85
+ nn.BatchNorm2d(32),
86
+ nn.ReLU(True),
87
+ nn.ReflectionPad2d(1),
88
+ nn.Conv2d(32, 64, 3, 2),
89
+ nn.BatchNorm2d(64),
90
+ nn.ReLU(True),
91
+ nn.ReflectionPad2d(1),
92
+ nn.Conv2d(64, 128, 3, 2),
93
+ nn.BatchNorm2d(128),
94
+ nn.ReLU(True))
95
+ self.enc_canvas = nn.Sequential(
96
+ nn.ReflectionPad2d(1),
97
+ nn.Conv2d(3, 32, 3, 1),
98
+ nn.BatchNorm2d(32),
99
+ nn.ReLU(True),
100
+ nn.ReflectionPad2d(1),
101
+ nn.Conv2d(32, 64, 3, 2),
102
+ nn.BatchNorm2d(64),
103
+ nn.ReLU(True),
104
+ nn.ReflectionPad2d(1),
105
+ nn.Conv2d(64, 128, 3, 2),
106
+ nn.BatchNorm2d(128),
107
+ nn.ReLU(True))
108
+ self.conv = nn.Conv2d(128 * 2, hidden_dim, 1)
109
+ self.transformer = nn.Transformer(hidden_dim, n_heads, n_enc_layers, n_dec_layers)
110
+ self.linear_param = nn.Sequential(
111
+ nn.Linear(hidden_dim, hidden_dim),
112
+ nn.ReLU(True),
113
+ nn.Linear(hidden_dim, hidden_dim),
114
+ nn.ReLU(True),
115
+ nn.Linear(hidden_dim, param_per_stroke))
116
+ self.linear_decider = nn.Linear(hidden_dim, 1)
117
+ self.query_pos = nn.Parameter(torch.rand(total_strokes, hidden_dim))
118
+ self.row_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
119
+ self.col_embed = nn.Parameter(torch.rand(8, hidden_dim // 2))
120
+
121
+ def forward(self, img, canvas):
122
+ b, _, H, W = img.shape
123
+ img_feat = self.enc_img(img)
124
+ canvas_feat = self.enc_canvas(canvas)
125
+ h, w = img_feat.shape[-2:]
126
+ feat = torch.cat([img_feat, canvas_feat], dim=1)
127
+ feat_conv = self.conv(feat)
128
+
129
+ pos_embed = torch.cat([
130
+ self.col_embed[:w].unsqueeze(0).contiguous().repeat(h, 1, 1),
131
+ self.row_embed[:h].unsqueeze(1).contiguous().repeat(1, w, 1),
132
+ ], dim=-1).flatten(0, 1).unsqueeze(1)
133
+ hidden_state = self.transformer(pos_embed + feat_conv.flatten(2).permute(2, 0, 1).contiguous(),
134
+ self.query_pos.unsqueeze(1).contiguous().repeat(1, b, 1))
135
+ hidden_state = hidden_state.permute(1, 0, 2).contiguous()
136
+ param = self.linear_param(hidden_state)
137
+ s = hidden_state.shape[1]
138
+ grid = param[:, :, :2].view(b * s, 1, 1, 2).contiguous()
139
+ img_temp = img.unsqueeze(1).contiguous().repeat(1, s, 1, 1, 1).view(b * s, 3, H, W).contiguous()
140
+ color = nn.functional.grid_sample(img_temp, 2 * grid - 1, align_corners=False).view(b, s, 3).contiguous()
141
+ decision = self.linear_decider(hidden_state)
142
+ return torch.cat([param, color, color, torch.rand(b, s, 1, device=img.device)], dim=-1), decision
143
+
train/models/painter_model.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .base_model import BaseModel
4
+ from . import networks
5
+ from util import morphology
6
+ from scipy.optimize import linear_sum_assignment
7
+ from PIL import Image
8
+
9
+
10
+ class PainterModel(BaseModel):
11
+
12
+ @staticmethod
13
+ def modify_commandline_options(parser, is_train=True):
14
+ parser.set_defaults(dataset_mode='null')
15
+ parser.add_argument('--used_strokes', type=int, default=8,
16
+ help='actually generated strokes number')
17
+ parser.add_argument('--num_blocks', type=int, default=3,
18
+ help='number of transformer blocks for stroke generator')
19
+ parser.add_argument('--lambda_w', type=float, default=10.0, help='weight for w loss of stroke shape')
20
+ parser.add_argument('--lambda_pixel', type=float, default=10.0, help='weight for pixel-level L1 loss')
21
+ parser.add_argument('--lambda_gt', type=float, default=1.0, help='weight for ground-truth loss')
22
+ parser.add_argument('--lambda_decision', type=float, default=10.0, help='weight for stroke decision loss')
23
+ parser.add_argument('--lambda_recall', type=float, default=10.0, help='weight of recall for stroke decision loss')
24
+ return parser
25
+
26
+ def __init__(self, opt):
27
+ BaseModel.__init__(self, opt)
28
+ self.loss_names = ['pixel', 'gt', 'w', 'decision']
29
+ self.visual_names = ['old', 'render', 'rec']
30
+ self.model_names = ['g']
31
+ self.d = 12 # xc, yc, w, h, theta, R0, G0, B0, R2, G2, B2, A
32
+ self.d_shape = 5
33
+
34
+ def read_img(img_path, img_type='RGB'):
35
+ img = Image.open(img_path).convert(img_type)
36
+ img = np.array(img)
37
+ if img.ndim == 2:
38
+ img = np.expand_dims(img, axis=-1)
39
+ img = img.transpose((2, 0, 1))
40
+ img = torch.from_numpy(img).unsqueeze(0).float() / 255.
41
+ return img
42
+
43
+ brush_large_vertical = read_img('brush/brush_large_vertical.png', 'L').to(self.device)
44
+ brush_large_horizontal = read_img('brush/brush_large_horizontal.png', 'L').to(self.device)
45
+ self.meta_brushes = torch.cat(
46
+ [brush_large_vertical, brush_large_horizontal], dim=0)
47
+ net_g = networks.Painter(self.d_shape, opt.used_strokes, opt.ngf,
48
+ n_enc_layers=opt.num_blocks, n_dec_layers=opt.num_blocks)
49
+ self.net_g = networks.init_net(net_g, opt.init_type, opt.init_gain, self.gpu_ids)
50
+ self.old = None
51
+ self.render = None
52
+ self.rec = None
53
+ self.gt_param = None
54
+ self.pred_param = None
55
+ self.gt_decision = None
56
+ self.pred_decision = None
57
+ self.patch_size = 32
58
+ self.loss_pixel = torch.tensor(0., device=self.device)
59
+ self.loss_gt = torch.tensor(0., device=self.device)
60
+ self.loss_w = torch.tensor(0., device=self.device)
61
+ self.loss_decision = torch.tensor(0., device=self.device)
62
+ self.criterion_pixel = torch.nn.L1Loss().to(self.device)
63
+ self.criterion_decision = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(opt.lambda_recall)).to(self.device)
64
+ if self.isTrain:
65
+ self.optimizer = torch.optim.Adam(self.net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
66
+ self.optimizers.append(self.optimizer)
67
+
68
+ def param2stroke(self, param, H, W):
69
+ # param: b, 12
70
+ b = param.shape[0]
71
+ param_list = torch.split(param, 1, dim=1)
72
+ x0, y0, w, h, theta = [item.squeeze(-1) for item in param_list[:5]]
73
+ R0, G0, B0, R2, G2, B2, _ = param_list[5:]
74
+ sin_theta = torch.sin(torch.acos(torch.tensor(-1., device=param.device)) * theta)
75
+ cos_theta = torch.cos(torch.acos(torch.tensor(-1., device=param.device)) * theta)
76
+ index = torch.full((b,), -1, device=param.device)
77
+ index[h > w] = 0
78
+ index[h <= w] = 1
79
+ brush = self.meta_brushes[index.long()]
80
+ alphas = torch.cat([brush, brush, brush], dim=1)
81
+ alphas = (alphas > 0).float()
82
+ t = torch.arange(0, brush.shape[2], device=param.device).unsqueeze(0) / brush.shape[2]
83
+ color_map = torch.stack([R0 * (1 - t) + R2 * t, G0 * (1 - t) + G2 * t, B0 * (1 - t) + B2 * t], dim=1)
84
+ color_map = color_map.unsqueeze(-1).repeat(1, 1, 1, brush.shape[3])
85
+ brush = brush * color_map
86
+
87
+ warp_00 = cos_theta / w
88
+ warp_01 = sin_theta * H / (W * w)
89
+ warp_02 = (1 - 2 * x0) * cos_theta / w + (1 - 2 * y0) * sin_theta * H / (W * w)
90
+ warp_10 = -sin_theta * W / (H * h)
91
+ warp_11 = cos_theta / h
92
+ warp_12 = (1 - 2 * y0) * cos_theta / h - (1 - 2 * x0) * sin_theta * W / (H * h)
93
+ warp_0 = torch.stack([warp_00, warp_01, warp_02], dim=1)
94
+ warp_1 = torch.stack([warp_10, warp_11, warp_12], dim=1)
95
+ warp = torch.stack([warp_0, warp_1], dim=1)
96
+ grid = torch.nn.functional.affine_grid(warp, torch.Size((b, 3, H, W)), align_corners=False)
97
+ brush = torch.nn.functional.grid_sample(brush, grid, align_corners=False)
98
+ alphas = torch.nn.functional.grid_sample(alphas, grid, align_corners=False)
99
+
100
+ return brush, alphas
101
+
102
+ def set_input(self, input_dict):
103
+ self.image_paths = input_dict['A_paths']
104
+ with torch.no_grad():
105
+ old_param = torch.rand(self.opt.batch_size // 4, self.opt.used_strokes, self.d, device=self.device)
106
+ old_param[:, :, :4] = old_param[:, :, :4] * 0.5 + 0.2
107
+ old_param[:, :, -4:-1] = old_param[:, :, -7:-4]
108
+ old_param = old_param.view(-1, self.d).contiguous()
109
+ foregrounds, alphas = self.param2stroke(old_param, self.patch_size * 2, self.patch_size * 2)
110
+ foregrounds = morphology.Dilation2d(m=1)(foregrounds)
111
+ alphas = morphology.Erosion2d(m=1)(alphas)
112
+ foregrounds = foregrounds.view(self.opt.batch_size // 4, self.opt.used_strokes, 3, self.patch_size * 2,
113
+ self.patch_size * 2).contiguous()
114
+ alphas = alphas.view(self.opt.batch_size // 4, self.opt.used_strokes, 3, self.patch_size * 2,
115
+ self.patch_size * 2).contiguous()
116
+ old = torch.zeros(self.opt.batch_size // 4, 3, self.patch_size * 2, self.patch_size * 2, device=self.device)
117
+ for i in range(self.opt.used_strokes):
118
+ foreground = foregrounds[:, i, :, :, :]
119
+ alpha = alphas[:, i, :, :, :]
120
+ old = foreground * alpha + old * (1 - alpha)
121
+ old = old.view(self.opt.batch_size // 4, 3, 2, self.patch_size, 2, self.patch_size).contiguous()
122
+ old = old.permute(0, 2, 4, 1, 3, 5).contiguous()
123
+ self.old = old.view(self.opt.batch_size, 3, self.patch_size, self.patch_size).contiguous()
124
+
125
+ gt_param = torch.rand(self.opt.batch_size, self.opt.used_strokes, self.d, device=self.device)
126
+ gt_param[:, :, :4] = gt_param[:, :, :4] * 0.5 + 0.2
127
+ gt_param[:, :, -4:-1] = gt_param[:, :, -7:-4]
128
+ self.gt_param = gt_param[:, :, :self.d_shape]
129
+ gt_param = gt_param.view(-1, self.d).contiguous()
130
+ foregrounds, alphas = self.param2stroke(gt_param, self.patch_size, self.patch_size)
131
+ foregrounds = morphology.Dilation2d(m=1)(foregrounds)
132
+ alphas = morphology.Erosion2d(m=1)(alphas)
133
+ foregrounds = foregrounds.view(self.opt.batch_size, self.opt.used_strokes, 3, self.patch_size,
134
+ self.patch_size).contiguous()
135
+ alphas = alphas.view(self.opt.batch_size, self.opt.used_strokes, 3, self.patch_size,
136
+ self.patch_size).contiguous()
137
+ self.render = self.old.clone()
138
+ gt_decision = torch.ones(self.opt.batch_size, self.opt.used_strokes, device=self.device)
139
+ for i in range(self.opt.used_strokes):
140
+ foreground = foregrounds[:, i, :, :, :]
141
+ alpha = alphas[:, i, :, :, :]
142
+ for j in range(i):
143
+ iou = (torch.sum(alpha * alphas[:, j, :, :, :], dim=(-3, -2, -1)) + 1e-5) / (
144
+ torch.sum(alphas[:, j, :, :, :], dim=(-3, -2, -1)) + 1e-5)
145
+ gt_decision[:, i] = ((iou < 0.75) | (~gt_decision[:, j].bool())).float() * gt_decision[:, i]
146
+ decision = gt_decision[:, i].view(self.opt.batch_size, 1, 1, 1).contiguous()
147
+ self.render = foreground * alpha * decision + self.render * (1 - alpha * decision)
148
+ self.gt_decision = gt_decision
149
+
150
+ def forward(self):
151
+ param, decisions = self.net_g(self.render, self.old)
152
+ # stroke_param: b, stroke_per_patch, param_per_stroke
153
+ # decision: b, stroke_per_patch, 1
154
+ self.pred_decision = decisions.view(-1, self.opt.used_strokes).contiguous()
155
+ self.pred_param = param[:, :, :self.d_shape]
156
+ param = param.view(-1, self.d).contiguous()
157
+ foregrounds, alphas = self.param2stroke(param, self.patch_size, self.patch_size)
158
+ foregrounds = morphology.Dilation2d(m=1)(foregrounds)
159
+ alphas = morphology.Erosion2d(m=1)(alphas)
160
+ # foreground, alpha: b * stroke_per_patch, 3, output_size, output_size
161
+ foregrounds = foregrounds.view(-1, self.opt.used_strokes, 3, self.patch_size, self.patch_size)
162
+ alphas = alphas.view(-1, self.opt.used_strokes, 3, self.patch_size, self.patch_size)
163
+ # foreground, alpha: b, stroke_per_patch, 3, output_size, output_size
164
+ decisions = networks.SignWithSigmoidGrad.apply(decisions.view(-1, self.opt.used_strokes, 1, 1, 1).contiguous())
165
+ self.rec = self.old.clone()
166
+ for j in range(foregrounds.shape[1]):
167
+ foreground = foregrounds[:, j, :, :, :]
168
+ alpha = alphas[:, j, :, :, :]
169
+ decision = decisions[:, j, :, :, :]
170
+ self.rec = foreground * alpha * decision + self.rec * (1 - alpha * decision)
171
+
172
+ @staticmethod
173
+ def get_sigma_sqrt(w, h, theta):
174
+ sigma_00 = w * (torch.cos(theta) ** 2) / 2 + h * (torch.sin(theta) ** 2) / 2
175
+ sigma_01 = (w - h) * torch.cos(theta) * torch.sin(theta) / 2
176
+ sigma_11 = h * (torch.cos(theta) ** 2) / 2 + w * (torch.sin(theta) ** 2) / 2
177
+ sigma_0 = torch.stack([sigma_00, sigma_01], dim=-1)
178
+ sigma_1 = torch.stack([sigma_01, sigma_11], dim=-1)
179
+ sigma = torch.stack([sigma_0, sigma_1], dim=-2)
180
+ return sigma
181
+
182
+ @staticmethod
183
+ def get_sigma(w, h, theta):
184
+ sigma_00 = w * w * (torch.cos(theta) ** 2) / 4 + h * h * (torch.sin(theta) ** 2) / 4
185
+ sigma_01 = (w * w - h * h) * torch.cos(theta) * torch.sin(theta) / 4
186
+ sigma_11 = h * h * (torch.cos(theta) ** 2) / 4 + w * w * (torch.sin(theta) ** 2) / 4
187
+ sigma_0 = torch.stack([sigma_00, sigma_01], dim=-1)
188
+ sigma_1 = torch.stack([sigma_01, sigma_11], dim=-1)
189
+ sigma = torch.stack([sigma_0, sigma_1], dim=-2)
190
+ return sigma
191
+
192
+ def gaussian_w_distance(self, param_1, param_2):
193
+ mu_1, w_1, h_1, theta_1 = torch.split(param_1, (2, 1, 1, 1), dim=-1)
194
+ w_1 = w_1.squeeze(-1)
195
+ h_1 = h_1.squeeze(-1)
196
+ theta_1 = torch.acos(torch.tensor(-1., device=param_1.device)) * theta_1.squeeze(-1)
197
+ trace_1 = (w_1 ** 2 + h_1 ** 2) / 4
198
+ mu_2, w_2, h_2, theta_2 = torch.split(param_2, (2, 1, 1, 1), dim=-1)
199
+ w_2 = w_2.squeeze(-1)
200
+ h_2 = h_2.squeeze(-1)
201
+ theta_2 = torch.acos(torch.tensor(-1., device=param_2.device)) * theta_2.squeeze(-1)
202
+ trace_2 = (w_2 ** 2 + h_2 ** 2) / 4
203
+ sigma_1_sqrt = self.get_sigma_sqrt(w_1, h_1, theta_1)
204
+ sigma_2 = self.get_sigma(w_2, h_2, theta_2)
205
+ trace_12 = torch.matmul(torch.matmul(sigma_1_sqrt, sigma_2), sigma_1_sqrt)
206
+ trace_12 = torch.sqrt(trace_12[..., 0, 0] + trace_12[..., 1, 1] + 2 * torch.sqrt(
207
+ trace_12[..., 0, 0] * trace_12[..., 1, 1] - trace_12[..., 0, 1] * trace_12[..., 1, 0]))
208
+ return torch.sum((mu_1 - mu_2) ** 2, dim=-1) + trace_1 + trace_2 - 2 * trace_12
209
+
210
+ def optimize_parameters(self):
211
+ self.forward()
212
+ self.loss_pixel = self.criterion_pixel(self.rec, self.render) * self.opt.lambda_pixel
213
+ cur_valid_gt_size = 0
214
+ with torch.no_grad():
215
+ r_idx = []
216
+ c_idx = []
217
+ for i in range(self.gt_param.shape[0]):
218
+ is_valid_gt = self.gt_decision[i].bool()
219
+ valid_gt_param = self.gt_param[i, is_valid_gt]
220
+ cost_matrix_l1 = torch.cdist(self.pred_param[i], valid_gt_param, p=1)
221
+ pred_param_broad = self.pred_param[i].unsqueeze(1).contiguous().repeat(
222
+ 1, valid_gt_param.shape[0], 1)
223
+ valid_gt_param_broad = valid_gt_param.unsqueeze(0).contiguous().repeat(
224
+ self.pred_param.shape[1], 1, 1)
225
+ cost_matrix_w = self.gaussian_w_distance(pred_param_broad, valid_gt_param_broad)
226
+ decision = self.pred_decision[i]
227
+ cost_matrix_decision = (1 - decision).unsqueeze(-1).repeat(1, valid_gt_param.shape[0])
228
+ r, c = linear_sum_assignment((cost_matrix_l1 + cost_matrix_w + cost_matrix_decision).cpu())
229
+ r_idx.append(torch.tensor(r + self.pred_param.shape[1] * i, device=self.device))
230
+ c_idx.append(torch.tensor(c + cur_valid_gt_size, device=self.device))
231
+ cur_valid_gt_size += valid_gt_param.shape[0]
232
+ r_idx = torch.cat(r_idx, dim=0)
233
+ c_idx = torch.cat(c_idx, dim=0)
234
+ paired_gt_decision = torch.zeros(self.gt_decision.shape[0] * self.gt_decision.shape[1], device=self.device)
235
+ paired_gt_decision[r_idx] = 1.
236
+ all_valid_gt_param = self.gt_param[self.gt_decision.bool(), :]
237
+ all_pred_param = self.pred_param.view(-1, self.pred_param.shape[2]).contiguous()
238
+ all_pred_decision = self.pred_decision.view(-1).contiguous()
239
+ paired_gt_param = all_valid_gt_param[c_idx, :]
240
+ paired_pred_param = all_pred_param[r_idx, :]
241
+ self.loss_gt = self.criterion_pixel(paired_pred_param, paired_gt_param) * self.opt.lambda_gt
242
+ self.loss_w = self.gaussian_w_distance(paired_pred_param, paired_gt_param).mean() * self.opt.lambda_w
243
+ self.loss_decision = self.criterion_decision(all_pred_decision, paired_gt_decision) * self.opt.lambda_decision
244
+ loss = self.loss_pixel + self.loss_gt + self.loss_w + self.loss_decision
245
+ loss.backward()
246
+ self.optimizer.step()
247
+ self.optimizer.zero_grad()
train/options/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
train/options/base_options.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from util import util
4
+ import torch
5
+ import models
6
+ import data
7
+
8
+
9
+ class BaseOptions:
10
+ """This class defines options used during both training and test time.
11
+
12
+ It also implements several helper functions such as parsing, printing, and saving the options.
13
+ It also gathers additional options defined in <modify_commandline_options> functions
14
+ in both dataset class and model class.
15
+ """
16
+
17
+ def __init__(self):
18
+ """Reset the class; indicates the class hasn't been initialized"""
19
+ self.initialized = False
20
+
21
+ def initialize(self, parser):
22
+ """Define the common options that are used in both training and test."""
23
+ # basic parameters
24
+ parser.add_argument('--dataroot', default='.',
25
+ help='path to images (should have sub-folders trainA, trainB, valA, valB, etc)')
26
+ parser.add_argument('--name', type=str, default='experiment_name',
27
+ help='name of the experiment. It decides where to store samples and models')
28
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
29
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
30
+ # model parameters
31
+ parser.add_argument('--model', type=str, default='painter',
32
+ help='chooses which model to use.')
33
+ parser.add_argument('--input_nc', type=int, default=3,
34
+ help='# of input image channels: 3 for RGB and 1 for grayscale')
35
+ parser.add_argument('--output_nc', type=int, default=3,
36
+ help='# of output image channels: 3 for RGB and 1 for grayscale')
37
+ parser.add_argument('--ngf', type=int, default=256, help='# of gen filters in the first conv layer')
38
+ parser.add_argument('--layer_num', type=int, default=2, help='# of resnet block for generator')
39
+ parser.add_argument('--init_type', type=str, default='normal',
40
+ help='network initialization [normal | xavier | kaiming | orthogonal]')
41
+ parser.add_argument('--init_gain', type=float, default=0.02,
42
+ help='scaling factor for normal, xavier and orthogonal.')
43
+ # dataset parameters
44
+ parser.add_argument('--dataset_mode', type=str, default='single',
45
+ help='chooses how datasets are loaded.')
46
+ parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
47
+ parser.add_argument('--serial_batches', action='store_true',
48
+ help='if true, takes images in order to make batches, otherwise takes them randomly')
49
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
50
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
51
+ parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
52
+ parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
53
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
54
+ help='Maximum number of samples allowed per dataset. If the dataset directory contains '
55
+ 'more than max_dataset_size, only a subset is loaded.')
56
+ parser.add_argument('--preprocess', type=str, default='resize_and_crop',
57
+ help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | '
58
+ 'scale_width_and_crop | none]')
59
+ parser.add_argument('--no_flip', action='store_true',
60
+ help='if specified, do not flip the images for data augmentation')
61
+ parser.add_argument('--display_winsize', type=int, default=256,
62
+ help='display window size for both visdom and HTML')
63
+ # additional parameters
64
+ parser.add_argument('--epoch', type=str, default='latest',
65
+ help='which epoch to load? set to latest to use latest cached model')
66
+ parser.add_argument('--load_iter', type=int, default='0',
67
+ help='which iteration to load? if load_iter > 0, the code will load models by iter_['
68
+ 'load_iter]; otherwise, the code will load models by [epoch]')
69
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
70
+ parser.add_argument('--suffix', default='', type=str,
71
+ help='customized suffix: opt.name = opt.name + suffix')
72
+ self.initialized = True
73
+ return parser
74
+
75
+ def gather_options(self):
76
+ """Initialize our parser with basic options(only once).
77
+ Add additional model-specific and dataset-specific options.
78
+ These options are defined in the <modify_commandline_options> function
79
+ in model and dataset classes.
80
+ """
81
+ if not self.initialized: # check if it has been initialized
82
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
83
+ parser = self.initialize(parser)
84
+
85
+ # get the basic options
86
+ opt, _ = parser.parse_known_args()
87
+
88
+ # modify model-related parser options
89
+ model_name = opt.model
90
+ model_option_setter = models.get_option_setter(model_name)
91
+ parser = model_option_setter(parser, self.isTrain)
92
+ opt, _ = parser.parse_known_args() # parse again with new defaults
93
+
94
+ # modify dataset-related parser options
95
+ dataset_name = opt.dataset_mode
96
+ dataset_option_setter = data.get_option_setter(dataset_name)
97
+ parser = dataset_option_setter(parser, self.isTrain)
98
+
99
+ # save and return the parser
100
+ self.parser = parser
101
+ return parser.parse_args()
102
+
103
+ def print_options(self, opt):
104
+ """Print and save options
105
+
106
+ It will print both current options and default values(if different).
107
+ It will save options into a text file / [checkpoints_dir] / opt.txt
108
+ """
109
+ message = ''
110
+ message += '----------------- Options ---------------\n'
111
+ for k, v in sorted(vars(opt).items()):
112
+ comment = ''
113
+ default = self.parser.get_default(k)
114
+ if v != default:
115
+ comment = '\t[default: %s]' % str(default)
116
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
117
+ message += '----------------- End -------------------'
118
+ print(message)
119
+
120
+ # save to the disk
121
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
122
+ util.mkdirs(expr_dir)
123
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
124
+ with open(file_name, 'wt') as opt_file:
125
+ opt_file.write(message)
126
+ opt_file.write('\n')
127
+
128
+ def parse(self):
129
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
130
+ opt = self.gather_options()
131
+ opt.isTrain = self.isTrain # train or test
132
+
133
+ # process opt.suffix
134
+ if opt.suffix:
135
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
136
+ opt.name = opt.name + suffix
137
+
138
+ self.print_options(opt)
139
+
140
+ # set gpu ids
141
+ str_ids = opt.gpu_ids.split(',')
142
+ opt.gpu_ids = []
143
+ for str_id in str_ids:
144
+ id = int(str_id)
145
+ if id >= 0:
146
+ opt.gpu_ids.append(id)
147
+ if len(opt.gpu_ids) > 0:
148
+ torch.cuda.set_device(opt.gpu_ids[0])
149
+
150
+ self.opt = opt
151
+ return self.opt
train/options/test_options.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TestOptions(BaseOptions):
5
+ """This class includes test options.
6
+
7
+ It also includes shared options defined in BaseOptions.
8
+ """
9
+
10
+ def initialize(self, parser):
11
+ parser = BaseOptions.initialize(self, parser) # define shared options
12
+ parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
13
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
14
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
15
+ # Dropout and Batch norm has different behavior during training and test.
16
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
17
+ parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
18
+ # rewrite devalue values
19
+ parser.set_defaults(model='test')
20
+ # To avoid cropping, the load_size should be the same as crop_size
21
+ parser.set_defaults(load_size=parser.get_default('crop_size'))
22
+ self.isTrain = False
23
+ return parser
train/options/train_options.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TrainOptions(BaseOptions):
5
+ """This class includes training options.
6
+
7
+ It also includes shared options defined in BaseOptions.
8
+ """
9
+
10
+ def initialize(self, parser):
11
+ parser = BaseOptions.initialize(self, parser)
12
+ # visdom and HTML visualization parameters
13
+ parser.add_argument('--display_freq', type=int, default=40,
14
+ help='frequency of showing training results on screen')
15
+ parser.add_argument('--display_ncols', type=int, default=4,
16
+ help='if positive, display all images in a single visdom web panel '
17
+ 'with certain number of images per row.')
18
+ parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
19
+ parser.add_argument('--display_server', type=str, default="http://localhost",
20
+ help='visdom server of the web display')
21
+ parser.add_argument('--display_env', type=str, default='main',
22
+ help='visdom display environment name (default is "main")')
23
+ parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
24
+ parser.add_argument('--update_html_freq', type=int, default=1000,
25
+ help='frequency of saving training results to html')
26
+ parser.add_argument('--print_freq', type=int, default=10,
27
+ help='frequency of showing training results on console')
28
+ parser.add_argument('--no_html', action='store_true',
29
+ help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
30
+ # network saving and loading parameters
31
+ parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
32
+ parser.add_argument('--save_epoch_freq', type=int, default=5,
33
+ help='frequency of saving checkpoints at the end of epochs')
34
+ parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
35
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
36
+ parser.add_argument('--epoch_count', type=int, default=1,
37
+ help='the starting epoch count, we save the model '
38
+ 'by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
39
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
40
+ # training parameters
41
+ parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
42
+ parser.add_argument('--n_epochs_decay', type=int, default=100,
43
+ help='number of epochs to linearly decay learning rate to zero')
44
+ parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
45
+ parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
46
+ parser.add_argument('--lr_policy', type=str, default='linear',
47
+ help='learning rate policy. [linear | step | plateau | cosine]')
48
+ parser.add_argument('--lr_decay_iters', type=int, default=50,
49
+ help='multiply by a gamma every lr_decay_iters iterations')
50
+
51
+ self.isTrain = True
52
+ return parser
train/train.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from options.train_options import TrainOptions
3
+ from data import create_dataset
4
+ from models import create_model
5
+ from util.visualizer import Visualizer
6
+
7
+ if __name__ == '__main__':
8
+ opt = TrainOptions().parse() # get training options
9
+ dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
10
+ dataset_size = len(dataset) # get the number of images in the dataset.
11
+ print('The number of training images = %d' % dataset_size)
12
+
13
+ model = create_model(opt) # create a model given opt.model and other options
14
+ model.setup(opt) # regular setup: load and print networks; create schedulers
15
+ visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
16
+ total_iters = 0 # the total number of training iterations
17
+
18
+ for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):
19
+ epoch_start_time = time.time() # timer for entire epoch
20
+ iter_data_time = time.time() # timer for data loading per iteration
21
+ epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
22
+ visualizer.reset() # reset visualizer: make sure it saves results to HTML at least once every epoch
23
+ for i, data in enumerate(dataset): # inner loop within one epoch
24
+ iter_start_time = time.time() # timer for computation per iteration
25
+ if total_iters % opt.print_freq == 0:
26
+ t_data = iter_start_time - iter_data_time
27
+
28
+ total_iters += opt.batch_size
29
+ epoch_iter += opt.batch_size
30
+ model.set_input(data) # unpack data from dataset and apply preprocessing
31
+ model.optimize_parameters() # calculate loss functions, get gradients, update network weights
32
+
33
+ if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
34
+ save_result = total_iters % opt.update_html_freq == 0
35
+ model.compute_visuals()
36
+ visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
37
+
38
+ if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
39
+ losses = model.get_current_losses()
40
+ t_comp = (time.time() - iter_start_time) / opt.batch_size
41
+ visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
42
+ if opt.display_id > 0:
43
+ visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
44
+
45
+ if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
46
+ print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
47
+ save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
48
+ model.save_networks(save_suffix)
49
+
50
+ iter_data_time = time.time()
51
+ if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
52
+ print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
53
+ model.save_networks('latest')
54
+ model.save_networks(epoch)
55
+
56
+ print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay,
57
+ time.time() - epoch_start_time))
58
+ model.update_learning_rate() # update learning rates in the beginning of every epoch.
train/train.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python train.py \
2
+ --name painter \
3
+ --gpu_ids 0 \
4
+ --model painter \
5
+ --dataset_mode null \
6
+ --batch_size 64 \
7
+ --display_freq 25 \
8
+ --print_freq 25 \
9
+ --lr 1e-4 \
10
+ --init_type normal \
11
+ --n_epochs 200 \
12
+ --n_epochs_decay 20 \
13
+ --max_dataset_size 16384 \
14
+ --save_epoch_freq 20
train/util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """This package includes a miscellaneous collection of useful helper functions."""
train/util/html.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ """This HTML class allows us to save images and write texts into a single HTML file.
8
+
9
+ It consists of functions such as <add_header> (add a text header to the HTML file),
10
+ <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
11
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12
+ """
13
+
14
+ def __init__(self, web_dir, title, refresh=0):
15
+ """Initialize the HTML classes
16
+
17
+ Parameters:
18
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
19
+ title (str) -- the webpage name
20
+ refresh (int) -- how often the website refresh itself; if 0; no refreshing
21
+ """
22
+ self.title = title
23
+ self.web_dir = web_dir
24
+ self.img_dir = os.path.join(self.web_dir, 'images')
25
+ if not os.path.exists(self.web_dir):
26
+ os.makedirs(self.web_dir)
27
+ if not os.path.exists(self.img_dir):
28
+ os.makedirs(self.img_dir)
29
+
30
+ self.doc = dominate.document(title=title)
31
+ if refresh > 0:
32
+ with self.doc.head:
33
+ meta(http_equiv="refresh", content=str(refresh))
34
+
35
+ def get_image_dir(self):
36
+ """Return the directory that stores images"""
37
+ return self.img_dir
38
+
39
+ def add_header(self, text):
40
+ """Insert a header to the HTML file
41
+
42
+ Parameters:
43
+ text (str) -- the header text
44
+ """
45
+ with self.doc:
46
+ h3(text)
47
+
48
+ def add_images(self, ims, txts, links, width=400):
49
+ """add images to the HTML file
50
+
51
+ Parameters:
52
+ ims (str list) -- a list of image paths
53
+ txts (str list) -- a list of image names shown on the website
54
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55
+ """
56
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57
+ self.doc.add(self.t)
58
+ with self.t:
59
+ with tr():
60
+ for im, txt, link in zip(ims, txts, links):
61
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
62
+ with p():
63
+ with a(href=os.path.join('images', link)):
64
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
65
+ br()
66
+ p(txt)
67
+
68
+ def save(self):
69
+ """save the current content to the HMTL file"""
70
+ html_file = '%s/index.html' % self.web_dir
71
+ f = open(html_file, 'wt')
72
+ f.write(self.doc.render())
73
+ f.close()
74
+
75
+
76
+ if __name__ == '__main__': # we show an example usage here.
77
+ html = HTML('web/', 'test_html')
78
+ html.add_header('hello world')
79
+
80
+ ims, txts, links = [], [], []
81
+ for n in range(4):
82
+ ims.append('image_%d.png' % n)
83
+ txts.append('text_%d' % n)
84
+ links.append('image_%d.png' % n)
85
+ html.add_images(ims, txts, links)
86
+ html.save()
train/util/morphology.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Erosion2d(nn.Module):
7
+
8
+ def __init__(self, m=1):
9
+ super(Erosion2d, self).__init__()
10
+ self.m = m
11
+ self.pad = [m, m, m, m]
12
+ self.unfold = nn.Unfold(2 * m + 1, padding=0, stride=1)
13
+
14
+ def forward(self, x):
15
+ batch_size, c, h, w = x.shape
16
+ x_pad = F.pad(x, pad=self.pad, mode='constant', value=1e9)
17
+ for i in range(c):
18
+ channel = self.unfold(x_pad[:, [i], :, :])
19
+ channel = torch.min(channel, dim=1, keepdim=True)[0]
20
+ channel = channel.view([batch_size, 1, h, w])
21
+ x[:, [i], :, :] = channel
22
+
23
+ return x
24
+
25
+
26
+ class Dilation2d(nn.Module):
27
+
28
+ def __init__(self, m=1):
29
+ super(Dilation2d, self).__init__()
30
+ self.m = m
31
+ self.pad = [m, m, m, m]
32
+ self.unfold = nn.Unfold(2 * m + 1, padding=0, stride=1)
33
+
34
+ def forward(self, x):
35
+ batch_size, c, h, w = x.shape
36
+ x_pad = F.pad(x, pad=self.pad, mode='constant', value=-1e9)
37
+ for i in range(c):
38
+ channel = self.unfold(x_pad[:, [i], :, :])
39
+ channel = torch.max(channel, dim=1, keepdim=True)[0]
40
+ channel = channel.view([batch_size, 1, h, w])
41
+ x[:, [i], :, :] = channel
42
+
43
+ return x
train/util/util.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains simple helper functions """
2
+ from __future__ import print_function
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+
8
+
9
+ def tensor2im(input_image, imtype=np.uint8):
10
+ """"Converts a Tensor array into a numpy image array.
11
+
12
+ Parameters:
13
+ input_image (tensor) -- the input image tensor array
14
+ imtype (type) -- the desired type of the converted numpy array
15
+ """
16
+ if not isinstance(input_image, np.ndarray):
17
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
18
+ image_tensor = input_image.data
19
+ else:
20
+ return input_image
21
+ image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
22
+ if image_numpy.shape[0] == 1: # grayscale to RGB
23
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
24
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: transpose and scaling
25
+ else: # if it is a numpy array
26
+ image_numpy = input_image * 255.
27
+ return image_numpy.astype(imtype)
28
+
29
+
30
+ def diagnose_network(net, name='network'):
31
+ """Calculate and print the mean of average absolute(gradients)
32
+
33
+ Parameters:
34
+ net (torch network) -- Torch network
35
+ name (str) -- the name of the network
36
+ """
37
+ mean = 0.0
38
+ count = 0
39
+ for param in net.parameters():
40
+ if param.grad is not None:
41
+ mean += torch.mean(torch.abs(param.grad.data))
42
+ count += 1
43
+ if count > 0:
44
+ mean = mean / count
45
+ print(name)
46
+ print(mean)
47
+
48
+
49
+ def save_image(image_numpy, image_path, aspect_ratio=1.0):
50
+ """Save a numpy image to the disk
51
+
52
+ Parameters:
53
+ image_numpy (numpy array) -- input numpy array
54
+ image_path (str) -- the path of the image
55
+ """
56
+
57
+ image_pil = Image.fromarray(image_numpy)
58
+ h, w, _ = image_numpy.shape
59
+
60
+ if aspect_ratio > 1.0:
61
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
62
+ if aspect_ratio < 1.0:
63
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
64
+ image_pil.save(image_path)
65
+
66
+
67
+ def print_numpy(x, val=True, shp=False):
68
+ """Print the mean, min, max, median, std, and size of a numpy array
69
+
70
+ Parameters:
71
+ val (bool) -- if print the values of the numpy array
72
+ shp (bool) -- if print the shape of the numpy array
73
+ """
74
+ x = x.astype(np.float64)
75
+ if shp:
76
+ print('shape,', x.shape)
77
+ if val:
78
+ x = x.flatten()
79
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
80
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
81
+
82
+
83
+ def mkdirs(paths):
84
+ """create empty directories if they don't exist
85
+
86
+ Parameters:
87
+ paths (str list) -- a list of directory paths
88
+ """
89
+ if isinstance(paths, list) and not isinstance(paths, str):
90
+ for path in paths:
91
+ mkdir(path)
92
+ else:
93
+ mkdir(paths)
94
+
95
+
96
+ def mkdir(path):
97
+ """create a single empty directory if it didn't exist
98
+
99
+ Parameters:
100
+ path (str) -- a single directory path
101
+ """
102
+ if not os.path.exists(path):
103
+ os.makedirs(path)
train/util/visualizer.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ import ntpath
5
+ import time
6
+ from . import util, html
7
+ from subprocess import Popen, PIPE
8
+
9
+
10
+ if sys.version_info[0] == 2:
11
+ VisdomExceptionBase = Exception
12
+ else:
13
+ VisdomExceptionBase = ConnectionError
14
+
15
+
16
+ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
17
+ """Save images to the disk.
18
+
19
+ Parameters:
20
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
21
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
22
+ image_path (str) -- the string is used to create image paths
23
+ aspect_ratio (float) -- the aspect ratio of saved images
24
+ width (int) -- the images will be resized to width x width
25
+
26
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
27
+ """
28
+ image_dir = webpage.get_image_dir()
29
+ short_path = ntpath.basename(image_path[0])
30
+ name = os.path.splitext(short_path)[0]
31
+
32
+ webpage.add_header(name)
33
+ ims, txts, links = [], [], []
34
+
35
+ for label, im_data in visuals.items():
36
+ im = util.tensor2im(im_data)
37
+ image_name = '%s_%s.png' % (name, label)
38
+ save_path = os.path.join(image_dir, image_name)
39
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
40
+ ims.append(image_name)
41
+ txts.append(label)
42
+ links.append(image_name)
43
+ webpage.add_images(ims, txts, links, width=width)
44
+
45
+
46
+ class Visualizer:
47
+ """This class includes several functions that can display/save images and print/save logging information.
48
+
49
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating
50
+ HTML files with images.
51
+ """
52
+
53
+ def __init__(self, opt):
54
+ """Initialize the Visualizer class
55
+
56
+ Parameters:
57
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
58
+ Step 1: Cache the training/test options
59
+ Step 2: connect to a visdom server
60
+ Step 3: create an HTML object for saveing HTML filters
61
+ Step 4: create a logging file to store training losses
62
+ """
63
+ self.opt = opt # cache the option
64
+ self.display_id = opt.display_id
65
+ self.use_html = opt.isTrain and not opt.no_html
66
+ self.win_size = opt.display_winsize
67
+ self.name = opt.name
68
+ self.port = opt.display_port
69
+ self.saved = False
70
+ if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
71
+ import visdom
72
+ self.ncols = opt.display_ncols
73
+ self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
74
+ if not self.vis.check_connection():
75
+ self.create_visdom_connections()
76
+
77
+ if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under
78
+ # <checkpoints_dir>/web/images/
79
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
80
+ self.img_dir = os.path.join(self.web_dir, 'images')
81
+ print('create web directory %s...' % self.web_dir)
82
+ util.mkdirs([self.web_dir, self.img_dir])
83
+ # create a logging file to store training losses
84
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
85
+ with open(self.log_name, "a") as log_file:
86
+ now = time.strftime("%c")
87
+ log_file.write('================ Training Loss (%s) ================\n' % now)
88
+
89
+ def reset(self):
90
+ """Reset the self.saved status"""
91
+ self.saved = False
92
+
93
+ def create_visdom_connections(self):
94
+ """If the program could not connect to Visdom server, this function will start a new server at port <
95
+ self.port > """
96
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
97
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
98
+ print('Command: %s' % cmd)
99
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
100
+
101
+ def display_current_results(self, visuals, epoch, save_result):
102
+ """Display current results on visdom; save current results to an HTML file.
103
+
104
+ Parameters:
105
+ visuals (OrderedDict) - - dictionary of images to display or save
106
+ epoch (int) - - the current epoch
107
+ save_result (bool) - - if save the current results to an HTML file
108
+ """
109
+ if self.display_id > 0: # show images in the browser using visdom
110
+ ncols = self.ncols
111
+ if ncols > 0: # show all the images in one visdom panel
112
+ ncols = min(ncols, len(visuals))
113
+ h, w = next(iter(visuals.values())).shape[:2]
114
+ table_css = """<style>
115
+ table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
116
+ table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
117
+ </style>""" % (w, h) # create a table css
118
+ # create a table of images.
119
+ title = self.name
120
+ label_html = ''
121
+ label_html_row = ''
122
+ images = []
123
+ idx = 0
124
+ for label, image in visuals.items():
125
+ image_numpy = util.tensor2im(image)
126
+ label_html_row += '<td>%s</td>' % label
127
+ images.append(image_numpy.transpose([2, 0, 1]))
128
+ idx += 1
129
+ if idx % ncols == 0:
130
+ label_html += '<tr>%s</tr>' % label_html_row
131
+ label_html_row = ''
132
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
133
+ while idx % ncols != 0:
134
+ images.append(white_image)
135
+ label_html_row += '<td></td>'
136
+ idx += 1
137
+ if label_html_row != '':
138
+ label_html += '<tr>%s</tr>' % label_html_row
139
+ try:
140
+ self.vis.images(images, nrow=ncols, win=self.display_id + 1,
141
+ padding=2, opts=dict(title=title + ' images'))
142
+ label_html = '<table>%s</table>' % label_html
143
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
144
+ opts=dict(title=title + ' labels'))
145
+ except VisdomExceptionBase:
146
+ self.create_visdom_connections()
147
+
148
+ else: # show each image in a separate visdom panel;
149
+ idx = 1
150
+ try:
151
+ for label, image in visuals.items():
152
+ image_numpy = util.tensor2im(image)
153
+ self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
154
+ win=self.display_id + idx)
155
+ idx += 1
156
+ except VisdomExceptionBase:
157
+ self.create_visdom_connections()
158
+
159
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
160
+ self.saved = True
161
+ # save images to the disk
162
+ for label, image in visuals.items():
163
+ image_numpy = util.tensor2im(image)
164
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
165
+ util.save_image(image_numpy, img_path)
166
+
167
+ # update website
168
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
169
+ for n in range(epoch, 0, -1):
170
+ webpage.add_header('epoch [%d]' % n)
171
+ ims, txts, links = [], [], []
172
+
173
+ for label, image_numpy in visuals.items():
174
+ image_numpy = util.tensor2im(image)
175
+ img_path = 'epoch%.3d_%s.png' % (n, label)
176
+ ims.append(img_path)
177
+ txts.append(label)
178
+ links.append(img_path)
179
+ webpage.add_images(ims, txts, links, width=self.win_size)
180
+ webpage.save()
181
+
182
+ def plot_current_losses(self, epoch, counter_ratio, losses):
183
+ """display the current losses on visdom display: dictionary of error labels and values
184
+
185
+ Parameters:
186
+ epoch (int) -- current epoch
187
+ counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
188
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
189
+ """
190
+ if not hasattr(self, 'plot_data'):
191
+ self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
192
+ self.plot_data['X'].append(epoch + counter_ratio)
193
+ self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
194
+ try:
195
+ self.vis.line(
196
+ X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
197
+ Y=np.array(self.plot_data['Y']),
198
+ opts={
199
+ 'title': self.name + ' loss over time',
200
+ 'legend': self.plot_data['legend'],
201
+ 'xlabel': 'epoch',
202
+ 'ylabel': 'loss'},
203
+ win=self.display_id)
204
+ except VisdomExceptionBase:
205
+ self.create_visdom_connections()
206
+
207
+ # losses: same format as |losses| of plot_current_losses
208
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
209
+ """print current losses on console; also save the losses to the disk
210
+
211
+ Parameters:
212
+ epoch (int) -- current epoch
213
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
214
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
215
+ t_comp (float) -- computational time per data point (normalized by batch_size)
216
+ t_data (float) -- data loading time per data point (normalized by batch_size)
217
+ """
218
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
219
+ for k, v in losses.items():
220
+ message += '%s: %.3f ' % (k, v)
221
+
222
+ print(message) # print the message
223
+ with open(self.log_name, "a") as log_file:
224
+ log_file.write('%s\n' % message) # save the message