antoyo123 commited on
Commit
b0df2a8
1 Parent(s): f4beaa6

Upload 203 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. bin/analyze_errors.py +316 -0
  3. bin/blur_predicts.py +57 -0
  4. bin/calc_dataset_stats.py +88 -0
  5. bin/debug/analyze_overlapping_masks.sh +31 -0
  6. bin/evaluate_predicts.py +79 -0
  7. bin/evaluator_example.py +76 -0
  8. bin/extract_masks.py +63 -0
  9. bin/filter_sharded_dataset.py +69 -0
  10. bin/gen_debug_mask_dataset.py +61 -0
  11. bin/gen_mask_dataset.py +130 -0
  12. bin/gen_mask_dataset_hydra.py +124 -0
  13. bin/gen_outpainting_dataset.py +88 -0
  14. bin/make_checkpoint.py +79 -0
  15. bin/mask_example.py +14 -0
  16. bin/paper_runfiles/blur_tests.sh +37 -0
  17. bin/paper_runfiles/env.sh +8 -0
  18. bin/paper_runfiles/find_best_checkpoint.py +54 -0
  19. bin/paper_runfiles/generate_test_celeba-hq.sh +17 -0
  20. bin/paper_runfiles/generate_test_ffhq.sh +17 -0
  21. bin/paper_runfiles/generate_test_paris.sh +17 -0
  22. bin/paper_runfiles/generate_test_paris_256.sh +17 -0
  23. bin/paper_runfiles/generate_val_test.sh +28 -0
  24. bin/paper_runfiles/predict_inner_features.sh +20 -0
  25. bin/paper_runfiles/update_test_data_stats.sh +30 -0
  26. bin/predict.py +89 -0
  27. bin/predict_inner_features.py +119 -0
  28. bin/report_from_tb.py +83 -0
  29. bin/sample_from_dataset.py +87 -0
  30. bin/side_by_side.py +76 -0
  31. bin/split_tar.py +22 -0
  32. bin/train.py +72 -0
  33. configs/analyze_mask_errors.yaml +7 -0
  34. configs/data_gen/gen_segm_dataset1.yaml +25 -0
  35. configs/data_gen/gen_segm_dataset3.yaml +25 -0
  36. configs/data_gen/random_medium_256.yaml +33 -0
  37. configs/data_gen/random_medium_512.yaml +33 -0
  38. configs/data_gen/random_thick_256.yaml +33 -0
  39. configs/data_gen/random_thick_512.yaml +33 -0
  40. configs/data_gen/random_thin_256.yaml +25 -0
  41. configs/data_gen/random_thin_512.yaml +25 -0
  42. configs/data_gen/segm_256.yaml +27 -0
  43. configs/data_gen/segm_512.yaml +27 -0
  44. configs/data_gen/sr_256.yaml +25 -0
  45. configs/data_gen/whydra/location/mml-ws01-celeba-hq.yaml +5 -0
  46. configs/data_gen/whydra/location/mml-ws01-ffhq.yaml +5 -0
  47. configs/data_gen/whydra/location/mml-ws01-paris.yaml +5 -0
  48. configs/data_gen/whydra/location/mml7-places.yaml +5 -0
  49. configs/data_gen/whydra/random_medium_256.yaml +42 -0
  50. configs/data_gen/whydra/random_medium_512.yaml +42 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ saicinpainting/evaluation/masks/countless/images/gcim.jpg filter=lfs diff=lfs merge=lfs -text
bin/analyze_errors.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import cv2
3
+ import numpy as np
4
+ import sklearn
5
+ import torch
6
+ import os
7
+ import pickle
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ from joblib import Parallel, delayed
11
+
12
+ from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset, load_image
13
+ from saicinpainting.evaluation.losses.fid.inception import InceptionV3
14
+ from saicinpainting.evaluation.utils import load_yaml
15
+ from saicinpainting.training.visualizers.base import visualize_mask_and_images
16
+
17
+
18
+ def draw_score(img, score):
19
+ img = np.transpose(img, (1, 2, 0))
20
+ cv2.putText(img, f'{score:.2f}',
21
+ (40, 40),
22
+ cv2.FONT_HERSHEY_SIMPLEX,
23
+ 1,
24
+ (0, 1, 0),
25
+ thickness=3)
26
+ img = np.transpose(img, (2, 0, 1))
27
+ return img
28
+
29
+
30
+ def save_global_samples(global_mask_fnames, mask2real_fname, mask2fake_fname, out_dir, real_scores_by_fname, fake_scores_by_fname):
31
+ for cur_mask_fname in global_mask_fnames:
32
+ cur_real_fname = mask2real_fname[cur_mask_fname]
33
+ orig_img = load_image(cur_real_fname, mode='RGB')
34
+ fake_img = load_image(mask2fake_fname[cur_mask_fname], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
35
+ mask = load_image(cur_mask_fname, mode='L')[None, ...]
36
+
37
+ draw_score(orig_img, real_scores_by_fname.loc[cur_real_fname, 'real_score'])
38
+ draw_score(fake_img, fake_scores_by_fname.loc[cur_mask_fname, 'fake_score'])
39
+
40
+ cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=mask, fake=fake_img),
41
+ keys=['image', 'fake'],
42
+ last_without_mask=True)
43
+ cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
44
+ cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
45
+ cv2.imwrite(os.path.join(out_dir, os.path.splitext(os.path.basename(cur_mask_fname))[0] + '.jpg'),
46
+ cur_grid)
47
+
48
+
49
+ def save_samples_by_real(worst_best_by_real, mask2fake_fname, fake_info, out_dir):
50
+ for real_fname in worst_best_by_real.index:
51
+ worst_mask_path = worst_best_by_real.loc[real_fname, 'worst']
52
+ best_mask_path = worst_best_by_real.loc[real_fname, 'best']
53
+ orig_img = load_image(real_fname, mode='RGB')
54
+ worst_mask_img = load_image(worst_mask_path, mode='L')[None, ...]
55
+ worst_fake_img = load_image(mask2fake_fname[worst_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
56
+ best_mask_img = load_image(best_mask_path, mode='L')[None, ...]
57
+ best_fake_img = load_image(mask2fake_fname[best_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
58
+
59
+ draw_score(orig_img, worst_best_by_real.loc[real_fname, 'real_score'])
60
+ draw_score(worst_fake_img, worst_best_by_real.loc[real_fname, 'worst_score'])
61
+ draw_score(best_fake_img, worst_best_by_real.loc[real_fname, 'best_score'])
62
+
63
+ cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=np.zeros_like(worst_mask_img),
64
+ worst_mask=worst_mask_img, worst_img=worst_fake_img,
65
+ best_mask=best_mask_img, best_img=best_fake_img),
66
+ keys=['image', 'worst_mask', 'worst_img', 'best_mask', 'best_img'],
67
+ rescale_keys=['worst_mask', 'best_mask'],
68
+ last_without_mask=True)
69
+ cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
70
+ cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
71
+ cv2.imwrite(os.path.join(out_dir,
72
+ os.path.splitext(os.path.basename(real_fname))[0] + '.jpg'),
73
+ cur_grid)
74
+
75
+ fig, (ax1, ax2) = plt.subplots(1, 2)
76
+ cur_stat = fake_info[fake_info['real_fname'] == real_fname]
77
+ cur_stat['fake_score'].hist(ax=ax1)
78
+ cur_stat['real_score'].hist(ax=ax2)
79
+ fig.tight_layout()
80
+ fig.savefig(os.path.join(out_dir,
81
+ os.path.splitext(os.path.basename(real_fname))[0] + '_scores.png'))
82
+ plt.close(fig)
83
+
84
+
85
+ def extract_overlapping_masks(mask_fnames, cur_i, fake_scores_table, max_overlaps_n=2):
86
+ result_pairs = []
87
+ result_scores = []
88
+ mask_fname_a = mask_fnames[cur_i]
89
+ mask_a = load_image(mask_fname_a, mode='L')[None, ...] > 0.5
90
+ cur_score_a = fake_scores_table.loc[mask_fname_a, 'fake_score']
91
+ for mask_fname_b in mask_fnames[cur_i + 1:]:
92
+ mask_b = load_image(mask_fname_b, mode='L')[None, ...] > 0.5
93
+ if not np.any(mask_a & mask_b):
94
+ continue
95
+ cur_score_b = fake_scores_table.loc[mask_fname_b, 'fake_score']
96
+ result_pairs.append((mask_fname_a, mask_fname_b))
97
+ result_scores.append(cur_score_b - cur_score_a)
98
+ if len(result_pairs) >= max_overlaps_n:
99
+ break
100
+ return result_pairs, result_scores
101
+
102
+
103
+ def main(args):
104
+ config = load_yaml(args.config)
105
+
106
+ latents_dir = os.path.join(args.outpath, 'latents')
107
+ os.makedirs(latents_dir, exist_ok=True)
108
+ global_worst_dir = os.path.join(args.outpath, 'global_worst')
109
+ os.makedirs(global_worst_dir, exist_ok=True)
110
+ global_best_dir = os.path.join(args.outpath, 'global_best')
111
+ os.makedirs(global_best_dir, exist_ok=True)
112
+ worst_best_by_best_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_max')
113
+ os.makedirs(worst_best_by_best_worst_score_diff_max_dir, exist_ok=True)
114
+ worst_best_by_best_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_min')
115
+ os.makedirs(worst_best_by_best_worst_score_diff_min_dir, exist_ok=True)
116
+ worst_best_by_real_best_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_max')
117
+ os.makedirs(worst_best_by_real_best_score_diff_max_dir, exist_ok=True)
118
+ worst_best_by_real_best_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_min')
119
+ os.makedirs(worst_best_by_real_best_score_diff_min_dir, exist_ok=True)
120
+ worst_best_by_real_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_max')
121
+ os.makedirs(worst_best_by_real_worst_score_diff_max_dir, exist_ok=True)
122
+ worst_best_by_real_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_min')
123
+ os.makedirs(worst_best_by_real_worst_score_diff_min_dir, exist_ok=True)
124
+
125
+ if not args.only_report:
126
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
127
+ inception_model = InceptionV3([block_idx]).eval().cuda()
128
+
129
+ dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
130
+
131
+ real2vector_cache = {}
132
+
133
+ real_features = []
134
+ fake_features = []
135
+
136
+ orig_fnames = []
137
+ mask_fnames = []
138
+ mask2real_fname = {}
139
+ mask2fake_fname = {}
140
+
141
+ for batch_i, batch in enumerate(dataset):
142
+ orig_img_fname = dataset.img_filenames[batch_i]
143
+ mask_fname = dataset.mask_filenames[batch_i]
144
+ fake_fname = dataset.pred_filenames[batch_i]
145
+ mask2real_fname[mask_fname] = orig_img_fname
146
+ mask2fake_fname[mask_fname] = fake_fname
147
+
148
+ cur_real_vector = real2vector_cache.get(orig_img_fname, None)
149
+ if cur_real_vector is None:
150
+ with torch.no_grad():
151
+ in_img = torch.from_numpy(batch['image'][None, ...]).cuda()
152
+ cur_real_vector = inception_model(in_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
153
+ real2vector_cache[orig_img_fname] = cur_real_vector
154
+
155
+ pred_img = torch.from_numpy(batch['inpainted'][None, ...]).cuda()
156
+ cur_fake_vector = inception_model(pred_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
157
+
158
+ real_features.append(cur_real_vector)
159
+ fake_features.append(cur_fake_vector)
160
+
161
+ orig_fnames.append(orig_img_fname)
162
+ mask_fnames.append(mask_fname)
163
+
164
+ ids_features = np.concatenate(real_features + fake_features, axis=0)
165
+ ids_labels = np.array(([1] * len(real_features)) + ([0] * len(fake_features)))
166
+
167
+ with open(os.path.join(latents_dir, 'featues.pkl'), 'wb') as f:
168
+ pickle.dump(ids_features, f, protocol=3)
169
+ with open(os.path.join(latents_dir, 'labels.pkl'), 'wb') as f:
170
+ pickle.dump(ids_labels, f, protocol=3)
171
+ with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'wb') as f:
172
+ pickle.dump(orig_fnames, f, protocol=3)
173
+ with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'wb') as f:
174
+ pickle.dump(mask_fnames, f, protocol=3)
175
+ with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'wb') as f:
176
+ pickle.dump(mask2real_fname, f, protocol=3)
177
+ with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'wb') as f:
178
+ pickle.dump(mask2fake_fname, f, protocol=3)
179
+
180
+ svm = sklearn.svm.LinearSVC(dual=False)
181
+ svm.fit(ids_features, ids_labels)
182
+
183
+ pred_scores = svm.decision_function(ids_features)
184
+ real_scores = pred_scores[:len(real_features)]
185
+ fake_scores = pred_scores[len(real_features):]
186
+
187
+ with open(os.path.join(latents_dir, 'pred_scores.pkl'), 'wb') as f:
188
+ pickle.dump(pred_scores, f, protocol=3)
189
+ with open(os.path.join(latents_dir, 'real_scores.pkl'), 'wb') as f:
190
+ pickle.dump(real_scores, f, protocol=3)
191
+ with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'wb') as f:
192
+ pickle.dump(fake_scores, f, protocol=3)
193
+ else:
194
+ with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'rb') as f:
195
+ orig_fnames = pickle.load(f)
196
+ with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'rb') as f:
197
+ mask_fnames = pickle.load(f)
198
+ with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'rb') as f:
199
+ mask2real_fname = pickle.load(f)
200
+ with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'rb') as f:
201
+ mask2fake_fname = pickle.load(f)
202
+ with open(os.path.join(latents_dir, 'real_scores.pkl'), 'rb') as f:
203
+ real_scores = pickle.load(f)
204
+ with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'rb') as f:
205
+ fake_scores = pickle.load(f)
206
+
207
+ real_info = pd.DataFrame(data=[dict(real_fname=fname,
208
+ real_score=score)
209
+ for fname, score
210
+ in zip(orig_fnames, real_scores)])
211
+ real_info.set_index('real_fname', drop=True, inplace=True)
212
+
213
+ fake_info = pd.DataFrame(data=[dict(mask_fname=fname,
214
+ fake_fname=mask2fake_fname[fname],
215
+ real_fname=mask2real_fname[fname],
216
+ fake_score=score)
217
+ for fname, score
218
+ in zip(mask_fnames, fake_scores)])
219
+ fake_info = fake_info.join(real_info, on='real_fname', how='left')
220
+ fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
221
+
222
+ fake_stats_by_real = fake_info.groupby('real_fname')['fake_score'].describe()[['mean', 'std']].rename(
223
+ {'mean': 'mean_fake_by_real', 'std': 'std_fake_by_real'}, axis=1)
224
+ fake_info = fake_info.join(fake_stats_by_real, on='real_fname', rsuffix='stat_by_real')
225
+ fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
226
+ fake_info.to_csv(os.path.join(latents_dir, 'join_scores_table.csv'), sep='\t', index=False)
227
+
228
+ fake_scores_table = fake_info.set_index('mask_fname')['fake_score'].to_frame()
229
+ real_scores_table = fake_info.set_index('real_fname')['real_score'].drop_duplicates().to_frame()
230
+
231
+ fig, (ax1, ax2) = plt.subplots(1, 2)
232
+ ax1.hist(fake_scores)
233
+ ax2.hist(real_scores)
234
+ fig.tight_layout()
235
+ fig.savefig(os.path.join(args.outpath, 'global_scores_hist.png'))
236
+ plt.close(fig)
237
+
238
+ global_worst_masks = fake_info.sort_values('fake_score', ascending=True)['mask_fname'].iloc[:config.take_global_top].to_list()
239
+ global_best_masks = fake_info.sort_values('fake_score', ascending=False)['mask_fname'].iloc[:config.take_global_top].to_list()
240
+ save_global_samples(global_worst_masks, mask2real_fname, mask2fake_fname, global_worst_dir, real_scores_table, fake_scores_table)
241
+ save_global_samples(global_best_masks, mask2real_fname, mask2fake_fname, global_best_dir, real_scores_table, fake_scores_table)
242
+
243
+ # grouped by real
244
+ worst_samples_by_real = fake_info.groupby('real_fname').apply(
245
+ lambda d: d.set_index('mask_fname')['fake_score'].idxmin()).to_frame().rename({0: 'worst'}, axis=1)
246
+ best_samples_by_real = fake_info.groupby('real_fname').apply(
247
+ lambda d: d.set_index('mask_fname')['fake_score'].idxmax()).to_frame().rename({0: 'best'}, axis=1)
248
+ worst_best_by_real = pd.concat([worst_samples_by_real, best_samples_by_real], axis=1)
249
+
250
+ worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'worst_score'}, axis=1),
251
+ on='worst')
252
+ worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'best_score'}, axis=1),
253
+ on='best')
254
+ worst_best_by_real = worst_best_by_real.join(real_scores_table)
255
+
256
+ worst_best_by_real['best_worst_score_diff'] = worst_best_by_real['best_score'] - worst_best_by_real['worst_score']
257
+ worst_best_by_real['real_best_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['best_score']
258
+ worst_best_by_real['real_worst_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['worst_score']
259
+
260
+ worst_best_by_best_worst_score_diff_min = worst_best_by_real.sort_values('best_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
261
+ worst_best_by_best_worst_score_diff_max = worst_best_by_real.sort_values('best_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
262
+ save_samples_by_real(worst_best_by_best_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_min_dir)
263
+ save_samples_by_real(worst_best_by_best_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_max_dir)
264
+
265
+ worst_best_by_real_best_score_diff_min = worst_best_by_real.sort_values('real_best_score_diff', ascending=True).iloc[:config.take_worst_best_top]
266
+ worst_best_by_real_best_score_diff_max = worst_best_by_real.sort_values('real_best_score_diff', ascending=False).iloc[:config.take_worst_best_top]
267
+ save_samples_by_real(worst_best_by_real_best_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_min_dir)
268
+ save_samples_by_real(worst_best_by_real_best_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_max_dir)
269
+
270
+ worst_best_by_real_worst_score_diff_min = worst_best_by_real.sort_values('real_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
271
+ worst_best_by_real_worst_score_diff_max = worst_best_by_real.sort_values('real_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
272
+ save_samples_by_real(worst_best_by_real_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_min_dir)
273
+ save_samples_by_real(worst_best_by_real_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_max_dir)
274
+
275
+ # analyze what change of mask causes bigger change of score
276
+ overlapping_mask_fname_pairs = []
277
+ overlapping_mask_fname_score_diffs = []
278
+ for cur_real_fname in orig_fnames:
279
+ cur_fakes_info = fake_info[fake_info['real_fname'] == cur_real_fname]
280
+ cur_mask_fnames = sorted(cur_fakes_info['mask_fname'].unique())
281
+
282
+ cur_mask_pairs_and_scores = Parallel(args.n_jobs)(
283
+ delayed(extract_overlapping_masks)(cur_mask_fnames, i, fake_scores_table)
284
+ for i in range(len(cur_mask_fnames) - 1)
285
+ )
286
+ for cur_pairs, cur_scores in cur_mask_pairs_and_scores:
287
+ overlapping_mask_fname_pairs.extend(cur_pairs)
288
+ overlapping_mask_fname_score_diffs.extend(cur_scores)
289
+
290
+ overlapping_mask_fname_pairs = np.asarray(overlapping_mask_fname_pairs)
291
+ overlapping_mask_fname_score_diffs = np.asarray(overlapping_mask_fname_score_diffs)
292
+ overlapping_sort_idx = np.argsort(overlapping_mask_fname_score_diffs)
293
+ overlapping_mask_fname_pairs = overlapping_mask_fname_pairs[overlapping_sort_idx]
294
+ overlapping_mask_fname_score_diffs = overlapping_mask_fname_score_diffs[overlapping_sort_idx]
295
+
296
+
297
+
298
+
299
+
300
+
301
+ if __name__ == '__main__':
302
+ import argparse
303
+
304
+ aparser = argparse.ArgumentParser()
305
+ aparser.add_argument('config', type=str, help='Path to config for dataset generation')
306
+ aparser.add_argument('datadir', type=str,
307
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
308
+ aparser.add_argument('predictdir', type=str,
309
+ help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
310
+ aparser.add_argument('outpath', type=str, help='Where to put results')
311
+ aparser.add_argument('--only-report', action='store_true',
312
+ help='Whether to skip prediction and feature extraction, '
313
+ 'load all the possible latents and proceed with report only')
314
+ aparser.add_argument('--n-jobs', type=int, default=8, help='how many processes to use for pair mask mining')
315
+
316
+ main(aparser.parse_args())
bin/blur_predicts.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import tqdm
8
+
9
+ from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
10
+ from saicinpainting.evaluation.utils import load_yaml
11
+
12
+
13
+ def main(args):
14
+ config = load_yaml(args.config)
15
+
16
+ if not args.predictdir.endswith('/'):
17
+ args.predictdir += '/'
18
+
19
+ dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
20
+
21
+ os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
22
+
23
+ for img_i in tqdm.trange(len(dataset)):
24
+ pred_fname = dataset.pred_filenames[img_i]
25
+ cur_out_fname = os.path.join(args.outpath, pred_fname[len(args.predictdir):])
26
+ os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
27
+
28
+ sample = dataset[img_i]
29
+ img = sample['image']
30
+ mask = sample['mask']
31
+ inpainted = sample['inpainted']
32
+
33
+ inpainted_blurred = cv2.GaussianBlur(np.transpose(inpainted, (1, 2, 0)),
34
+ ksize=(args.k, args.k),
35
+ sigmaX=args.s, sigmaY=args.s,
36
+ borderType=cv2.BORDER_REFLECT)
37
+
38
+ cur_res = (1 - mask) * np.transpose(img, (1, 2, 0)) + mask * inpainted_blurred
39
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
40
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
41
+ cv2.imwrite(cur_out_fname, cur_res)
42
+
43
+
44
+ if __name__ == '__main__':
45
+ import argparse
46
+
47
+ aparser = argparse.ArgumentParser()
48
+ aparser.add_argument('config', type=str, help='Path to evaluation config')
49
+ aparser.add_argument('datadir', type=str,
50
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
51
+ aparser.add_argument('predictdir', type=str,
52
+ help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
53
+ aparser.add_argument('outpath', type=str, help='Where to put results')
54
+ aparser.add_argument('-s', type=float, default=0.1, help='Gaussian blur sigma')
55
+ aparser.add_argument('-k', type=int, default=5, help='Kernel size in gaussian blur')
56
+
57
+ main(aparser.parse_args())
bin/calc_dataset_stats.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ import tqdm
7
+ from scipy.ndimage.morphology import distance_transform_edt
8
+
9
+ from saicinpainting.evaluation.data import InpaintingDataset
10
+ from saicinpainting.evaluation.vis import save_item_for_vis
11
+
12
+
13
+ def main(args):
14
+ dataset = InpaintingDataset(args.datadir, img_suffix='.png')
15
+
16
+ area_bins = np.linspace(0, 1, args.area_bins + 1)
17
+
18
+ heights = []
19
+ widths = []
20
+ image_areas = []
21
+ hole_areas = []
22
+ hole_area_percents = []
23
+ known_pixel_distances = []
24
+
25
+ area_bins_count = np.zeros(args.area_bins)
26
+ area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
27
+
28
+ bin2i = [[] for _ in range(args.area_bins)]
29
+
30
+ for i, item in enumerate(tqdm.tqdm(dataset)):
31
+ h, w = item['image'].shape[1:]
32
+ heights.append(h)
33
+ widths.append(w)
34
+ full_area = h * w
35
+ image_areas.append(full_area)
36
+ bin_mask = item['mask'] > 0.5
37
+ hole_area = bin_mask.sum()
38
+ hole_areas.append(hole_area)
39
+ hole_percent = hole_area / full_area
40
+ hole_area_percents.append(hole_percent)
41
+ bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
42
+ area_bins_count[bin_i] += 1
43
+ bin2i[bin_i].append(i)
44
+
45
+ cur_dist = distance_transform_edt(bin_mask)
46
+ cur_dist_inside_mask = cur_dist[bin_mask]
47
+ known_pixel_distances.append(cur_dist_inside_mask.mean())
48
+
49
+ os.makedirs(args.outdir, exist_ok=True)
50
+ with open(os.path.join(args.outdir, 'summary.txt'), 'w') as f:
51
+ f.write(f'''Location: {args.datadir}
52
+
53
+ Number of samples: {len(dataset)}
54
+
55
+ Image height: min {min(heights):5d} max {max(heights):5d} mean {np.mean(heights):.2f}
56
+ Image width: min {min(widths):5d} max {max(widths):5d} mean {np.mean(widths):.2f}
57
+ Image area: min {min(image_areas):7d} max {max(image_areas):7d} mean {np.mean(image_areas):.2f}
58
+ Hole area: min {min(hole_areas):7d} max {max(hole_areas):7d} mean {np.mean(hole_areas):.2f}
59
+ Hole area %: min {min(hole_area_percents) * 100:2.2f} max {max(hole_area_percents) * 100:2.2f} mean {np.mean(hole_area_percents) * 100:2.2f}
60
+ Dist 2known: min {min(known_pixel_distances):2.2f} max {max(known_pixel_distances):2.2f} mean {np.mean(known_pixel_distances):2.2f} median {np.median(known_pixel_distances):2.2f}
61
+
62
+ Stats by hole area %:
63
+ ''')
64
+ for bin_i in range(args.area_bins):
65
+ f.write(f'{area_bin_titles[bin_i]}%: '
66
+ f'samples number {area_bins_count[bin_i]}, '
67
+ f'{area_bins_count[bin_i] / len(dataset) * 100:.1f}%\n')
68
+
69
+ for bin_i in range(args.area_bins):
70
+ bindir = os.path.join(args.outdir, 'samples', area_bin_titles[bin_i])
71
+ os.makedirs(bindir, exist_ok=True)
72
+ bin_idx = bin2i[bin_i]
73
+ for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
74
+ save_item_for_vis(dataset[sample_i], os.path.join(bindir, f'{sample_i}.png'))
75
+
76
+
77
+ if __name__ == '__main__':
78
+ import argparse
79
+
80
+ aparser = argparse.ArgumentParser()
81
+ aparser.add_argument('datadir', type=str,
82
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
83
+ aparser.add_argument('outdir', type=str, help='Where to put results')
84
+ aparser.add_argument('--samples-n', type=int, default=10,
85
+ help='Number of sample images with masks to copy for visualization for each area bin')
86
+ aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
87
+
88
+ main(aparser.parse_args())
bin/debug/analyze_overlapping_masks.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ BASEDIR="$(dirname $0)"
4
+
5
+ # paths are valid for mml7
6
+
7
+ # select images
8
+ #ls /data/inpainting/work/data/train | shuf | head -2000 | xargs -n1 -I{} cp {} /data/inpainting/mask_analysis/src
9
+
10
+ # generate masks
11
+ #"$BASEDIR/../gen_debug_mask_dataset.py" \
12
+ # "$BASEDIR/../../configs/debug_mask_gen.yaml" \
13
+ # "/data/inpainting/mask_analysis/src" \
14
+ # "/data/inpainting/mask_analysis/generated"
15
+
16
+ # predict
17
+ #"$BASEDIR/../predict.py" \
18
+ # model.path="simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/saved_checkpoint/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15_epoch22-step-574999" \
19
+ # indir="/data/inpainting/mask_analysis/generated" \
20
+ # outdir="/data/inpainting/mask_analysis/predicted" \
21
+ # dataset.img_suffix=.jpg \
22
+ # +out_ext=.jpg
23
+
24
+ # analyze good and bad samples
25
+ "$BASEDIR/../analyze_errors.py" \
26
+ --only-report \
27
+ --n-jobs 8 \
28
+ "$BASEDIR/../../configs/analyze_mask_errors.yaml" \
29
+ "/data/inpainting/mask_analysis/small/generated" \
30
+ "/data/inpainting/mask_analysis/small/predicted" \
31
+ "/data/inpainting/mask_analysis/small/report"
bin/evaluate_predicts.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ import pandas as pd
6
+
7
+ from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
8
+ from saicinpainting.evaluation.evaluator import InpaintingEvaluator, lpips_fid100_f1
9
+ from saicinpainting.evaluation.losses.base_loss import SegmentationAwareSSIM, \
10
+ SegmentationClassStats, SSIMScore, LPIPSScore, FIDScore, SegmentationAwareLPIPS, SegmentationAwareFID
11
+ from saicinpainting.evaluation.utils import load_yaml
12
+
13
+
14
+ def main(args):
15
+ config = load_yaml(args.config)
16
+
17
+ dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
18
+
19
+ metrics = {
20
+ 'ssim': SSIMScore(),
21
+ 'lpips': LPIPSScore(),
22
+ 'fid': FIDScore()
23
+ }
24
+ enable_segm = config.get('segmentation', dict(enable=False)).get('enable', False)
25
+ if enable_segm:
26
+ weights_path = os.path.expandvars(config.segmentation.weights_path)
27
+ metrics.update(dict(
28
+ segm_stats=SegmentationClassStats(weights_path=weights_path),
29
+ segm_ssim=SegmentationAwareSSIM(weights_path=weights_path),
30
+ segm_lpips=SegmentationAwareLPIPS(weights_path=weights_path),
31
+ segm_fid=SegmentationAwareFID(weights_path=weights_path)
32
+ ))
33
+ evaluator = InpaintingEvaluator(dataset, scores=metrics,
34
+ integral_title='lpips_fid100_f1', integral_func=lpips_fid100_f1,
35
+ **config.evaluator_kwargs)
36
+
37
+ os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
38
+
39
+ results = evaluator.evaluate()
40
+
41
+ results = pd.DataFrame(results).stack(1).unstack(0)
42
+ results.dropna(axis=1, how='all', inplace=True)
43
+ results.to_csv(args.outpath, sep='\t', float_format='%.4f')
44
+
45
+ if enable_segm:
46
+ only_short_results = results[[c for c in results.columns if not c[0].startswith('segm_')]].dropna(axis=1, how='all')
47
+ only_short_results.to_csv(args.outpath + '_short', sep='\t', float_format='%.4f')
48
+
49
+ print(only_short_results)
50
+
51
+ segm_metrics_results = results[['segm_ssim', 'segm_lpips', 'segm_fid']].dropna(axis=1, how='all').transpose().unstack(0).reorder_levels([1, 0], axis=1)
52
+ segm_metrics_results.drop(['mean', 'std'], axis=0, inplace=True)
53
+
54
+ segm_stats_results = results['segm_stats'].dropna(axis=1, how='all').transpose()
55
+ segm_stats_results.index = pd.MultiIndex.from_tuples(n.split('/') for n in segm_stats_results.index)
56
+ segm_stats_results = segm_stats_results.unstack(0).reorder_levels([1, 0], axis=1)
57
+ segm_stats_results.sort_index(axis=1, inplace=True)
58
+ segm_stats_results.dropna(axis=0, how='all', inplace=True)
59
+
60
+ segm_results = pd.concat([segm_metrics_results, segm_stats_results], axis=1, sort=True)
61
+ segm_results.sort_values(('mask_freq', 'total'), ascending=False, inplace=True)
62
+
63
+ segm_results.to_csv(args.outpath + '_segm', sep='\t', float_format='%.4f')
64
+ else:
65
+ print(results)
66
+
67
+
68
+ if __name__ == '__main__':
69
+ import argparse
70
+
71
+ aparser = argparse.ArgumentParser()
72
+ aparser.add_argument('config', type=str, help='Path to evaluation config')
73
+ aparser.add_argument('datadir', type=str,
74
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
75
+ aparser.add_argument('predictdir', type=str,
76
+ help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
77
+ aparser.add_argument('outpath', type=str, help='Where to put results')
78
+
79
+ main(aparser.parse_args())
bin/evaluator_example.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from skimage import io
7
+ from skimage.transform import resize
8
+ from torch.utils.data import Dataset
9
+
10
+ from saicinpainting.evaluation.evaluator import InpaintingEvaluator
11
+ from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
12
+
13
+
14
+ class SimpleImageDataset(Dataset):
15
+ def __init__(self, root_dir, image_size=(400, 600)):
16
+ self.root_dir = root_dir
17
+ self.files = sorted(os.listdir(root_dir))
18
+ self.image_size = image_size
19
+
20
+ def __getitem__(self, index):
21
+ img_name = os.path.join(self.root_dir, self.files[index])
22
+ image = io.imread(img_name)
23
+ image = resize(image, self.image_size, anti_aliasing=True)
24
+ image = torch.FloatTensor(image).permute(2, 0, 1)
25
+ return image
26
+
27
+ def __len__(self):
28
+ return len(self.files)
29
+
30
+
31
+ def create_rectangle_mask(height, width):
32
+ mask = np.ones((height, width))
33
+ up_left_corner = width // 4, height // 4
34
+ down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1)
35
+ cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED)
36
+ return mask
37
+
38
+
39
+ class Model():
40
+ def __call__(self, img_batch, mask_batch):
41
+ mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None]
42
+ inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :]
43
+ return inpainted
44
+
45
+
46
+ class SimpleImageSquareMaskDataset(Dataset):
47
+ def __init__(self, dataset):
48
+ self.dataset = dataset
49
+ self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size))
50
+ self.model = Model()
51
+
52
+ def __getitem__(self, index):
53
+ img = self.dataset[index]
54
+ mask = self.mask.clone()
55
+ inpainted = self.model(img[None, ...], mask[None, ...])
56
+ return dict(image=img, mask=mask, inpainted=inpainted)
57
+
58
+ def __len__(self):
59
+ return len(self.dataset)
60
+
61
+
62
+ dataset = SimpleImageDataset('imgs')
63
+ mask_dataset = SimpleImageSquareMaskDataset(dataset)
64
+ model = Model()
65
+ metrics = {
66
+ 'ssim': SSIMScore(),
67
+ 'lpips': LPIPSScore(),
68
+ 'fid': FIDScore()
69
+ }
70
+
71
+ evaluator = InpaintingEvaluator(
72
+ mask_dataset, scores=metrics, batch_size=3, area_grouping=True
73
+ )
74
+
75
+ results = evaluator.evaluate(model)
76
+ print(results)
bin/extract_masks.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL.Image as Image
2
+ import numpy as np
3
+ import os
4
+
5
+
6
+ def main(args):
7
+ if not args.indir.endswith('/'):
8
+ args.indir += '/'
9
+ os.makedirs(args.outdir, exist_ok=True)
10
+
11
+ src_images = [
12
+ args.indir+fname for fname in os.listdir(args.indir)]
13
+
14
+ tgt_masks = [
15
+ args.outdir+fname[:-4] + f'_mask000.png'
16
+ for fname in os.listdir(args.indir)]
17
+
18
+ for img_name, msk_name in zip(src_images, tgt_masks):
19
+ #print(img)
20
+ #print(msk)
21
+
22
+ image = Image.open(img_name).convert('RGB')
23
+ image = np.transpose(np.array(image), (2, 0, 1))
24
+
25
+ mask = (image == 255).astype(int)
26
+
27
+ print(mask.dtype, mask.shape)
28
+
29
+
30
+ Image.fromarray(
31
+ np.clip(mask[0,:,:] * 255, 0, 255).astype('uint8'),mode='L'
32
+ ).save(msk_name)
33
+
34
+
35
+
36
+
37
+ '''
38
+ for infile in src_images:
39
+ try:
40
+ file_relpath = infile[len(indir):]
41
+ img_outpath = os.path.join(outdir, file_relpath)
42
+ os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
43
+
44
+ image = Image.open(infile).convert('RGB')
45
+
46
+ mask =
47
+
48
+ Image.fromarray(
49
+ np.clip(
50
+ cur_mask * 255, 0, 255).astype('uint8'),
51
+ mode='L'
52
+ ).save(cur_basename + f'_mask{i:03d}.png')
53
+ '''
54
+
55
+
56
+
57
+ if __name__ == '__main__':
58
+ import argparse
59
+ aparser = argparse.ArgumentParser()
60
+ aparser.add_argument('--indir', type=str, help='Path to folder with images')
61
+ aparser.add_argument('--outdir', type=str, help='Path to folder to store aligned images and masks to')
62
+
63
+ main(aparser.parse_args())
bin/filter_sharded_dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import math
5
+ import os
6
+ import random
7
+
8
+ import braceexpand
9
+ import webdataset as wds
10
+
11
+ DEFAULT_CATS_FILE = os.path.join(os.path.dirname(__file__), '..', 'configs', 'places2-categories_157.txt')
12
+
13
+ def is_good_key(key, cats):
14
+ return any(c in key for c in cats)
15
+
16
+
17
+ def main(args):
18
+ if args.categories == 'nofilter':
19
+ good_categories = None
20
+ else:
21
+ with open(args.categories, 'r') as f:
22
+ good_categories = set(line.strip().split(' ')[0] for line in f if line.strip())
23
+
24
+ all_input_files = list(braceexpand.braceexpand(args.infile))
25
+ chunk_size = int(math.ceil(len(all_input_files) / args.n_read_streams))
26
+
27
+ input_iterators = [iter(wds.Dataset(all_input_files[start : start + chunk_size]).shuffle(args.shuffle_buffer))
28
+ for start in range(0, len(all_input_files), chunk_size)]
29
+ output_datasets = [wds.ShardWriter(args.outpattern.format(i)) for i in range(args.n_write_streams)]
30
+
31
+ good_readers = list(range(len(input_iterators)))
32
+ step_i = 0
33
+ good_samples = 0
34
+ bad_samples = 0
35
+ while len(good_readers) > 0:
36
+ if step_i % args.print_freq == 0:
37
+ print(f'Iterations done {step_i}; readers alive {good_readers}; good samples {good_samples}; bad samples {bad_samples}')
38
+
39
+ step_i += 1
40
+
41
+ ri = random.choice(good_readers)
42
+ try:
43
+ sample = next(input_iterators[ri])
44
+ except StopIteration:
45
+ good_readers = list(set(good_readers) - {ri})
46
+ continue
47
+
48
+ if good_categories is not None and not is_good_key(sample['__key__'], good_categories):
49
+ bad_samples += 1
50
+ continue
51
+
52
+ wi = random.randint(0, args.n_write_streams - 1)
53
+ output_datasets[wi].write(sample)
54
+ good_samples += 1
55
+
56
+
57
+ if __name__ == '__main__':
58
+ import argparse
59
+
60
+ aparser = argparse.ArgumentParser()
61
+ aparser.add_argument('--categories', type=str, default=DEFAULT_CATS_FILE)
62
+ aparser.add_argument('--shuffle-buffer', type=int, default=10000)
63
+ aparser.add_argument('--n-read-streams', type=int, default=10)
64
+ aparser.add_argument('--n-write-streams', type=int, default=10)
65
+ aparser.add_argument('--print-freq', type=int, default=1000)
66
+ aparser.add_argument('infile', type=str)
67
+ aparser.add_argument('outpattern', type=str)
68
+
69
+ main(aparser.parse_args())
bin/gen_debug_mask_dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import glob
4
+ import os
5
+
6
+ import PIL.Image as Image
7
+ import cv2
8
+ import numpy as np
9
+ import tqdm
10
+ import shutil
11
+
12
+
13
+ from saicinpainting.evaluation.utils import load_yaml
14
+
15
+
16
+ def generate_masks_for_img(infile, outmask_pattern, mask_size=200, step=0.5):
17
+ inimg = Image.open(infile)
18
+ width, height = inimg.size
19
+ step_abs = int(mask_size * step)
20
+
21
+ mask = np.zeros((height, width), dtype='uint8')
22
+ mask_i = 0
23
+
24
+ for start_vertical in range(0, height - step_abs, step_abs):
25
+ for start_horizontal in range(0, width - step_abs, step_abs):
26
+ mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 255
27
+
28
+ cv2.imwrite(outmask_pattern.format(mask_i), mask)
29
+
30
+ mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 0
31
+ mask_i += 1
32
+
33
+
34
+ def main(args):
35
+ if not args.indir.endswith('/'):
36
+ args.indir += '/'
37
+ if not args.outdir.endswith('/'):
38
+ args.outdir += '/'
39
+
40
+ config = load_yaml(args.config)
41
+
42
+ in_files = list(glob.glob(os.path.join(args.indir, '**', f'*{config.img_ext}'), recursive=True))
43
+ for infile in tqdm.tqdm(in_files):
44
+ outimg = args.outdir + infile[len(args.indir):]
45
+ outmask_pattern = outimg[:-len(config.img_ext)] + '_mask{:04d}.png'
46
+
47
+ os.makedirs(os.path.dirname(outimg), exist_ok=True)
48
+ shutil.copy2(infile, outimg)
49
+
50
+ generate_masks_for_img(infile, outmask_pattern, **config.gen_kwargs)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ import argparse
55
+
56
+ aparser = argparse.ArgumentParser()
57
+ aparser.add_argument('config', type=str, help='Path to config for dataset generation')
58
+ aparser.add_argument('indir', type=str, help='Path to folder with images')
59
+ aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
60
+
61
+ main(aparser.parse_args())
bin/gen_mask_dataset.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import glob
4
+ import os
5
+ import shutil
6
+ import traceback
7
+
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+ from joblib import Parallel, delayed
11
+
12
+ from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
13
+ from saicinpainting.evaluation.utils import load_yaml, SmallMode
14
+ from saicinpainting.training.data.masks import MixedMaskGenerator
15
+
16
+
17
+ class MakeManyMasksWrapper:
18
+ def __init__(self, impl, variants_n=2):
19
+ self.impl = impl
20
+ self.variants_n = variants_n
21
+
22
+ def get_masks(self, img):
23
+ img = np.transpose(np.array(img), (2, 0, 1))
24
+ return [self.impl(img)[0] for _ in range(self.variants_n)]
25
+
26
+
27
+ def process_images(src_images, indir, outdir, config):
28
+ if config.generator_kind == 'segmentation':
29
+ mask_generator = SegmentationMask(**config.mask_generator_kwargs)
30
+ elif config.generator_kind == 'random':
31
+ variants_n = config.mask_generator_kwargs.pop('variants_n', 2)
32
+ mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs),
33
+ variants_n=variants_n)
34
+ else:
35
+ raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
36
+
37
+ max_tamper_area = config.get('max_tamper_area', 1)
38
+
39
+ for infile in src_images:
40
+ try:
41
+ file_relpath = infile[len(indir):]
42
+ img_outpath = os.path.join(outdir, file_relpath)
43
+ os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
44
+
45
+ image = Image.open(infile).convert('RGB')
46
+
47
+ # scale input image to output resolution and filter smaller images
48
+ if min(image.size) < config.cropping.out_min_size:
49
+ handle_small_mode = SmallMode(config.cropping.handle_small_mode)
50
+ if handle_small_mode == SmallMode.DROP:
51
+ continue
52
+ elif handle_small_mode == SmallMode.UPSCALE:
53
+ factor = config.cropping.out_min_size / min(image.size)
54
+ out_size = (np.array(image.size) * factor).round().astype('uint32')
55
+ image = image.resize(out_size, resample=Image.BICUBIC)
56
+ else:
57
+ factor = config.cropping.out_min_size / min(image.size)
58
+ out_size = (np.array(image.size) * factor).round().astype('uint32')
59
+ image = image.resize(out_size, resample=Image.BICUBIC)
60
+
61
+ # generate and select masks
62
+ src_masks = mask_generator.get_masks(image)
63
+
64
+ filtered_image_mask_pairs = []
65
+ for cur_mask in src_masks:
66
+ if config.cropping.out_square_crop:
67
+ (crop_left,
68
+ crop_top,
69
+ crop_right,
70
+ crop_bottom) = propose_random_square_crop(cur_mask,
71
+ min_overlap=config.cropping.crop_min_overlap)
72
+ cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
73
+ cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
74
+ else:
75
+ cur_image = image
76
+
77
+ if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
78
+ continue
79
+
80
+ filtered_image_mask_pairs.append((cur_image, cur_mask))
81
+
82
+ mask_indices = np.random.choice(len(filtered_image_mask_pairs),
83
+ size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
84
+ replace=False)
85
+
86
+ # crop masks; save masks together with input image
87
+ mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
88
+ for i, idx in enumerate(mask_indices):
89
+ cur_image, cur_mask = filtered_image_mask_pairs[idx]
90
+ cur_basename = mask_basename + f'_crop{i:03d}'
91
+ Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
92
+ mode='L').save(cur_basename + f'_mask{i:03d}.png')
93
+ cur_image.save(cur_basename + '.png')
94
+ except KeyboardInterrupt:
95
+ return
96
+ except Exception as ex:
97
+ print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
98
+
99
+
100
+ def main(args):
101
+ if not args.indir.endswith('/'):
102
+ args.indir += '/'
103
+
104
+ os.makedirs(args.outdir, exist_ok=True)
105
+
106
+ config = load_yaml(args.config)
107
+
108
+ in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True))
109
+ if args.n_jobs == 0:
110
+ process_images(in_files, args.indir, args.outdir, config)
111
+ else:
112
+ in_files_n = len(in_files)
113
+ chunk_size = in_files_n // args.n_jobs + (1 if in_files_n % args.n_jobs > 0 else 0)
114
+ Parallel(n_jobs=args.n_jobs)(
115
+ delayed(process_images)(in_files[start:start+chunk_size], args.indir, args.outdir, config)
116
+ for start in range(0, len(in_files), chunk_size)
117
+ )
118
+
119
+
120
+ if __name__ == '__main__':
121
+ import argparse
122
+
123
+ aparser = argparse.ArgumentParser()
124
+ aparser.add_argument('config', type=str, help='Path to config for dataset generation')
125
+ aparser.add_argument('indir', type=str, help='Path to folder with images')
126
+ aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
127
+ aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use')
128
+ aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension')
129
+
130
+ main(aparser.parse_args())
bin/gen_mask_dataset_hydra.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import glob
4
+ import os
5
+ import shutil
6
+ import traceback
7
+ import hydra
8
+ from omegaconf import OmegaConf
9
+
10
+ import PIL.Image as Image
11
+ import numpy as np
12
+ from joblib import Parallel, delayed
13
+
14
+ from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
15
+ from saicinpainting.evaluation.utils import load_yaml, SmallMode
16
+ from saicinpainting.training.data.masks import MixedMaskGenerator
17
+
18
+
19
+ class MakeManyMasksWrapper:
20
+ def __init__(self, impl, variants_n=2):
21
+ self.impl = impl
22
+ self.variants_n = variants_n
23
+
24
+ def get_masks(self, img):
25
+ img = np.transpose(np.array(img), (2, 0, 1))
26
+ return [self.impl(img)[0] for _ in range(self.variants_n)]
27
+
28
+
29
+ def process_images(src_images, indir, outdir, config):
30
+ if config.generator_kind == 'segmentation':
31
+ mask_generator = SegmentationMask(**config.mask_generator_kwargs)
32
+ elif config.generator_kind == 'random':
33
+ mask_generator_kwargs = OmegaConf.to_container(config.mask_generator_kwargs, resolve=True)
34
+ variants_n = mask_generator_kwargs.pop('variants_n', 2)
35
+ mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**mask_generator_kwargs),
36
+ variants_n=variants_n)
37
+ else:
38
+ raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
39
+
40
+ max_tamper_area = config.get('max_tamper_area', 1)
41
+
42
+ for infile in src_images:
43
+ try:
44
+ file_relpath = infile[len(indir):]
45
+ img_outpath = os.path.join(outdir, file_relpath)
46
+ os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
47
+
48
+ image = Image.open(infile).convert('RGB')
49
+
50
+ # scale input image to output resolution and filter smaller images
51
+ if min(image.size) < config.cropping.out_min_size:
52
+ handle_small_mode = SmallMode(config.cropping.handle_small_mode)
53
+ if handle_small_mode == SmallMode.DROP:
54
+ continue
55
+ elif handle_small_mode == SmallMode.UPSCALE:
56
+ factor = config.cropping.out_min_size / min(image.size)
57
+ out_size = (np.array(image.size) * factor).round().astype('uint32')
58
+ image = image.resize(out_size, resample=Image.BICUBIC)
59
+ else:
60
+ factor = config.cropping.out_min_size / min(image.size)
61
+ out_size = (np.array(image.size) * factor).round().astype('uint32')
62
+ image = image.resize(out_size, resample=Image.BICUBIC)
63
+
64
+ # generate and select masks
65
+ src_masks = mask_generator.get_masks(image)
66
+
67
+ filtered_image_mask_pairs = []
68
+ for cur_mask in src_masks:
69
+ if config.cropping.out_square_crop:
70
+ (crop_left,
71
+ crop_top,
72
+ crop_right,
73
+ crop_bottom) = propose_random_square_crop(cur_mask,
74
+ min_overlap=config.cropping.crop_min_overlap)
75
+ cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
76
+ cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
77
+ else:
78
+ cur_image = image
79
+
80
+ if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
81
+ continue
82
+
83
+ filtered_image_mask_pairs.append((cur_image, cur_mask))
84
+
85
+ mask_indices = np.random.choice(len(filtered_image_mask_pairs),
86
+ size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
87
+ replace=False)
88
+
89
+ # crop masks; save masks together with input image
90
+ mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
91
+ for i, idx in enumerate(mask_indices):
92
+ cur_image, cur_mask = filtered_image_mask_pairs[idx]
93
+ cur_basename = mask_basename + f'_crop{i:03d}'
94
+ Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
95
+ mode='L').save(cur_basename + f'_mask{i:03d}.png')
96
+ cur_image.save(cur_basename + '.png')
97
+ except KeyboardInterrupt:
98
+ return
99
+ except Exception as ex:
100
+ print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
101
+
102
+
103
+ @hydra.main(config_path='../configs/data_gen/whydra', config_name='random_medium_256.yaml')
104
+ def main(config: OmegaConf):
105
+ if not config.indir.endswith('/'):
106
+ config.indir += '/'
107
+
108
+ os.makedirs(config.outdir, exist_ok=True)
109
+
110
+ in_files = list(glob.glob(os.path.join(config.indir, '**', f'*.{config.location.extension}'),
111
+ recursive=True))
112
+ if config.n_jobs == 0:
113
+ process_images(in_files, config.indir, config.outdir, config)
114
+ else:
115
+ in_files_n = len(in_files)
116
+ chunk_size = in_files_n // config.n_jobs + (1 if in_files_n % config.n_jobs > 0 else 0)
117
+ Parallel(n_jobs=config.n_jobs)(
118
+ delayed(process_images)(in_files[start:start+chunk_size], config.indir, config.outdir, config)
119
+ for start in range(0, len(in_files), chunk_size)
120
+ )
121
+
122
+
123
+ if __name__ == '__main__':
124
+ main()
bin/gen_outpainting_dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import glob
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import sys
7
+ import traceback
8
+
9
+ from saicinpainting.evaluation.data import load_image
10
+ from saicinpainting.evaluation.utils import move_to_device
11
+
12
+ os.environ['OMP_NUM_THREADS'] = '1'
13
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
14
+ os.environ['MKL_NUM_THREADS'] = '1'
15
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
16
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
17
+
18
+ import cv2
19
+ import hydra
20
+ import numpy as np
21
+ import torch
22
+ import tqdm
23
+ import yaml
24
+ from omegaconf import OmegaConf
25
+ from torch.utils.data._utils.collate import default_collate
26
+
27
+ from saicinpainting.training.data.datasets import make_default_val_dataset
28
+ from saicinpainting.training.trainers import load_checkpoint
29
+ from saicinpainting.utils import register_debug_signal_handlers
30
+
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+
34
+ def main(args):
35
+ try:
36
+ if not args.indir.endswith('/'):
37
+ args.indir += '/'
38
+
39
+ for in_img in glob.glob(os.path.join(args.indir, '**', '*' + args.img_suffix), recursive=True):
40
+ if 'mask' in os.path.basename(in_img):
41
+ continue
42
+
43
+ out_img_path = os.path.join(args.outdir, os.path.splitext(in_img[len(args.indir):])[0] + '.png')
44
+ out_mask_path = f'{os.path.splitext(out_img_path)[0]}_mask.png'
45
+
46
+ os.makedirs(os.path.dirname(out_img_path), exist_ok=True)
47
+
48
+ img = load_image(in_img)
49
+ height, width = img.shape[1:]
50
+ pad_h, pad_w = int(height * args.coef / 2), int(width * args.coef / 2)
51
+
52
+ mask = np.zeros((height, width), dtype='uint8')
53
+
54
+ if args.expand:
55
+ img = np.pad(img, ((0, 0), (pad_h, pad_h), (pad_w, pad_w)))
56
+ mask = np.pad(mask, ((pad_h, pad_h), (pad_w, pad_w)), mode='constant', constant_values=255)
57
+ else:
58
+ mask[:pad_h] = 255
59
+ mask[-pad_h:] = 255
60
+ mask[:, :pad_w] = 255
61
+ mask[:, -pad_w:] = 255
62
+
63
+ # img = np.pad(img, ((0, 0), (pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode='symmetric')
64
+ # mask = np.pad(mask, ((pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode = 'symmetric')
65
+
66
+ img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype('uint8')
67
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
68
+ cv2.imwrite(out_img_path, img)
69
+
70
+ cv2.imwrite(out_mask_path, mask)
71
+ except KeyboardInterrupt:
72
+ LOGGER.warning('Interrupted by user')
73
+ except Exception as ex:
74
+ LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
75
+ sys.exit(1)
76
+
77
+
78
+ if __name__ == '__main__':
79
+ import argparse
80
+
81
+ aparser = argparse.ArgumentParser()
82
+ aparser.add_argument('indir', type=str, help='Root directory with images')
83
+ aparser.add_argument('outdir', type=str, help='Where to store results')
84
+ aparser.add_argument('--img-suffix', type=str, default='.png', help='Input image extension')
85
+ aparser.add_argument('--expand', action='store_true', help='Generate mask by padding (true) or by cropping (false)')
86
+ aparser.add_argument('--coef', type=float, default=0.2, help='How much to crop/expand in order to get masks')
87
+
88
+ main(aparser.parse_args())
bin/make_checkpoint.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import shutil
5
+
6
+ import torch
7
+
8
+
9
+ def get_checkpoint_files(s):
10
+ s = s.strip()
11
+ if ',' in s:
12
+ return [get_checkpoint_files(chunk) for chunk in s.split(',')]
13
+ return 'last.ckpt' if s == 'last' else f'{s}.ckpt'
14
+
15
+
16
+ def main(args):
17
+ checkpoint_fnames = get_checkpoint_files(args.epochs)
18
+ if isinstance(checkpoint_fnames, str):
19
+ checkpoint_fnames = [checkpoint_fnames]
20
+ assert len(checkpoint_fnames) >= 1
21
+
22
+ checkpoint_path = os.path.join(args.indir, 'models', checkpoint_fnames[0])
23
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
24
+ del checkpoint['optimizer_states']
25
+
26
+ if len(checkpoint_fnames) > 1:
27
+ for fname in checkpoint_fnames[1:]:
28
+ print('sum', fname)
29
+ sum_tensors_cnt = 0
30
+ other_cp = torch.load(os.path.join(args.indir, 'models', fname), map_location='cpu')
31
+ for k in checkpoint['state_dict'].keys():
32
+ if checkpoint['state_dict'][k].dtype is torch.float:
33
+ checkpoint['state_dict'][k].data.add_(other_cp['state_dict'][k].data)
34
+ sum_tensors_cnt += 1
35
+ print('summed', sum_tensors_cnt, 'tensors')
36
+
37
+ for k in checkpoint['state_dict'].keys():
38
+ if checkpoint['state_dict'][k].dtype is torch.float:
39
+ checkpoint['state_dict'][k].data.mul_(1 / float(len(checkpoint_fnames)))
40
+
41
+ state_dict = checkpoint['state_dict']
42
+
43
+ if not args.leave_discriminators:
44
+ for k in list(state_dict.keys()):
45
+ if k.startswith('discriminator.'):
46
+ del state_dict[k]
47
+
48
+ if not args.leave_losses:
49
+ for k in list(state_dict.keys()):
50
+ if k.startswith('loss_'):
51
+ del state_dict[k]
52
+
53
+ out_checkpoint_path = os.path.join(args.outdir, 'models', 'best.ckpt')
54
+ os.makedirs(os.path.dirname(out_checkpoint_path), exist_ok=True)
55
+
56
+ torch.save(checkpoint, out_checkpoint_path)
57
+
58
+ shutil.copy2(os.path.join(args.indir, 'config.yaml'),
59
+ os.path.join(args.outdir, 'config.yaml'))
60
+
61
+
62
+ if __name__ == '__main__':
63
+ import argparse
64
+
65
+ aparser = argparse.ArgumentParser()
66
+ aparser.add_argument('indir',
67
+ help='Path to directory with output of training '
68
+ '(i.e. directory, which has samples, modules, config.yaml and train.log')
69
+ aparser.add_argument('outdir',
70
+ help='Where to put minimal checkpoint, which can be consumed by "bin/predict.py"')
71
+ aparser.add_argument('--epochs', type=str, default='last',
72
+ help='Which checkpoint to take. '
73
+ 'Can be "last" or integer - number of epoch')
74
+ aparser.add_argument('--leave-discriminators', action='store_true',
75
+ help='If enabled, the state of discriminators will not be removed from the checkpoint')
76
+ aparser.add_argument('--leave-losses', action='store_true',
77
+ help='If enabled, weights of nn-based losses (e.g. perceptual) will not be removed')
78
+
79
+ main(aparser.parse_args())
bin/mask_example.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from skimage import io
3
+ from skimage.transform import resize
4
+
5
+ from saicinpainting.evaluation.masks.mask import SegmentationMask
6
+
7
+ im = io.imread('imgs/ex4.jpg')
8
+ im = resize(im, (512, 1024), anti_aliasing=True)
9
+ mask_seg = SegmentationMask(num_variants_per_mask=10)
10
+ mask_examples = mask_seg.get_masks(im)
11
+ for i, example in enumerate(mask_examples):
12
+ plt.imshow(example)
13
+ plt.show()
14
+ plt.imsave(f'tmp/img_masks/{i}.png', example)
bin/paper_runfiles/blur_tests.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/env bash
2
+ #
3
+ ## !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
4
+ #
5
+ ## paths to data are valid for mml7
6
+ #PLACES_ROOT="/data/inpainting/Places365"
7
+ #OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
8
+ #
9
+ #source "$(dirname $0)/env.sh"
10
+ #
11
+ #for datadir in test_large_30k # val_large
12
+ #do
13
+ # for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
14
+ # do
15
+ # "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
16
+ # "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
17
+ #
18
+ # "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
19
+ # done
20
+ #
21
+ # for conf in segm_256 segm_512
22
+ # do
23
+ # "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
24
+ # "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
25
+ #
26
+ # "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
27
+ # done
28
+ #done
29
+ #
30
+ #IN_DIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k/random_medium_512"
31
+ #PRED_DIR="/data/inpainting/predictions/final/images/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37/random_medium_512"
32
+ #BLUR_OUT_DIR="/data/inpainting/predictions/final/blur/images"
33
+ #
34
+ #for b in 0.1
35
+ #
36
+ #"$BINDIR/blur_predicts.py" "$BASEDIR/../../configs/eval2.yaml" "$CUR_IN_DIR" "$CUR_OUT_DIR" "$CUR_EVAL_DIR"
37
+ #
bin/paper_runfiles/env.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ DIRNAME="$(dirname $0)"
2
+ DIRNAME="$(realpath ""$DIRNAME"")"
3
+
4
+ BINDIR="$DIRNAME/.."
5
+ SRCDIR="$BINDIR/.."
6
+ CONFIGDIR="$SRCDIR/configs"
7
+
8
+ export PYTHONPATH="$SRCDIR:$PYTHONPATH"
bin/paper_runfiles/find_best_checkpoint.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import os
5
+ from argparse import ArgumentParser
6
+
7
+
8
+ def ssim_fid100_f1(metrics, fid_scale=100):
9
+ ssim = metrics.loc['total', 'ssim']['mean']
10
+ fid = metrics.loc['total', 'fid']['mean']
11
+ fid_rel = max(0, fid_scale - fid) / fid_scale
12
+ f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
13
+ return f1
14
+
15
+
16
+ def find_best_checkpoint(model_list, models_dir):
17
+ with open(model_list) as f:
18
+ models = [m.strip() for m in f.readlines()]
19
+ with open(f'{model_list}_best', 'w') as f:
20
+ for model in models:
21
+ print(model)
22
+ best_f1 = 0
23
+ best_epoch = 0
24
+ best_step = 0
25
+ with open(os.path.join(models_dir, model, 'train.log')) as fm:
26
+ lines = fm.readlines()
27
+ for line_index in range(len(lines)):
28
+ line = lines[line_index]
29
+ if 'Validation metrics after epoch' in line:
30
+ sharp_index = line.index('#')
31
+ cur_ep = line[sharp_index + 1:]
32
+ comma_index = cur_ep.index(',')
33
+ cur_ep = int(cur_ep[:comma_index])
34
+ total_index = line.index('total ')
35
+ step = int(line[total_index:].split()[1].strip())
36
+ total_line = lines[line_index + 5]
37
+ if not total_line.startswith('total'):
38
+ continue
39
+ words = total_line.strip().split()
40
+ f1 = float(words[-1])
41
+ print(f'\tEpoch: {cur_ep}, f1={f1}')
42
+ if f1 > best_f1:
43
+ best_f1 = f1
44
+ best_epoch = cur_ep
45
+ best_step = step
46
+ f.write(f'{model}\t{best_epoch}\t{best_step}\t{best_f1}\n')
47
+
48
+
49
+ if __name__ == '__main__':
50
+ parser = ArgumentParser()
51
+ parser.add_argument('model_list')
52
+ parser.add_argument('models_dir')
53
+ args = parser.parse_args()
54
+ find_best_checkpoint(args.model_list, args.models_dir)
bin/paper_runfiles/generate_test_celeba-hq.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml-ws01
4
+ OUT_DIR="/media/inpainting/paper_data/CelebA-HQ_val_test"
5
+
6
+ source "$(dirname $0)/env.sh"
7
+
8
+ for datadir in "val" "test"
9
+ do
10
+ for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
11
+ do
12
+ "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-celeba-hq \
13
+ location.out_dir=$OUT_DIR cropping.out_square_crop=False
14
+
15
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
16
+ done
17
+ done
bin/paper_runfiles/generate_test_ffhq.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml-ws01
4
+ OUT_DIR="/media/inpainting/paper_data/FFHQ_val"
5
+
6
+ source "$(dirname $0)/env.sh"
7
+
8
+ for datadir in test
9
+ do
10
+ for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
11
+ do
12
+ "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-ffhq \
13
+ location.out_dir=$OUT_DIR cropping.out_square_crop=False
14
+
15
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
16
+ done
17
+ done
bin/paper_runfiles/generate_test_paris.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml-ws01
4
+ OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val"
5
+
6
+ source "$(dirname $0)/env.sh"
7
+
8
+ for datadir in paris_eval_gt
9
+ do
10
+ for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
11
+ do
12
+ "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
13
+ location.out_dir=OUT_DIR cropping.out_square_crop=False cropping.out_min_size=227
14
+
15
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
16
+ done
17
+ done
bin/paper_runfiles/generate_test_paris_256.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml-ws01
4
+ OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val_256"
5
+
6
+ source "$(dirname $0)/env.sh"
7
+
8
+ for datadir in paris_eval_gt
9
+ do
10
+ for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
11
+ do
12
+ "$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
13
+ location.out_dir=$OUT_DIR cropping.out_square_crop=False cropping.out_min_size=256
14
+
15
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
16
+ done
17
+ done
bin/paper_runfiles/generate_val_test.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
4
+
5
+ # paths to data are valid for mml7
6
+ PLACES_ROOT="/data/inpainting/Places365"
7
+ OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
8
+
9
+ source "$(dirname $0)/env.sh"
10
+
11
+ for datadir in test_large_30k # val_large
12
+ do
13
+ for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
14
+ do
15
+ "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
16
+ "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
17
+
18
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
19
+ done
20
+
21
+ for conf in segm_256 segm_512
22
+ do
23
+ "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
24
+ "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
25
+
26
+ "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
27
+ done
28
+ done
bin/paper_runfiles/predict_inner_features.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml7
4
+
5
+ source "$(dirname $0)/env.sh"
6
+
7
+ "$BINDIR/predict_inner_features.py" \
8
+ -cn default_inner_features_ffc \
9
+ model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-34-05_train_ablv2_work_ffc075_resume_epoch39" \
10
+ indir="/data/inpainting/paper_data/inner_features_vis/input/" \
11
+ outdir="/data/inpainting/paper_data/inner_features_vis/output/ffc" \
12
+ dataset.img_suffix=.png
13
+
14
+
15
+ "$BINDIR/predict_inner_features.py" \
16
+ -cn default_inner_features_work \
17
+ model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37" \
18
+ indir="/data/inpainting/paper_data/inner_features_vis/input/" \
19
+ outdir="/data/inpainting/paper_data/inner_features_vis/output/work" \
20
+ dataset.img_suffix=.png
bin/paper_runfiles/update_test_data_stats.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # paths to data are valid for mml7
4
+
5
+ source "$(dirname $0)/env.sh"
6
+
7
+ #INDIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k"
8
+ #
9
+ #for dataset in random_medium_256 random_medium_512 random_thick_256 random_thick_512 random_thin_256 random_thin_512
10
+ #do
11
+ # "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
12
+ #done
13
+ #
14
+ #"$BINDIR/calc_dataset_stats.py" "/data/inpainting/evalset2" "/data/inpainting/evalset2_stats2"
15
+
16
+
17
+ INDIR="/data/inpainting/paper_data/CelebA-HQ_val_test/test"
18
+
19
+ for dataset in random_medium_256 random_thick_256 random_thin_256
20
+ do
21
+ "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
22
+ done
23
+
24
+
25
+ INDIR="/data/inpainting/paper_data/Paris_StreetView_Dataset_val_256/paris_eval_gt"
26
+
27
+ for dataset in random_medium_256 random_thick_256 random_thin_256
28
+ do
29
+ "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
30
+ done
bin/predict.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Example command:
4
+ # ./bin/predict.py \
5
+ # model.path=<path to checkpoint, prepared by make_checkpoint.py> \
6
+ # indir=<path to input data> \
7
+ # outdir=<where to store predicts>
8
+
9
+ import logging
10
+ import os
11
+ import sys
12
+ import traceback
13
+
14
+ from saicinpainting.evaluation.utils import move_to_device
15
+
16
+ os.environ['OMP_NUM_THREADS'] = '1'
17
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
18
+ os.environ['MKL_NUM_THREADS'] = '1'
19
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
20
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
21
+
22
+ import cv2
23
+ import hydra
24
+ import numpy as np
25
+ import torch
26
+ import tqdm
27
+ import yaml
28
+ from omegaconf import OmegaConf
29
+ from torch.utils.data._utils.collate import default_collate
30
+
31
+ from saicinpainting.training.data.datasets import make_default_val_dataset
32
+ from saicinpainting.training.trainers import load_checkpoint
33
+ from saicinpainting.utils import register_debug_signal_handlers
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ @hydra.main(config_path='../configs/prediction', config_name='default.yaml')
39
+ def main(predict_config: OmegaConf):
40
+ try:
41
+ register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
42
+
43
+ device = torch.device(predict_config.device)
44
+
45
+ train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
46
+ with open(train_config_path, 'r') as f:
47
+ train_config = OmegaConf.create(yaml.safe_load(f))
48
+
49
+ train_config.training_model.predict_only = True
50
+
51
+ out_ext = predict_config.get('out_ext', '.png')
52
+
53
+ checkpoint_path = os.path.join(predict_config.model.path,
54
+ 'models',
55
+ predict_config.model.checkpoint)
56
+ model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
57
+ model.freeze()
58
+ model.to(device)
59
+
60
+ if not predict_config.indir.endswith('/'):
61
+ predict_config.indir += '/'
62
+
63
+ dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
64
+ with torch.no_grad():
65
+ for img_i in tqdm.trange(len(dataset)):
66
+ mask_fname = dataset.mask_filenames[img_i]
67
+ cur_out_fname = os.path.join(
68
+ predict_config.outdir,
69
+ os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
70
+ )
71
+ os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
72
+
73
+ batch = move_to_device(default_collate([dataset[img_i]]), device)
74
+ batch['mask'] = (batch['mask'] > 0) * 1
75
+ batch = model(batch)
76
+ cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
77
+
78
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
79
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
80
+ cv2.imwrite(cur_out_fname, cur_res)
81
+ except KeyboardInterrupt:
82
+ LOGGER.warning('Interrupted by user')
83
+ except Exception as ex:
84
+ LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
85
+ sys.exit(1)
86
+
87
+
88
+ if __name__ == '__main__':
89
+ main()
bin/predict_inner_features.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Example command:
4
+ # ./bin/predict.py \
5
+ # model.path=<path to checkpoint, prepared by make_checkpoint.py> \
6
+ # indir=<path to input data> \
7
+ # outdir=<where to store predicts>
8
+
9
+ import logging
10
+ import os
11
+ import sys
12
+ import traceback
13
+
14
+ from saicinpainting.evaluation.utils import move_to_device
15
+
16
+ os.environ['OMP_NUM_THREADS'] = '1'
17
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
18
+ os.environ['MKL_NUM_THREADS'] = '1'
19
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
20
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
21
+
22
+ import cv2
23
+ import hydra
24
+ import numpy as np
25
+ import torch
26
+ import tqdm
27
+ import yaml
28
+ from omegaconf import OmegaConf
29
+ from torch.utils.data._utils.collate import default_collate
30
+
31
+ from saicinpainting.training.data.datasets import make_default_val_dataset
32
+ from saicinpainting.training.trainers import load_checkpoint, DefaultInpaintingTrainingModule
33
+ from saicinpainting.utils import register_debug_signal_handlers, get_shape
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ @hydra.main(config_path='../configs/prediction', config_name='default_inner_features.yaml')
39
+ def main(predict_config: OmegaConf):
40
+ try:
41
+ register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
42
+
43
+ device = torch.device(predict_config.device)
44
+
45
+ train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
46
+ with open(train_config_path, 'r') as f:
47
+ train_config = OmegaConf.create(yaml.safe_load(f))
48
+
49
+ checkpoint_path = os.path.join(predict_config.model.path, 'models', predict_config.model.checkpoint)
50
+ model = load_checkpoint(train_config, checkpoint_path, strict=False)
51
+ model.freeze()
52
+ model.to(device)
53
+
54
+ assert isinstance(model, DefaultInpaintingTrainingModule), 'Only DefaultInpaintingTrainingModule is supported'
55
+ assert isinstance(getattr(model.generator, 'model', None), torch.nn.Sequential)
56
+
57
+ if not predict_config.indir.endswith('/'):
58
+ predict_config.indir += '/'
59
+
60
+ dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
61
+
62
+ max_level = max(predict_config.levels)
63
+
64
+ with torch.no_grad():
65
+ for img_i in tqdm.trange(len(dataset)):
66
+ mask_fname = dataset.mask_filenames[img_i]
67
+ cur_out_fname = os.path.join(predict_config.outdir, os.path.splitext(mask_fname[len(predict_config.indir):])[0])
68
+ os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
69
+
70
+ batch = move_to_device(default_collate([dataset[img_i]]), device)
71
+
72
+ img = batch['image']
73
+ mask = batch['mask']
74
+ mask[:] = 0
75
+ mask_h, mask_w = mask.shape[-2:]
76
+ mask[:, :,
77
+ mask_h // 2 - predict_config.hole_radius : mask_h // 2 + predict_config.hole_radius,
78
+ mask_w // 2 - predict_config.hole_radius : mask_w // 2 + predict_config.hole_radius] = 1
79
+
80
+ masked_img = torch.cat([img * (1 - mask), mask], dim=1)
81
+
82
+ feats = masked_img
83
+ for level_i, level in enumerate(model.generator.model):
84
+ feats = level(feats)
85
+ if level_i in predict_config.levels:
86
+ cur_feats = torch.cat([f for f in feats if torch.is_tensor(f)], dim=1) \
87
+ if isinstance(feats, tuple) else feats
88
+
89
+ if predict_config.slice_channels:
90
+ cur_feats = cur_feats[:, slice(*predict_config.slice_channels)]
91
+
92
+ cur_feat = cur_feats.pow(2).mean(1).pow(0.5).clone()
93
+ cur_feat -= cur_feat.min()
94
+ cur_feat /= cur_feat.std()
95
+ cur_feat = cur_feat.clamp(0, 1) / 1
96
+ cur_feat = cur_feat.cpu().numpy()[0]
97
+ cur_feat *= 255
98
+ cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
99
+ cv2.imwrite(cur_out_fname + f'_lev{level_i:02d}_norm.png', cur_feat)
100
+
101
+ # for channel_i in predict_config.channels:
102
+ #
103
+ # cur_feat = cur_feats[0, channel_i].clone().detach().cpu().numpy()
104
+ # cur_feat -= cur_feat.min()
105
+ # cur_feat /= cur_feat.max()
106
+ # cur_feat *= 255
107
+ # cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
108
+ # cv2.imwrite(cur_out_fname + f'_lev{level_i}_ch{channel_i}.png', cur_feat)
109
+ elif level_i >= max_level:
110
+ break
111
+ except KeyboardInterrupt:
112
+ LOGGER.warning('Interrupted by user')
113
+ except Exception as ex:
114
+ LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
115
+ sys.exit(1)
116
+
117
+
118
+ if __name__ == '__main__':
119
+ main()
bin/report_from_tb.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import glob
4
+ import os
5
+ import re
6
+
7
+ import tensorflow as tf
8
+ from torch.utils.tensorboard import SummaryWriter
9
+
10
+
11
+ GROUPING_RULES = [
12
+ re.compile(r'^(?P<group>train|test|val|extra_val_.*?(256|512))_(?P<title>.*)', re.I)
13
+ ]
14
+
15
+
16
+ DROP_RULES = [
17
+ re.compile(r'_std$', re.I)
18
+ ]
19
+
20
+
21
+ def need_drop(tag):
22
+ for rule in DROP_RULES:
23
+ if rule.search(tag):
24
+ return True
25
+ return False
26
+
27
+
28
+ def get_group_and_title(tag):
29
+ for rule in GROUPING_RULES:
30
+ match = rule.search(tag)
31
+ if match is None:
32
+ continue
33
+ return match.group('group'), match.group('title')
34
+ return None, None
35
+
36
+
37
+ def main(args):
38
+ os.makedirs(args.outdir, exist_ok=True)
39
+
40
+ ignored_events = set()
41
+
42
+ for orig_fname in glob.glob(args.inglob):
43
+ cur_dirpath = os.path.dirname(orig_fname) # remove filename, this should point to "version_0" directory
44
+ subdirname = os.path.basename(cur_dirpath) # == "version_0" most of time
45
+ exp_root_path = os.path.dirname(cur_dirpath) # remove "version_0"
46
+ exp_name = os.path.basename(exp_root_path)
47
+
48
+ writers_by_group = {}
49
+
50
+ for e in tf.compat.v1.train.summary_iterator(orig_fname):
51
+ for v in e.summary.value:
52
+ if need_drop(v.tag):
53
+ continue
54
+
55
+ cur_group, cur_title = get_group_and_title(v.tag)
56
+ if cur_group is None:
57
+ if v.tag not in ignored_events:
58
+ print(f'WARNING: Could not detect group for {v.tag}, ignoring it')
59
+ ignored_events.add(v.tag)
60
+ continue
61
+
62
+ cur_writer = writers_by_group.get(cur_group, None)
63
+ if cur_writer is None:
64
+ if args.include_version:
65
+ cur_outdir = os.path.join(args.outdir, exp_name, f'{subdirname}_{cur_group}')
66
+ else:
67
+ cur_outdir = os.path.join(args.outdir, exp_name, cur_group)
68
+ cur_writer = SummaryWriter(cur_outdir)
69
+ writers_by_group[cur_group] = cur_writer
70
+
71
+ cur_writer.add_scalar(cur_title, v.simple_value, global_step=e.step, walltime=e.wall_time)
72
+
73
+
74
+ if __name__ == '__main__':
75
+ import argparse
76
+
77
+ aparser = argparse.ArgumentParser()
78
+ aparser.add_argument('inglob', type=str)
79
+ aparser.add_argument('outdir', type=str)
80
+ aparser.add_argument('--include-version', action='store_true',
81
+ help='Include subdirectory name e.g. "version_0" into output path')
82
+
83
+ main(aparser.parse_args())
bin/sample_from_dataset.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ import tqdm
7
+ from skimage import io
8
+ from skimage.segmentation import mark_boundaries
9
+
10
+ from saicinpainting.evaluation.data import InpaintingDataset
11
+ from saicinpainting.evaluation.vis import save_item_for_vis
12
+
13
+ def save_mask_for_sidebyside(item, out_file):
14
+ mask = item['mask']# > 0.5
15
+ if mask.ndim == 3:
16
+ mask = mask[0]
17
+ mask = np.clip(mask * 255, 0, 255).astype('uint8')
18
+ io.imsave(out_file, mask)
19
+
20
+ def save_img_for_sidebyside(item, out_file):
21
+ img = np.transpose(item['image'], (1, 2, 0))
22
+ img = np.clip(img * 255, 0, 255).astype('uint8')
23
+ io.imsave(out_file, img)
24
+
25
+ def save_masked_img_for_sidebyside(item, out_file):
26
+ mask = item['mask']
27
+ img = item['image']
28
+
29
+ img = (1-mask) * img + mask
30
+ img = np.transpose(img, (1, 2, 0))
31
+
32
+ img = np.clip(img * 255, 0, 255).astype('uint8')
33
+ io.imsave(out_file, img)
34
+
35
+ def main(args):
36
+ dataset = InpaintingDataset(args.datadir, img_suffix='.png')
37
+
38
+ area_bins = np.linspace(0, 1, args.area_bins + 1)
39
+
40
+ heights = []
41
+ widths = []
42
+ image_areas = []
43
+ hole_areas = []
44
+ hole_area_percents = []
45
+ area_bins_count = np.zeros(args.area_bins)
46
+ area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
47
+
48
+ bin2i = [[] for _ in range(args.area_bins)]
49
+
50
+ for i, item in enumerate(tqdm.tqdm(dataset)):
51
+ h, w = item['image'].shape[1:]
52
+ heights.append(h)
53
+ widths.append(w)
54
+ full_area = h * w
55
+ image_areas.append(full_area)
56
+ hole_area = (item['mask'] == 1).sum()
57
+ hole_areas.append(hole_area)
58
+ hole_percent = hole_area / full_area
59
+ hole_area_percents.append(hole_percent)
60
+ bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
61
+ area_bins_count[bin_i] += 1
62
+ bin2i[bin_i].append(i)
63
+
64
+ os.makedirs(args.outdir, exist_ok=True)
65
+
66
+ for bin_i in range(args.area_bins):
67
+ bindir = os.path.join(args.outdir, area_bin_titles[bin_i])
68
+ os.makedirs(bindir, exist_ok=True)
69
+ bin_idx = bin2i[bin_i]
70
+ for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
71
+ item = dataset[sample_i]
72
+ path = os.path.join(bindir, dataset.img_filenames[sample_i].split('/')[-1])
73
+ save_masked_img_for_sidebyside(item, path)
74
+
75
+
76
+ if __name__ == '__main__':
77
+ import argparse
78
+
79
+ aparser = argparse.ArgumentParser()
80
+ aparser.add_argument('--datadir', type=str,
81
+ help='Path to folder with images and masks (output of gen_mask_dataset.py)')
82
+ aparser.add_argument('--outdir', type=str, help='Where to put results')
83
+ aparser.add_argument('--samples-n', type=int, default=10,
84
+ help='Number of sample images with masks to copy for visualization for each area bin')
85
+ aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
86
+
87
+ main(aparser.parse_args())
bin/side_by_side.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import random
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
9
+ from saicinpainting.evaluation.utils import load_yaml
10
+ from saicinpainting.training.visualizers.base import visualize_mask_and_images
11
+
12
+
13
+ def main(args):
14
+ config = load_yaml(args.config)
15
+
16
+ datasets = [PrecomputedInpaintingResultsDataset(args.datadir, cur_predictdir, **config.dataset_kwargs)
17
+ for cur_predictdir in args.predictdirs]
18
+ assert len({len(ds) for ds in datasets}) == 1
19
+ len_first = len(datasets[0])
20
+
21
+ indices = list(range(len_first))
22
+ if len_first > args.max_n:
23
+ indices = sorted(random.sample(indices, args.max_n))
24
+
25
+ os.makedirs(args.outpath, exist_ok=True)
26
+
27
+ filename2i = {}
28
+
29
+ keys = ['image'] + [i for i in range(len(datasets))]
30
+ for img_i in indices:
31
+ try:
32
+ mask_fname = os.path.basename(datasets[0].mask_filenames[img_i])
33
+ if mask_fname in filename2i:
34
+ filename2i[mask_fname] += 1
35
+ idx = filename2i[mask_fname]
36
+ mask_fname_only, ext = os.path.split(mask_fname)
37
+ mask_fname = f'{mask_fname_only}_{idx}{ext}'
38
+ else:
39
+ filename2i[mask_fname] = 1
40
+
41
+ cur_vis_dict = datasets[0][img_i]
42
+ for ds_i, ds in enumerate(datasets):
43
+ cur_vis_dict[ds_i] = ds[img_i]['inpainted']
44
+
45
+ vis_img = visualize_mask_and_images(cur_vis_dict, keys,
46
+ last_without_mask=False,
47
+ mask_only_first=True,
48
+ black_mask=args.black)
49
+ vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
50
+
51
+ out_fname = os.path.join(args.outpath, mask_fname)
52
+
53
+
54
+
55
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
56
+ cv2.imwrite(out_fname, vis_img)
57
+ except Exception as ex:
58
+ print(f'Could not process {img_i} due to {ex}')
59
+
60
+
61
+ if __name__ == '__main__':
62
+ import argparse
63
+
64
+ aparser = argparse.ArgumentParser()
65
+ aparser.add_argument('--max-n', type=int, default=100, help='Maximum number of images to print')
66
+ aparser.add_argument('--black', action='store_true', help='Whether to fill mask on GT with black')
67
+ aparser.add_argument('config', type=str, help='Path to evaluation config (e.g. configs/eval1.yaml)')
68
+ aparser.add_argument('outpath', type=str, help='Where to put results')
69
+ aparser.add_argument('datadir', type=str,
70
+ help='Path to folder with images and masks')
71
+ aparser.add_argument('predictdirs', type=str,
72
+ nargs='+',
73
+ help='Path to folders with predicts')
74
+
75
+
76
+ main(aparser.parse_args())
bin/split_tar.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import tqdm
5
+ import webdataset as wds
6
+
7
+
8
+ def main(args):
9
+ input_dataset = wds.Dataset(args.infile)
10
+ output_dataset = wds.ShardWriter(args.outpattern)
11
+ for rec in tqdm.tqdm(input_dataset):
12
+ output_dataset.write(rec)
13
+
14
+
15
+ if __name__ == '__main__':
16
+ import argparse
17
+
18
+ aparser = argparse.ArgumentParser()
19
+ aparser.add_argument('infile', type=str)
20
+ aparser.add_argument('outpattern', type=str)
21
+
22
+ main(aparser.parse_args())
bin/train.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+ import traceback
7
+
8
+ os.environ['OMP_NUM_THREADS'] = '1'
9
+ os.environ['OPENBLAS_NUM_THREADS'] = '1'
10
+ os.environ['MKL_NUM_THREADS'] = '1'
11
+ os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
12
+ os.environ['NUMEXPR_NUM_THREADS'] = '1'
13
+
14
+ import hydra
15
+ from omegaconf import OmegaConf
16
+ from pytorch_lightning import Trainer
17
+ from pytorch_lightning.callbacks import ModelCheckpoint
18
+ from pytorch_lightning.loggers import TensorBoardLogger
19
+ from pytorch_lightning.plugins import DDPPlugin
20
+
21
+ from saicinpainting.training.trainers import make_training_model
22
+ from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \
23
+ handle_deterministic_config
24
+
25
+ LOGGER = logging.getLogger(__name__)
26
+
27
+
28
+ @handle_ddp_subprocess()
29
+ @hydra.main(config_path='../configs/training', config_name='tiny_test.yaml')
30
+ def main(config: OmegaConf):
31
+ try:
32
+ need_set_deterministic = handle_deterministic_config(config)
33
+
34
+ register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
35
+
36
+ is_in_ddp_subprocess = handle_ddp_parent_process()
37
+
38
+ config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir)
39
+ if not is_in_ddp_subprocess:
40
+ LOGGER.info(OmegaConf.to_yaml(config))
41
+ OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml'))
42
+
43
+ checkpoints_dir = os.path.join(os.getcwd(), 'models')
44
+ os.makedirs(checkpoints_dir, exist_ok=True)
45
+
46
+ # there is no need to suppress this logger in ddp, because it handles rank on its own
47
+ metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd()))
48
+ metrics_logger.log_hyperparams(config)
49
+
50
+ training_model = make_training_model(config)
51
+
52
+ trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True)
53
+ if need_set_deterministic:
54
+ trainer_kwargs['deterministic'] = True
55
+
56
+ trainer = Trainer(
57
+ # there is no need to suppress checkpointing in ddp, because it handles rank on its own
58
+ callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs),
59
+ logger=metrics_logger,
60
+ default_root_dir=os.getcwd(),
61
+ **trainer_kwargs
62
+ )
63
+ trainer.fit(training_model)
64
+ except KeyboardInterrupt:
65
+ LOGGER.warning('Interrupted by user')
66
+ except Exception as ex:
67
+ LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}')
68
+ sys.exit(1)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ main()
configs/analyze_mask_errors.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dataset_kwargs:
2
+ img_suffix: .jpg
3
+ inpainted_suffix: .jpg
4
+
5
+ take_global_top: 30
6
+ take_worst_best_top: 30
7
+ take_overlapping_top: 30
configs/data_gen/gen_segm_dataset1.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: segmentation
2
+
3
+ mask_generator_kwargs:
4
+ confidence_threshold: 0.5
5
+ max_object_area: 0.5
6
+ min_mask_area: 0.02
7
+ downsample_levels: 6
8
+ num_variants_per_mask: 5
9
+ rigidness_mode: 1
10
+ max_foreground_coverage: 0.3
11
+ max_foreground_intersection: 0.7
12
+ max_mask_intersection: 0.1
13
+ max_hidden_area: 0.1
14
+ max_scale_change: 0.25
15
+ horizontal_flip: True
16
+ max_vertical_shift: 0.2
17
+ position_shuffle: True
18
+
19
+ max_masks_per_image: 5
20
+
21
+ cropping:
22
+ out_min_size: 512
23
+ handle_small_mode: drop
24
+ out_square_crop: True
25
+ crop_min_overlap: 0.5
configs/data_gen/gen_segm_dataset3.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: segmentation
2
+
3
+ mask_generator_kwargs:
4
+ confidence_threshold: 0.5
5
+ max_object_area: 0.5
6
+ min_mask_area: 0.07
7
+ downsample_levels: 6
8
+ num_variants_per_mask: 3
9
+ rigidness_mode: 1
10
+ max_foreground_coverage: 0.4
11
+ max_foreground_intersection: 0.8
12
+ max_mask_intersection: 0.2
13
+ max_hidden_area: 0.1
14
+ max_scale_change: 0.25
15
+ horizontal_flip: True
16
+ max_vertical_shift: 0.3
17
+ position_shuffle: True
18
+
19
+ max_masks_per_image: 3
20
+
21
+ cropping:
22
+ out_min_size: 512
23
+ handle_small_mode: drop
24
+ out_square_crop: True
25
+ crop_min_overlap: 0.5
configs/data_gen/random_medium_256.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 5
8
+ max_width: 50
9
+ max_angle: 4
10
+ max_len: 100
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 0
15
+ bbox_min_size: 10
16
+ bbox_max_size: 50
17
+ max_times: 5
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 256
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
configs/data_gen/random_medium_512.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 10
8
+ max_width: 100
9
+ max_angle: 4
10
+ max_len: 200
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 0
15
+ bbox_min_size: 30
16
+ bbox_max_size: 150
17
+ max_times: 5
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 512
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
configs/data_gen/random_thick_256.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 1
7
+ max_times: 5
8
+ max_width: 100
9
+ max_angle: 4
10
+ max_len: 200
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 10
15
+ bbox_min_size: 30
16
+ bbox_max_size: 150
17
+ max_times: 3
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 256
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
configs/data_gen/random_thick_512.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 1
7
+ max_times: 5
8
+ max_width: 250
9
+ max_angle: 4
10
+ max_len: 450
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 10
15
+ bbox_min_size: 30
16
+ bbox_max_size: 300
17
+ max_times: 4
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 512
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
configs/data_gen/random_thin_256.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 50
8
+ max_width: 10
9
+ max_angle: 4
10
+ max_len: 40
11
+ box_proba: 0
12
+ segm_proba: 0
13
+ squares_proba: 0
14
+
15
+ variants_n: 5
16
+
17
+ max_masks_per_image: 1
18
+
19
+ cropping:
20
+ out_min_size: 256
21
+ handle_small_mode: upscale
22
+ out_square_crop: True
23
+ crop_min_overlap: 1
24
+
25
+ max_tamper_area: 0.5
configs/data_gen/random_thin_512.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 70
8
+ max_width: 20
9
+ max_angle: 4
10
+ max_len: 100
11
+ box_proba: 0
12
+ segm_proba: 0
13
+ squares_proba: 0
14
+
15
+ variants_n: 5
16
+
17
+ max_masks_per_image: 1
18
+
19
+ cropping:
20
+ out_min_size: 512
21
+ handle_small_mode: upscale
22
+ out_square_crop: True
23
+ crop_min_overlap: 1
24
+
25
+ max_tamper_area: 0.5
configs/data_gen/segm_256.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: segmentation
2
+
3
+ mask_generator_kwargs:
4
+ confidence_threshold: 0.5
5
+ max_object_area: 0.5
6
+ min_mask_area: 0.05
7
+ downsample_levels: 6
8
+ num_variants_per_mask: 3
9
+ rigidness_mode: 1
10
+ max_foreground_coverage: 1 # turn off filtering by overlap
11
+ max_foreground_intersection: 1 # turn off filtering by overlap
12
+ max_mask_intersection: 0.2 # the lower this value the higher diversity
13
+ max_hidden_area: 0.5
14
+ max_scale_change: 0.25
15
+ horizontal_flip: True
16
+ max_vertical_shift: 0.3
17
+ position_shuffle: True
18
+
19
+ max_masks_per_image: 1
20
+
21
+ cropping:
22
+ out_min_size: 256
23
+ handle_small_mode: upscale
24
+ out_square_crop: True
25
+ crop_min_overlap: 1
26
+
27
+ max_tamper_area: 0.5
configs/data_gen/segm_512.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: segmentation
2
+
3
+ mask_generator_kwargs:
4
+ confidence_threshold: 0.5
5
+ max_object_area: 0.5
6
+ min_mask_area: 0.05
7
+ downsample_levels: 6
8
+ num_variants_per_mask: 3
9
+ rigidness_mode: 1
10
+ max_foreground_coverage: 1 # turn off filtering by overlap
11
+ max_foreground_intersection: 1 # turn off filtering by overlap
12
+ max_mask_intersection: 0.2 # the lower this value the higher diversity
13
+ max_hidden_area: 0.5
14
+ max_scale_change: 0.25
15
+ horizontal_flip: True
16
+ max_vertical_shift: 0.3
17
+ position_shuffle: True
18
+
19
+ max_masks_per_image: 1
20
+
21
+ cropping:
22
+ out_min_size: 512
23
+ handle_small_mode: upscale
24
+ out_square_crop: True
25
+ crop_min_overlap: 1
26
+
27
+ max_tamper_area: 0.5
configs/data_gen/sr_256.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 0
5
+ box_proba: 0
6
+ segm_proba: 0
7
+ squares_proba: 0
8
+ superres_proba: 1
9
+ superres_kwargs:
10
+ min_step: 2
11
+ max_step: 4
12
+ min_width: 1
13
+ max_width: 3
14
+
15
+ variants_n: 5
16
+
17
+ max_masks_per_image: 1
18
+
19
+ cropping:
20
+ out_min_size: 256
21
+ handle_small_mode: upscale
22
+ out_square_crop: True
23
+ crop_min_overlap: 1
24
+
25
+ max_tamper_area: 1
configs/data_gen/whydra/location/mml-ws01-celeba-hq.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ root_dir: /media/inpainting/CelebA-HQ
4
+ out_dir: /media/inpainting/paper_data/CelebA-HQ_val_test
5
+ extension: jpg
configs/data_gen/whydra/location/mml-ws01-ffhq.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ root_dir: /media/inpainting/FFHQ/
4
+ out_dir: /media/inpainting/paper_data/FFHQ_val
5
+ extension: png
configs/data_gen/whydra/location/mml-ws01-paris.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ root_dir: /media/inpainting/Paris_StreetView_Dataset
4
+ out_dir: /media/inpainting/paper_data/Paris_StreetView_Dataset_val
5
+ extension: png
configs/data_gen/whydra/location/mml7-places.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ root_dir: /data/inpainting/Places365
4
+ out_dir: /data/inpainting/paper_data/Places365_val_test
5
+ extension: jpg
configs/data_gen/whydra/random_medium_256.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datadir: val_large
2
+ indir: ${location.root_dir}/${datadir}
3
+ outdir: ${location.out_dir}/${datadir}/random_medium_256
4
+
5
+ n_jobs: 8
6
+
7
+ generator_kind: random
8
+
9
+ mask_generator_kwargs:
10
+ irregular_proba: 1
11
+ irregular_kwargs:
12
+ min_times: 4
13
+ max_times: 5
14
+ max_width: 50
15
+ max_angle: 4
16
+ max_len: 100
17
+
18
+ box_proba: 0.3
19
+ box_kwargs:
20
+ margin: 0
21
+ bbox_min_size: 10
22
+ bbox_max_size: 50
23
+ max_times: 5
24
+ min_times: 1
25
+
26
+ segm_proba: 0
27
+ squares_proba: 0
28
+
29
+ variants_n: 5
30
+
31
+ max_masks_per_image: 1
32
+
33
+ cropping:
34
+ out_min_size: 256
35
+ handle_small_mode: upscale
36
+ out_square_crop: True
37
+ crop_min_overlap: 1
38
+
39
+ max_tamper_area: 0.5
40
+
41
+ defaults:
42
+ - location: mml7-places
configs/data_gen/whydra/random_medium_512.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datadir: val_large
2
+ indir: ${location.root_dir}/${datadir}
3
+ outdir: ${location.out_dir}/${datadir}/random_medium_512
4
+
5
+ n_jobs: 8
6
+
7
+ generator_kind: random
8
+
9
+ mask_generator_kwargs:
10
+ irregular_proba: 1
11
+ irregular_kwargs:
12
+ min_times: 4
13
+ max_times: 10
14
+ max_width: 100
15
+ max_angle: 4
16
+ max_len: 200
17
+
18
+ box_proba: 0.3
19
+ box_kwargs:
20
+ margin: 0
21
+ bbox_min_size: 30
22
+ bbox_max_size: 150
23
+ max_times: 5
24
+ min_times: 1
25
+
26
+ segm_proba: 0
27
+ squares_proba: 0
28
+
29
+ variants_n: 5
30
+
31
+ max_masks_per_image: 1
32
+
33
+ cropping:
34
+ out_min_size: 512
35
+ handle_small_mode: upscale
36
+ out_square_crop: True
37
+ crop_min_overlap: 1
38
+
39
+ max_tamper_area: 0.5
40
+
41
+ defaults:
42
+ - location: mml7-places