Upload 203 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- bin/analyze_errors.py +316 -0
- bin/blur_predicts.py +57 -0
- bin/calc_dataset_stats.py +88 -0
- bin/debug/analyze_overlapping_masks.sh +31 -0
- bin/evaluate_predicts.py +79 -0
- bin/evaluator_example.py +76 -0
- bin/extract_masks.py +63 -0
- bin/filter_sharded_dataset.py +69 -0
- bin/gen_debug_mask_dataset.py +61 -0
- bin/gen_mask_dataset.py +130 -0
- bin/gen_mask_dataset_hydra.py +124 -0
- bin/gen_outpainting_dataset.py +88 -0
- bin/make_checkpoint.py +79 -0
- bin/mask_example.py +14 -0
- bin/paper_runfiles/blur_tests.sh +37 -0
- bin/paper_runfiles/env.sh +8 -0
- bin/paper_runfiles/find_best_checkpoint.py +54 -0
- bin/paper_runfiles/generate_test_celeba-hq.sh +17 -0
- bin/paper_runfiles/generate_test_ffhq.sh +17 -0
- bin/paper_runfiles/generate_test_paris.sh +17 -0
- bin/paper_runfiles/generate_test_paris_256.sh +17 -0
- bin/paper_runfiles/generate_val_test.sh +28 -0
- bin/paper_runfiles/predict_inner_features.sh +20 -0
- bin/paper_runfiles/update_test_data_stats.sh +30 -0
- bin/predict.py +89 -0
- bin/predict_inner_features.py +119 -0
- bin/report_from_tb.py +83 -0
- bin/sample_from_dataset.py +87 -0
- bin/side_by_side.py +76 -0
- bin/split_tar.py +22 -0
- bin/train.py +72 -0
- configs/analyze_mask_errors.yaml +7 -0
- configs/data_gen/gen_segm_dataset1.yaml +25 -0
- configs/data_gen/gen_segm_dataset3.yaml +25 -0
- configs/data_gen/random_medium_256.yaml +33 -0
- configs/data_gen/random_medium_512.yaml +33 -0
- configs/data_gen/random_thick_256.yaml +33 -0
- configs/data_gen/random_thick_512.yaml +33 -0
- configs/data_gen/random_thin_256.yaml +25 -0
- configs/data_gen/random_thin_512.yaml +25 -0
- configs/data_gen/segm_256.yaml +27 -0
- configs/data_gen/segm_512.yaml +27 -0
- configs/data_gen/sr_256.yaml +25 -0
- configs/data_gen/whydra/location/mml-ws01-celeba-hq.yaml +5 -0
- configs/data_gen/whydra/location/mml-ws01-ffhq.yaml +5 -0
- configs/data_gen/whydra/location/mml-ws01-paris.yaml +5 -0
- configs/data_gen/whydra/location/mml7-places.yaml +5 -0
- configs/data_gen/whydra/random_medium_256.yaml +42 -0
- configs/data_gen/whydra/random_medium_512.yaml +42 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
saicinpainting/evaluation/masks/countless/images/gcim.jpg filter=lfs diff=lfs merge=lfs -text
|
bin/analyze_errors.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import sklearn
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
import pandas as pd
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
|
12 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset, load_image
|
13 |
+
from saicinpainting.evaluation.losses.fid.inception import InceptionV3
|
14 |
+
from saicinpainting.evaluation.utils import load_yaml
|
15 |
+
from saicinpainting.training.visualizers.base import visualize_mask_and_images
|
16 |
+
|
17 |
+
|
18 |
+
def draw_score(img, score):
|
19 |
+
img = np.transpose(img, (1, 2, 0))
|
20 |
+
cv2.putText(img, f'{score:.2f}',
|
21 |
+
(40, 40),
|
22 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
23 |
+
1,
|
24 |
+
(0, 1, 0),
|
25 |
+
thickness=3)
|
26 |
+
img = np.transpose(img, (2, 0, 1))
|
27 |
+
return img
|
28 |
+
|
29 |
+
|
30 |
+
def save_global_samples(global_mask_fnames, mask2real_fname, mask2fake_fname, out_dir, real_scores_by_fname, fake_scores_by_fname):
|
31 |
+
for cur_mask_fname in global_mask_fnames:
|
32 |
+
cur_real_fname = mask2real_fname[cur_mask_fname]
|
33 |
+
orig_img = load_image(cur_real_fname, mode='RGB')
|
34 |
+
fake_img = load_image(mask2fake_fname[cur_mask_fname], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
|
35 |
+
mask = load_image(cur_mask_fname, mode='L')[None, ...]
|
36 |
+
|
37 |
+
draw_score(orig_img, real_scores_by_fname.loc[cur_real_fname, 'real_score'])
|
38 |
+
draw_score(fake_img, fake_scores_by_fname.loc[cur_mask_fname, 'fake_score'])
|
39 |
+
|
40 |
+
cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=mask, fake=fake_img),
|
41 |
+
keys=['image', 'fake'],
|
42 |
+
last_without_mask=True)
|
43 |
+
cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
|
44 |
+
cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
|
45 |
+
cv2.imwrite(os.path.join(out_dir, os.path.splitext(os.path.basename(cur_mask_fname))[0] + '.jpg'),
|
46 |
+
cur_grid)
|
47 |
+
|
48 |
+
|
49 |
+
def save_samples_by_real(worst_best_by_real, mask2fake_fname, fake_info, out_dir):
|
50 |
+
for real_fname in worst_best_by_real.index:
|
51 |
+
worst_mask_path = worst_best_by_real.loc[real_fname, 'worst']
|
52 |
+
best_mask_path = worst_best_by_real.loc[real_fname, 'best']
|
53 |
+
orig_img = load_image(real_fname, mode='RGB')
|
54 |
+
worst_mask_img = load_image(worst_mask_path, mode='L')[None, ...]
|
55 |
+
worst_fake_img = load_image(mask2fake_fname[worst_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
|
56 |
+
best_mask_img = load_image(best_mask_path, mode='L')[None, ...]
|
57 |
+
best_fake_img = load_image(mask2fake_fname[best_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
|
58 |
+
|
59 |
+
draw_score(orig_img, worst_best_by_real.loc[real_fname, 'real_score'])
|
60 |
+
draw_score(worst_fake_img, worst_best_by_real.loc[real_fname, 'worst_score'])
|
61 |
+
draw_score(best_fake_img, worst_best_by_real.loc[real_fname, 'best_score'])
|
62 |
+
|
63 |
+
cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=np.zeros_like(worst_mask_img),
|
64 |
+
worst_mask=worst_mask_img, worst_img=worst_fake_img,
|
65 |
+
best_mask=best_mask_img, best_img=best_fake_img),
|
66 |
+
keys=['image', 'worst_mask', 'worst_img', 'best_mask', 'best_img'],
|
67 |
+
rescale_keys=['worst_mask', 'best_mask'],
|
68 |
+
last_without_mask=True)
|
69 |
+
cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
|
70 |
+
cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
|
71 |
+
cv2.imwrite(os.path.join(out_dir,
|
72 |
+
os.path.splitext(os.path.basename(real_fname))[0] + '.jpg'),
|
73 |
+
cur_grid)
|
74 |
+
|
75 |
+
fig, (ax1, ax2) = plt.subplots(1, 2)
|
76 |
+
cur_stat = fake_info[fake_info['real_fname'] == real_fname]
|
77 |
+
cur_stat['fake_score'].hist(ax=ax1)
|
78 |
+
cur_stat['real_score'].hist(ax=ax2)
|
79 |
+
fig.tight_layout()
|
80 |
+
fig.savefig(os.path.join(out_dir,
|
81 |
+
os.path.splitext(os.path.basename(real_fname))[0] + '_scores.png'))
|
82 |
+
plt.close(fig)
|
83 |
+
|
84 |
+
|
85 |
+
def extract_overlapping_masks(mask_fnames, cur_i, fake_scores_table, max_overlaps_n=2):
|
86 |
+
result_pairs = []
|
87 |
+
result_scores = []
|
88 |
+
mask_fname_a = mask_fnames[cur_i]
|
89 |
+
mask_a = load_image(mask_fname_a, mode='L')[None, ...] > 0.5
|
90 |
+
cur_score_a = fake_scores_table.loc[mask_fname_a, 'fake_score']
|
91 |
+
for mask_fname_b in mask_fnames[cur_i + 1:]:
|
92 |
+
mask_b = load_image(mask_fname_b, mode='L')[None, ...] > 0.5
|
93 |
+
if not np.any(mask_a & mask_b):
|
94 |
+
continue
|
95 |
+
cur_score_b = fake_scores_table.loc[mask_fname_b, 'fake_score']
|
96 |
+
result_pairs.append((mask_fname_a, mask_fname_b))
|
97 |
+
result_scores.append(cur_score_b - cur_score_a)
|
98 |
+
if len(result_pairs) >= max_overlaps_n:
|
99 |
+
break
|
100 |
+
return result_pairs, result_scores
|
101 |
+
|
102 |
+
|
103 |
+
def main(args):
|
104 |
+
config = load_yaml(args.config)
|
105 |
+
|
106 |
+
latents_dir = os.path.join(args.outpath, 'latents')
|
107 |
+
os.makedirs(latents_dir, exist_ok=True)
|
108 |
+
global_worst_dir = os.path.join(args.outpath, 'global_worst')
|
109 |
+
os.makedirs(global_worst_dir, exist_ok=True)
|
110 |
+
global_best_dir = os.path.join(args.outpath, 'global_best')
|
111 |
+
os.makedirs(global_best_dir, exist_ok=True)
|
112 |
+
worst_best_by_best_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_max')
|
113 |
+
os.makedirs(worst_best_by_best_worst_score_diff_max_dir, exist_ok=True)
|
114 |
+
worst_best_by_best_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_min')
|
115 |
+
os.makedirs(worst_best_by_best_worst_score_diff_min_dir, exist_ok=True)
|
116 |
+
worst_best_by_real_best_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_max')
|
117 |
+
os.makedirs(worst_best_by_real_best_score_diff_max_dir, exist_ok=True)
|
118 |
+
worst_best_by_real_best_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_min')
|
119 |
+
os.makedirs(worst_best_by_real_best_score_diff_min_dir, exist_ok=True)
|
120 |
+
worst_best_by_real_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_max')
|
121 |
+
os.makedirs(worst_best_by_real_worst_score_diff_max_dir, exist_ok=True)
|
122 |
+
worst_best_by_real_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_min')
|
123 |
+
os.makedirs(worst_best_by_real_worst_score_diff_min_dir, exist_ok=True)
|
124 |
+
|
125 |
+
if not args.only_report:
|
126 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
|
127 |
+
inception_model = InceptionV3([block_idx]).eval().cuda()
|
128 |
+
|
129 |
+
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
|
130 |
+
|
131 |
+
real2vector_cache = {}
|
132 |
+
|
133 |
+
real_features = []
|
134 |
+
fake_features = []
|
135 |
+
|
136 |
+
orig_fnames = []
|
137 |
+
mask_fnames = []
|
138 |
+
mask2real_fname = {}
|
139 |
+
mask2fake_fname = {}
|
140 |
+
|
141 |
+
for batch_i, batch in enumerate(dataset):
|
142 |
+
orig_img_fname = dataset.img_filenames[batch_i]
|
143 |
+
mask_fname = dataset.mask_filenames[batch_i]
|
144 |
+
fake_fname = dataset.pred_filenames[batch_i]
|
145 |
+
mask2real_fname[mask_fname] = orig_img_fname
|
146 |
+
mask2fake_fname[mask_fname] = fake_fname
|
147 |
+
|
148 |
+
cur_real_vector = real2vector_cache.get(orig_img_fname, None)
|
149 |
+
if cur_real_vector is None:
|
150 |
+
with torch.no_grad():
|
151 |
+
in_img = torch.from_numpy(batch['image'][None, ...]).cuda()
|
152 |
+
cur_real_vector = inception_model(in_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
|
153 |
+
real2vector_cache[orig_img_fname] = cur_real_vector
|
154 |
+
|
155 |
+
pred_img = torch.from_numpy(batch['inpainted'][None, ...]).cuda()
|
156 |
+
cur_fake_vector = inception_model(pred_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
|
157 |
+
|
158 |
+
real_features.append(cur_real_vector)
|
159 |
+
fake_features.append(cur_fake_vector)
|
160 |
+
|
161 |
+
orig_fnames.append(orig_img_fname)
|
162 |
+
mask_fnames.append(mask_fname)
|
163 |
+
|
164 |
+
ids_features = np.concatenate(real_features + fake_features, axis=0)
|
165 |
+
ids_labels = np.array(([1] * len(real_features)) + ([0] * len(fake_features)))
|
166 |
+
|
167 |
+
with open(os.path.join(latents_dir, 'featues.pkl'), 'wb') as f:
|
168 |
+
pickle.dump(ids_features, f, protocol=3)
|
169 |
+
with open(os.path.join(latents_dir, 'labels.pkl'), 'wb') as f:
|
170 |
+
pickle.dump(ids_labels, f, protocol=3)
|
171 |
+
with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'wb') as f:
|
172 |
+
pickle.dump(orig_fnames, f, protocol=3)
|
173 |
+
with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'wb') as f:
|
174 |
+
pickle.dump(mask_fnames, f, protocol=3)
|
175 |
+
with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'wb') as f:
|
176 |
+
pickle.dump(mask2real_fname, f, protocol=3)
|
177 |
+
with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'wb') as f:
|
178 |
+
pickle.dump(mask2fake_fname, f, protocol=3)
|
179 |
+
|
180 |
+
svm = sklearn.svm.LinearSVC(dual=False)
|
181 |
+
svm.fit(ids_features, ids_labels)
|
182 |
+
|
183 |
+
pred_scores = svm.decision_function(ids_features)
|
184 |
+
real_scores = pred_scores[:len(real_features)]
|
185 |
+
fake_scores = pred_scores[len(real_features):]
|
186 |
+
|
187 |
+
with open(os.path.join(latents_dir, 'pred_scores.pkl'), 'wb') as f:
|
188 |
+
pickle.dump(pred_scores, f, protocol=3)
|
189 |
+
with open(os.path.join(latents_dir, 'real_scores.pkl'), 'wb') as f:
|
190 |
+
pickle.dump(real_scores, f, protocol=3)
|
191 |
+
with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'wb') as f:
|
192 |
+
pickle.dump(fake_scores, f, protocol=3)
|
193 |
+
else:
|
194 |
+
with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'rb') as f:
|
195 |
+
orig_fnames = pickle.load(f)
|
196 |
+
with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'rb') as f:
|
197 |
+
mask_fnames = pickle.load(f)
|
198 |
+
with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'rb') as f:
|
199 |
+
mask2real_fname = pickle.load(f)
|
200 |
+
with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'rb') as f:
|
201 |
+
mask2fake_fname = pickle.load(f)
|
202 |
+
with open(os.path.join(latents_dir, 'real_scores.pkl'), 'rb') as f:
|
203 |
+
real_scores = pickle.load(f)
|
204 |
+
with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'rb') as f:
|
205 |
+
fake_scores = pickle.load(f)
|
206 |
+
|
207 |
+
real_info = pd.DataFrame(data=[dict(real_fname=fname,
|
208 |
+
real_score=score)
|
209 |
+
for fname, score
|
210 |
+
in zip(orig_fnames, real_scores)])
|
211 |
+
real_info.set_index('real_fname', drop=True, inplace=True)
|
212 |
+
|
213 |
+
fake_info = pd.DataFrame(data=[dict(mask_fname=fname,
|
214 |
+
fake_fname=mask2fake_fname[fname],
|
215 |
+
real_fname=mask2real_fname[fname],
|
216 |
+
fake_score=score)
|
217 |
+
for fname, score
|
218 |
+
in zip(mask_fnames, fake_scores)])
|
219 |
+
fake_info = fake_info.join(real_info, on='real_fname', how='left')
|
220 |
+
fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
|
221 |
+
|
222 |
+
fake_stats_by_real = fake_info.groupby('real_fname')['fake_score'].describe()[['mean', 'std']].rename(
|
223 |
+
{'mean': 'mean_fake_by_real', 'std': 'std_fake_by_real'}, axis=1)
|
224 |
+
fake_info = fake_info.join(fake_stats_by_real, on='real_fname', rsuffix='stat_by_real')
|
225 |
+
fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
|
226 |
+
fake_info.to_csv(os.path.join(latents_dir, 'join_scores_table.csv'), sep='\t', index=False)
|
227 |
+
|
228 |
+
fake_scores_table = fake_info.set_index('mask_fname')['fake_score'].to_frame()
|
229 |
+
real_scores_table = fake_info.set_index('real_fname')['real_score'].drop_duplicates().to_frame()
|
230 |
+
|
231 |
+
fig, (ax1, ax2) = plt.subplots(1, 2)
|
232 |
+
ax1.hist(fake_scores)
|
233 |
+
ax2.hist(real_scores)
|
234 |
+
fig.tight_layout()
|
235 |
+
fig.savefig(os.path.join(args.outpath, 'global_scores_hist.png'))
|
236 |
+
plt.close(fig)
|
237 |
+
|
238 |
+
global_worst_masks = fake_info.sort_values('fake_score', ascending=True)['mask_fname'].iloc[:config.take_global_top].to_list()
|
239 |
+
global_best_masks = fake_info.sort_values('fake_score', ascending=False)['mask_fname'].iloc[:config.take_global_top].to_list()
|
240 |
+
save_global_samples(global_worst_masks, mask2real_fname, mask2fake_fname, global_worst_dir, real_scores_table, fake_scores_table)
|
241 |
+
save_global_samples(global_best_masks, mask2real_fname, mask2fake_fname, global_best_dir, real_scores_table, fake_scores_table)
|
242 |
+
|
243 |
+
# grouped by real
|
244 |
+
worst_samples_by_real = fake_info.groupby('real_fname').apply(
|
245 |
+
lambda d: d.set_index('mask_fname')['fake_score'].idxmin()).to_frame().rename({0: 'worst'}, axis=1)
|
246 |
+
best_samples_by_real = fake_info.groupby('real_fname').apply(
|
247 |
+
lambda d: d.set_index('mask_fname')['fake_score'].idxmax()).to_frame().rename({0: 'best'}, axis=1)
|
248 |
+
worst_best_by_real = pd.concat([worst_samples_by_real, best_samples_by_real], axis=1)
|
249 |
+
|
250 |
+
worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'worst_score'}, axis=1),
|
251 |
+
on='worst')
|
252 |
+
worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'best_score'}, axis=1),
|
253 |
+
on='best')
|
254 |
+
worst_best_by_real = worst_best_by_real.join(real_scores_table)
|
255 |
+
|
256 |
+
worst_best_by_real['best_worst_score_diff'] = worst_best_by_real['best_score'] - worst_best_by_real['worst_score']
|
257 |
+
worst_best_by_real['real_best_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['best_score']
|
258 |
+
worst_best_by_real['real_worst_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['worst_score']
|
259 |
+
|
260 |
+
worst_best_by_best_worst_score_diff_min = worst_best_by_real.sort_values('best_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
|
261 |
+
worst_best_by_best_worst_score_diff_max = worst_best_by_real.sort_values('best_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
|
262 |
+
save_samples_by_real(worst_best_by_best_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_min_dir)
|
263 |
+
save_samples_by_real(worst_best_by_best_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_max_dir)
|
264 |
+
|
265 |
+
worst_best_by_real_best_score_diff_min = worst_best_by_real.sort_values('real_best_score_diff', ascending=True).iloc[:config.take_worst_best_top]
|
266 |
+
worst_best_by_real_best_score_diff_max = worst_best_by_real.sort_values('real_best_score_diff', ascending=False).iloc[:config.take_worst_best_top]
|
267 |
+
save_samples_by_real(worst_best_by_real_best_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_min_dir)
|
268 |
+
save_samples_by_real(worst_best_by_real_best_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_max_dir)
|
269 |
+
|
270 |
+
worst_best_by_real_worst_score_diff_min = worst_best_by_real.sort_values('real_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
|
271 |
+
worst_best_by_real_worst_score_diff_max = worst_best_by_real.sort_values('real_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
|
272 |
+
save_samples_by_real(worst_best_by_real_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_min_dir)
|
273 |
+
save_samples_by_real(worst_best_by_real_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_max_dir)
|
274 |
+
|
275 |
+
# analyze what change of mask causes bigger change of score
|
276 |
+
overlapping_mask_fname_pairs = []
|
277 |
+
overlapping_mask_fname_score_diffs = []
|
278 |
+
for cur_real_fname in orig_fnames:
|
279 |
+
cur_fakes_info = fake_info[fake_info['real_fname'] == cur_real_fname]
|
280 |
+
cur_mask_fnames = sorted(cur_fakes_info['mask_fname'].unique())
|
281 |
+
|
282 |
+
cur_mask_pairs_and_scores = Parallel(args.n_jobs)(
|
283 |
+
delayed(extract_overlapping_masks)(cur_mask_fnames, i, fake_scores_table)
|
284 |
+
for i in range(len(cur_mask_fnames) - 1)
|
285 |
+
)
|
286 |
+
for cur_pairs, cur_scores in cur_mask_pairs_and_scores:
|
287 |
+
overlapping_mask_fname_pairs.extend(cur_pairs)
|
288 |
+
overlapping_mask_fname_score_diffs.extend(cur_scores)
|
289 |
+
|
290 |
+
overlapping_mask_fname_pairs = np.asarray(overlapping_mask_fname_pairs)
|
291 |
+
overlapping_mask_fname_score_diffs = np.asarray(overlapping_mask_fname_score_diffs)
|
292 |
+
overlapping_sort_idx = np.argsort(overlapping_mask_fname_score_diffs)
|
293 |
+
overlapping_mask_fname_pairs = overlapping_mask_fname_pairs[overlapping_sort_idx]
|
294 |
+
overlapping_mask_fname_score_diffs = overlapping_mask_fname_score_diffs[overlapping_sort_idx]
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
if __name__ == '__main__':
|
302 |
+
import argparse
|
303 |
+
|
304 |
+
aparser = argparse.ArgumentParser()
|
305 |
+
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
|
306 |
+
aparser.add_argument('datadir', type=str,
|
307 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
308 |
+
aparser.add_argument('predictdir', type=str,
|
309 |
+
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
|
310 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
311 |
+
aparser.add_argument('--only-report', action='store_true',
|
312 |
+
help='Whether to skip prediction and feature extraction, '
|
313 |
+
'load all the possible latents and proceed with report only')
|
314 |
+
aparser.add_argument('--n-jobs', type=int, default=8, help='how many processes to use for pair mask mining')
|
315 |
+
|
316 |
+
main(aparser.parse_args())
|
bin/blur_predicts.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
|
10 |
+
from saicinpainting.evaluation.utils import load_yaml
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
config = load_yaml(args.config)
|
15 |
+
|
16 |
+
if not args.predictdir.endswith('/'):
|
17 |
+
args.predictdir += '/'
|
18 |
+
|
19 |
+
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
|
20 |
+
|
21 |
+
os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
|
22 |
+
|
23 |
+
for img_i in tqdm.trange(len(dataset)):
|
24 |
+
pred_fname = dataset.pred_filenames[img_i]
|
25 |
+
cur_out_fname = os.path.join(args.outpath, pred_fname[len(args.predictdir):])
|
26 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
27 |
+
|
28 |
+
sample = dataset[img_i]
|
29 |
+
img = sample['image']
|
30 |
+
mask = sample['mask']
|
31 |
+
inpainted = sample['inpainted']
|
32 |
+
|
33 |
+
inpainted_blurred = cv2.GaussianBlur(np.transpose(inpainted, (1, 2, 0)),
|
34 |
+
ksize=(args.k, args.k),
|
35 |
+
sigmaX=args.s, sigmaY=args.s,
|
36 |
+
borderType=cv2.BORDER_REFLECT)
|
37 |
+
|
38 |
+
cur_res = (1 - mask) * np.transpose(img, (1, 2, 0)) + mask * inpainted_blurred
|
39 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
40 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
41 |
+
cv2.imwrite(cur_out_fname, cur_res)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
import argparse
|
46 |
+
|
47 |
+
aparser = argparse.ArgumentParser()
|
48 |
+
aparser.add_argument('config', type=str, help='Path to evaluation config')
|
49 |
+
aparser.add_argument('datadir', type=str,
|
50 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
51 |
+
aparser.add_argument('predictdir', type=str,
|
52 |
+
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
|
53 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
54 |
+
aparser.add_argument('-s', type=float, default=0.1, help='Gaussian blur sigma')
|
55 |
+
aparser.add_argument('-k', type=int, default=5, help='Kernel size in gaussian blur')
|
56 |
+
|
57 |
+
main(aparser.parse_args())
|
bin/calc_dataset_stats.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
from scipy.ndimage.morphology import distance_transform_edt
|
8 |
+
|
9 |
+
from saicinpainting.evaluation.data import InpaintingDataset
|
10 |
+
from saicinpainting.evaluation.vis import save_item_for_vis
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
dataset = InpaintingDataset(args.datadir, img_suffix='.png')
|
15 |
+
|
16 |
+
area_bins = np.linspace(0, 1, args.area_bins + 1)
|
17 |
+
|
18 |
+
heights = []
|
19 |
+
widths = []
|
20 |
+
image_areas = []
|
21 |
+
hole_areas = []
|
22 |
+
hole_area_percents = []
|
23 |
+
known_pixel_distances = []
|
24 |
+
|
25 |
+
area_bins_count = np.zeros(args.area_bins)
|
26 |
+
area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
|
27 |
+
|
28 |
+
bin2i = [[] for _ in range(args.area_bins)]
|
29 |
+
|
30 |
+
for i, item in enumerate(tqdm.tqdm(dataset)):
|
31 |
+
h, w = item['image'].shape[1:]
|
32 |
+
heights.append(h)
|
33 |
+
widths.append(w)
|
34 |
+
full_area = h * w
|
35 |
+
image_areas.append(full_area)
|
36 |
+
bin_mask = item['mask'] > 0.5
|
37 |
+
hole_area = bin_mask.sum()
|
38 |
+
hole_areas.append(hole_area)
|
39 |
+
hole_percent = hole_area / full_area
|
40 |
+
hole_area_percents.append(hole_percent)
|
41 |
+
bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
|
42 |
+
area_bins_count[bin_i] += 1
|
43 |
+
bin2i[bin_i].append(i)
|
44 |
+
|
45 |
+
cur_dist = distance_transform_edt(bin_mask)
|
46 |
+
cur_dist_inside_mask = cur_dist[bin_mask]
|
47 |
+
known_pixel_distances.append(cur_dist_inside_mask.mean())
|
48 |
+
|
49 |
+
os.makedirs(args.outdir, exist_ok=True)
|
50 |
+
with open(os.path.join(args.outdir, 'summary.txt'), 'w') as f:
|
51 |
+
f.write(f'''Location: {args.datadir}
|
52 |
+
|
53 |
+
Number of samples: {len(dataset)}
|
54 |
+
|
55 |
+
Image height: min {min(heights):5d} max {max(heights):5d} mean {np.mean(heights):.2f}
|
56 |
+
Image width: min {min(widths):5d} max {max(widths):5d} mean {np.mean(widths):.2f}
|
57 |
+
Image area: min {min(image_areas):7d} max {max(image_areas):7d} mean {np.mean(image_areas):.2f}
|
58 |
+
Hole area: min {min(hole_areas):7d} max {max(hole_areas):7d} mean {np.mean(hole_areas):.2f}
|
59 |
+
Hole area %: min {min(hole_area_percents) * 100:2.2f} max {max(hole_area_percents) * 100:2.2f} mean {np.mean(hole_area_percents) * 100:2.2f}
|
60 |
+
Dist 2known: min {min(known_pixel_distances):2.2f} max {max(known_pixel_distances):2.2f} mean {np.mean(known_pixel_distances):2.2f} median {np.median(known_pixel_distances):2.2f}
|
61 |
+
|
62 |
+
Stats by hole area %:
|
63 |
+
''')
|
64 |
+
for bin_i in range(args.area_bins):
|
65 |
+
f.write(f'{area_bin_titles[bin_i]}%: '
|
66 |
+
f'samples number {area_bins_count[bin_i]}, '
|
67 |
+
f'{area_bins_count[bin_i] / len(dataset) * 100:.1f}%\n')
|
68 |
+
|
69 |
+
for bin_i in range(args.area_bins):
|
70 |
+
bindir = os.path.join(args.outdir, 'samples', area_bin_titles[bin_i])
|
71 |
+
os.makedirs(bindir, exist_ok=True)
|
72 |
+
bin_idx = bin2i[bin_i]
|
73 |
+
for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
|
74 |
+
save_item_for_vis(dataset[sample_i], os.path.join(bindir, f'{sample_i}.png'))
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
import argparse
|
79 |
+
|
80 |
+
aparser = argparse.ArgumentParser()
|
81 |
+
aparser.add_argument('datadir', type=str,
|
82 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
83 |
+
aparser.add_argument('outdir', type=str, help='Where to put results')
|
84 |
+
aparser.add_argument('--samples-n', type=int, default=10,
|
85 |
+
help='Number of sample images with masks to copy for visualization for each area bin')
|
86 |
+
aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
|
87 |
+
|
88 |
+
main(aparser.parse_args())
|
bin/debug/analyze_overlapping_masks.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
BASEDIR="$(dirname $0)"
|
4 |
+
|
5 |
+
# paths are valid for mml7
|
6 |
+
|
7 |
+
# select images
|
8 |
+
#ls /data/inpainting/work/data/train | shuf | head -2000 | xargs -n1 -I{} cp {} /data/inpainting/mask_analysis/src
|
9 |
+
|
10 |
+
# generate masks
|
11 |
+
#"$BASEDIR/../gen_debug_mask_dataset.py" \
|
12 |
+
# "$BASEDIR/../../configs/debug_mask_gen.yaml" \
|
13 |
+
# "/data/inpainting/mask_analysis/src" \
|
14 |
+
# "/data/inpainting/mask_analysis/generated"
|
15 |
+
|
16 |
+
# predict
|
17 |
+
#"$BASEDIR/../predict.py" \
|
18 |
+
# model.path="simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/saved_checkpoint/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15_epoch22-step-574999" \
|
19 |
+
# indir="/data/inpainting/mask_analysis/generated" \
|
20 |
+
# outdir="/data/inpainting/mask_analysis/predicted" \
|
21 |
+
# dataset.img_suffix=.jpg \
|
22 |
+
# +out_ext=.jpg
|
23 |
+
|
24 |
+
# analyze good and bad samples
|
25 |
+
"$BASEDIR/../analyze_errors.py" \
|
26 |
+
--only-report \
|
27 |
+
--n-jobs 8 \
|
28 |
+
"$BASEDIR/../../configs/analyze_mask_errors.yaml" \
|
29 |
+
"/data/inpainting/mask_analysis/small/generated" \
|
30 |
+
"/data/inpainting/mask_analysis/small/predicted" \
|
31 |
+
"/data/inpainting/mask_analysis/small/report"
|
bin/evaluate_predicts.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
|
8 |
+
from saicinpainting.evaluation.evaluator import InpaintingEvaluator, lpips_fid100_f1
|
9 |
+
from saicinpainting.evaluation.losses.base_loss import SegmentationAwareSSIM, \
|
10 |
+
SegmentationClassStats, SSIMScore, LPIPSScore, FIDScore, SegmentationAwareLPIPS, SegmentationAwareFID
|
11 |
+
from saicinpainting.evaluation.utils import load_yaml
|
12 |
+
|
13 |
+
|
14 |
+
def main(args):
|
15 |
+
config = load_yaml(args.config)
|
16 |
+
|
17 |
+
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
|
18 |
+
|
19 |
+
metrics = {
|
20 |
+
'ssim': SSIMScore(),
|
21 |
+
'lpips': LPIPSScore(),
|
22 |
+
'fid': FIDScore()
|
23 |
+
}
|
24 |
+
enable_segm = config.get('segmentation', dict(enable=False)).get('enable', False)
|
25 |
+
if enable_segm:
|
26 |
+
weights_path = os.path.expandvars(config.segmentation.weights_path)
|
27 |
+
metrics.update(dict(
|
28 |
+
segm_stats=SegmentationClassStats(weights_path=weights_path),
|
29 |
+
segm_ssim=SegmentationAwareSSIM(weights_path=weights_path),
|
30 |
+
segm_lpips=SegmentationAwareLPIPS(weights_path=weights_path),
|
31 |
+
segm_fid=SegmentationAwareFID(weights_path=weights_path)
|
32 |
+
))
|
33 |
+
evaluator = InpaintingEvaluator(dataset, scores=metrics,
|
34 |
+
integral_title='lpips_fid100_f1', integral_func=lpips_fid100_f1,
|
35 |
+
**config.evaluator_kwargs)
|
36 |
+
|
37 |
+
os.makedirs(os.path.dirname(args.outpath), exist_ok=True)
|
38 |
+
|
39 |
+
results = evaluator.evaluate()
|
40 |
+
|
41 |
+
results = pd.DataFrame(results).stack(1).unstack(0)
|
42 |
+
results.dropna(axis=1, how='all', inplace=True)
|
43 |
+
results.to_csv(args.outpath, sep='\t', float_format='%.4f')
|
44 |
+
|
45 |
+
if enable_segm:
|
46 |
+
only_short_results = results[[c for c in results.columns if not c[0].startswith('segm_')]].dropna(axis=1, how='all')
|
47 |
+
only_short_results.to_csv(args.outpath + '_short', sep='\t', float_format='%.4f')
|
48 |
+
|
49 |
+
print(only_short_results)
|
50 |
+
|
51 |
+
segm_metrics_results = results[['segm_ssim', 'segm_lpips', 'segm_fid']].dropna(axis=1, how='all').transpose().unstack(0).reorder_levels([1, 0], axis=1)
|
52 |
+
segm_metrics_results.drop(['mean', 'std'], axis=0, inplace=True)
|
53 |
+
|
54 |
+
segm_stats_results = results['segm_stats'].dropna(axis=1, how='all').transpose()
|
55 |
+
segm_stats_results.index = pd.MultiIndex.from_tuples(n.split('/') for n in segm_stats_results.index)
|
56 |
+
segm_stats_results = segm_stats_results.unstack(0).reorder_levels([1, 0], axis=1)
|
57 |
+
segm_stats_results.sort_index(axis=1, inplace=True)
|
58 |
+
segm_stats_results.dropna(axis=0, how='all', inplace=True)
|
59 |
+
|
60 |
+
segm_results = pd.concat([segm_metrics_results, segm_stats_results], axis=1, sort=True)
|
61 |
+
segm_results.sort_values(('mask_freq', 'total'), ascending=False, inplace=True)
|
62 |
+
|
63 |
+
segm_results.to_csv(args.outpath + '_segm', sep='\t', float_format='%.4f')
|
64 |
+
else:
|
65 |
+
print(results)
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
import argparse
|
70 |
+
|
71 |
+
aparser = argparse.ArgumentParser()
|
72 |
+
aparser.add_argument('config', type=str, help='Path to evaluation config')
|
73 |
+
aparser.add_argument('datadir', type=str,
|
74 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
75 |
+
aparser.add_argument('predictdir', type=str,
|
76 |
+
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
|
77 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
78 |
+
|
79 |
+
main(aparser.parse_args())
|
bin/evaluator_example.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from skimage import io
|
7 |
+
from skimage.transform import resize
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
from saicinpainting.evaluation.evaluator import InpaintingEvaluator
|
11 |
+
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
|
12 |
+
|
13 |
+
|
14 |
+
class SimpleImageDataset(Dataset):
|
15 |
+
def __init__(self, root_dir, image_size=(400, 600)):
|
16 |
+
self.root_dir = root_dir
|
17 |
+
self.files = sorted(os.listdir(root_dir))
|
18 |
+
self.image_size = image_size
|
19 |
+
|
20 |
+
def __getitem__(self, index):
|
21 |
+
img_name = os.path.join(self.root_dir, self.files[index])
|
22 |
+
image = io.imread(img_name)
|
23 |
+
image = resize(image, self.image_size, anti_aliasing=True)
|
24 |
+
image = torch.FloatTensor(image).permute(2, 0, 1)
|
25 |
+
return image
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.files)
|
29 |
+
|
30 |
+
|
31 |
+
def create_rectangle_mask(height, width):
|
32 |
+
mask = np.ones((height, width))
|
33 |
+
up_left_corner = width // 4, height // 4
|
34 |
+
down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1)
|
35 |
+
cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED)
|
36 |
+
return mask
|
37 |
+
|
38 |
+
|
39 |
+
class Model():
|
40 |
+
def __call__(self, img_batch, mask_batch):
|
41 |
+
mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None]
|
42 |
+
inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :]
|
43 |
+
return inpainted
|
44 |
+
|
45 |
+
|
46 |
+
class SimpleImageSquareMaskDataset(Dataset):
|
47 |
+
def __init__(self, dataset):
|
48 |
+
self.dataset = dataset
|
49 |
+
self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size))
|
50 |
+
self.model = Model()
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
img = self.dataset[index]
|
54 |
+
mask = self.mask.clone()
|
55 |
+
inpainted = self.model(img[None, ...], mask[None, ...])
|
56 |
+
return dict(image=img, mask=mask, inpainted=inpainted)
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.dataset)
|
60 |
+
|
61 |
+
|
62 |
+
dataset = SimpleImageDataset('imgs')
|
63 |
+
mask_dataset = SimpleImageSquareMaskDataset(dataset)
|
64 |
+
model = Model()
|
65 |
+
metrics = {
|
66 |
+
'ssim': SSIMScore(),
|
67 |
+
'lpips': LPIPSScore(),
|
68 |
+
'fid': FIDScore()
|
69 |
+
}
|
70 |
+
|
71 |
+
evaluator = InpaintingEvaluator(
|
72 |
+
mask_dataset, scores=metrics, batch_size=3, area_grouping=True
|
73 |
+
)
|
74 |
+
|
75 |
+
results = evaluator.evaluate(model)
|
76 |
+
print(results)
|
bin/extract_masks.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL.Image as Image
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def main(args):
|
7 |
+
if not args.indir.endswith('/'):
|
8 |
+
args.indir += '/'
|
9 |
+
os.makedirs(args.outdir, exist_ok=True)
|
10 |
+
|
11 |
+
src_images = [
|
12 |
+
args.indir+fname for fname in os.listdir(args.indir)]
|
13 |
+
|
14 |
+
tgt_masks = [
|
15 |
+
args.outdir+fname[:-4] + f'_mask000.png'
|
16 |
+
for fname in os.listdir(args.indir)]
|
17 |
+
|
18 |
+
for img_name, msk_name in zip(src_images, tgt_masks):
|
19 |
+
#print(img)
|
20 |
+
#print(msk)
|
21 |
+
|
22 |
+
image = Image.open(img_name).convert('RGB')
|
23 |
+
image = np.transpose(np.array(image), (2, 0, 1))
|
24 |
+
|
25 |
+
mask = (image == 255).astype(int)
|
26 |
+
|
27 |
+
print(mask.dtype, mask.shape)
|
28 |
+
|
29 |
+
|
30 |
+
Image.fromarray(
|
31 |
+
np.clip(mask[0,:,:] * 255, 0, 255).astype('uint8'),mode='L'
|
32 |
+
).save(msk_name)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
'''
|
38 |
+
for infile in src_images:
|
39 |
+
try:
|
40 |
+
file_relpath = infile[len(indir):]
|
41 |
+
img_outpath = os.path.join(outdir, file_relpath)
|
42 |
+
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
43 |
+
|
44 |
+
image = Image.open(infile).convert('RGB')
|
45 |
+
|
46 |
+
mask =
|
47 |
+
|
48 |
+
Image.fromarray(
|
49 |
+
np.clip(
|
50 |
+
cur_mask * 255, 0, 255).astype('uint8'),
|
51 |
+
mode='L'
|
52 |
+
).save(cur_basename + f'_mask{i:03d}.png')
|
53 |
+
'''
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
import argparse
|
59 |
+
aparser = argparse.ArgumentParser()
|
60 |
+
aparser.add_argument('--indir', type=str, help='Path to folder with images')
|
61 |
+
aparser.add_argument('--outdir', type=str, help='Path to folder to store aligned images and masks to')
|
62 |
+
|
63 |
+
main(aparser.parse_args())
|
bin/filter_sharded_dataset.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
|
8 |
+
import braceexpand
|
9 |
+
import webdataset as wds
|
10 |
+
|
11 |
+
DEFAULT_CATS_FILE = os.path.join(os.path.dirname(__file__), '..', 'configs', 'places2-categories_157.txt')
|
12 |
+
|
13 |
+
def is_good_key(key, cats):
|
14 |
+
return any(c in key for c in cats)
|
15 |
+
|
16 |
+
|
17 |
+
def main(args):
|
18 |
+
if args.categories == 'nofilter':
|
19 |
+
good_categories = None
|
20 |
+
else:
|
21 |
+
with open(args.categories, 'r') as f:
|
22 |
+
good_categories = set(line.strip().split(' ')[0] for line in f if line.strip())
|
23 |
+
|
24 |
+
all_input_files = list(braceexpand.braceexpand(args.infile))
|
25 |
+
chunk_size = int(math.ceil(len(all_input_files) / args.n_read_streams))
|
26 |
+
|
27 |
+
input_iterators = [iter(wds.Dataset(all_input_files[start : start + chunk_size]).shuffle(args.shuffle_buffer))
|
28 |
+
for start in range(0, len(all_input_files), chunk_size)]
|
29 |
+
output_datasets = [wds.ShardWriter(args.outpattern.format(i)) for i in range(args.n_write_streams)]
|
30 |
+
|
31 |
+
good_readers = list(range(len(input_iterators)))
|
32 |
+
step_i = 0
|
33 |
+
good_samples = 0
|
34 |
+
bad_samples = 0
|
35 |
+
while len(good_readers) > 0:
|
36 |
+
if step_i % args.print_freq == 0:
|
37 |
+
print(f'Iterations done {step_i}; readers alive {good_readers}; good samples {good_samples}; bad samples {bad_samples}')
|
38 |
+
|
39 |
+
step_i += 1
|
40 |
+
|
41 |
+
ri = random.choice(good_readers)
|
42 |
+
try:
|
43 |
+
sample = next(input_iterators[ri])
|
44 |
+
except StopIteration:
|
45 |
+
good_readers = list(set(good_readers) - {ri})
|
46 |
+
continue
|
47 |
+
|
48 |
+
if good_categories is not None and not is_good_key(sample['__key__'], good_categories):
|
49 |
+
bad_samples += 1
|
50 |
+
continue
|
51 |
+
|
52 |
+
wi = random.randint(0, args.n_write_streams - 1)
|
53 |
+
output_datasets[wi].write(sample)
|
54 |
+
good_samples += 1
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
import argparse
|
59 |
+
|
60 |
+
aparser = argparse.ArgumentParser()
|
61 |
+
aparser.add_argument('--categories', type=str, default=DEFAULT_CATS_FILE)
|
62 |
+
aparser.add_argument('--shuffle-buffer', type=int, default=10000)
|
63 |
+
aparser.add_argument('--n-read-streams', type=int, default=10)
|
64 |
+
aparser.add_argument('--n-write-streams', type=int, default=10)
|
65 |
+
aparser.add_argument('--print-freq', type=int, default=1000)
|
66 |
+
aparser.add_argument('infile', type=str)
|
67 |
+
aparser.add_argument('outpattern', type=str)
|
68 |
+
|
69 |
+
main(aparser.parse_args())
|
bin/gen_debug_mask_dataset.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
|
6 |
+
import PIL.Image as Image
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import tqdm
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
|
13 |
+
from saicinpainting.evaluation.utils import load_yaml
|
14 |
+
|
15 |
+
|
16 |
+
def generate_masks_for_img(infile, outmask_pattern, mask_size=200, step=0.5):
|
17 |
+
inimg = Image.open(infile)
|
18 |
+
width, height = inimg.size
|
19 |
+
step_abs = int(mask_size * step)
|
20 |
+
|
21 |
+
mask = np.zeros((height, width), dtype='uint8')
|
22 |
+
mask_i = 0
|
23 |
+
|
24 |
+
for start_vertical in range(0, height - step_abs, step_abs):
|
25 |
+
for start_horizontal in range(0, width - step_abs, step_abs):
|
26 |
+
mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 255
|
27 |
+
|
28 |
+
cv2.imwrite(outmask_pattern.format(mask_i), mask)
|
29 |
+
|
30 |
+
mask[start_vertical:start_vertical + mask_size, start_horizontal:start_horizontal + mask_size] = 0
|
31 |
+
mask_i += 1
|
32 |
+
|
33 |
+
|
34 |
+
def main(args):
|
35 |
+
if not args.indir.endswith('/'):
|
36 |
+
args.indir += '/'
|
37 |
+
if not args.outdir.endswith('/'):
|
38 |
+
args.outdir += '/'
|
39 |
+
|
40 |
+
config = load_yaml(args.config)
|
41 |
+
|
42 |
+
in_files = list(glob.glob(os.path.join(args.indir, '**', f'*{config.img_ext}'), recursive=True))
|
43 |
+
for infile in tqdm.tqdm(in_files):
|
44 |
+
outimg = args.outdir + infile[len(args.indir):]
|
45 |
+
outmask_pattern = outimg[:-len(config.img_ext)] + '_mask{:04d}.png'
|
46 |
+
|
47 |
+
os.makedirs(os.path.dirname(outimg), exist_ok=True)
|
48 |
+
shutil.copy2(infile, outimg)
|
49 |
+
|
50 |
+
generate_masks_for_img(infile, outmask_pattern, **config.gen_kwargs)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
import argparse
|
55 |
+
|
56 |
+
aparser = argparse.ArgumentParser()
|
57 |
+
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
|
58 |
+
aparser.add_argument('indir', type=str, help='Path to folder with images')
|
59 |
+
aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
|
60 |
+
|
61 |
+
main(aparser.parse_args())
|
bin/gen_mask_dataset.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import traceback
|
7 |
+
|
8 |
+
import PIL.Image as Image
|
9 |
+
import numpy as np
|
10 |
+
from joblib import Parallel, delayed
|
11 |
+
|
12 |
+
from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
|
13 |
+
from saicinpainting.evaluation.utils import load_yaml, SmallMode
|
14 |
+
from saicinpainting.training.data.masks import MixedMaskGenerator
|
15 |
+
|
16 |
+
|
17 |
+
class MakeManyMasksWrapper:
|
18 |
+
def __init__(self, impl, variants_n=2):
|
19 |
+
self.impl = impl
|
20 |
+
self.variants_n = variants_n
|
21 |
+
|
22 |
+
def get_masks(self, img):
|
23 |
+
img = np.transpose(np.array(img), (2, 0, 1))
|
24 |
+
return [self.impl(img)[0] for _ in range(self.variants_n)]
|
25 |
+
|
26 |
+
|
27 |
+
def process_images(src_images, indir, outdir, config):
|
28 |
+
if config.generator_kind == 'segmentation':
|
29 |
+
mask_generator = SegmentationMask(**config.mask_generator_kwargs)
|
30 |
+
elif config.generator_kind == 'random':
|
31 |
+
variants_n = config.mask_generator_kwargs.pop('variants_n', 2)
|
32 |
+
mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs),
|
33 |
+
variants_n=variants_n)
|
34 |
+
else:
|
35 |
+
raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
|
36 |
+
|
37 |
+
max_tamper_area = config.get('max_tamper_area', 1)
|
38 |
+
|
39 |
+
for infile in src_images:
|
40 |
+
try:
|
41 |
+
file_relpath = infile[len(indir):]
|
42 |
+
img_outpath = os.path.join(outdir, file_relpath)
|
43 |
+
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
44 |
+
|
45 |
+
image = Image.open(infile).convert('RGB')
|
46 |
+
|
47 |
+
# scale input image to output resolution and filter smaller images
|
48 |
+
if min(image.size) < config.cropping.out_min_size:
|
49 |
+
handle_small_mode = SmallMode(config.cropping.handle_small_mode)
|
50 |
+
if handle_small_mode == SmallMode.DROP:
|
51 |
+
continue
|
52 |
+
elif handle_small_mode == SmallMode.UPSCALE:
|
53 |
+
factor = config.cropping.out_min_size / min(image.size)
|
54 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
55 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
56 |
+
else:
|
57 |
+
factor = config.cropping.out_min_size / min(image.size)
|
58 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
59 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
60 |
+
|
61 |
+
# generate and select masks
|
62 |
+
src_masks = mask_generator.get_masks(image)
|
63 |
+
|
64 |
+
filtered_image_mask_pairs = []
|
65 |
+
for cur_mask in src_masks:
|
66 |
+
if config.cropping.out_square_crop:
|
67 |
+
(crop_left,
|
68 |
+
crop_top,
|
69 |
+
crop_right,
|
70 |
+
crop_bottom) = propose_random_square_crop(cur_mask,
|
71 |
+
min_overlap=config.cropping.crop_min_overlap)
|
72 |
+
cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
|
73 |
+
cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
|
74 |
+
else:
|
75 |
+
cur_image = image
|
76 |
+
|
77 |
+
if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
|
78 |
+
continue
|
79 |
+
|
80 |
+
filtered_image_mask_pairs.append((cur_image, cur_mask))
|
81 |
+
|
82 |
+
mask_indices = np.random.choice(len(filtered_image_mask_pairs),
|
83 |
+
size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
|
84 |
+
replace=False)
|
85 |
+
|
86 |
+
# crop masks; save masks together with input image
|
87 |
+
mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
|
88 |
+
for i, idx in enumerate(mask_indices):
|
89 |
+
cur_image, cur_mask = filtered_image_mask_pairs[idx]
|
90 |
+
cur_basename = mask_basename + f'_crop{i:03d}'
|
91 |
+
Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
|
92 |
+
mode='L').save(cur_basename + f'_mask{i:03d}.png')
|
93 |
+
cur_image.save(cur_basename + '.png')
|
94 |
+
except KeyboardInterrupt:
|
95 |
+
return
|
96 |
+
except Exception as ex:
|
97 |
+
print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
|
98 |
+
|
99 |
+
|
100 |
+
def main(args):
|
101 |
+
if not args.indir.endswith('/'):
|
102 |
+
args.indir += '/'
|
103 |
+
|
104 |
+
os.makedirs(args.outdir, exist_ok=True)
|
105 |
+
|
106 |
+
config = load_yaml(args.config)
|
107 |
+
|
108 |
+
in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True))
|
109 |
+
if args.n_jobs == 0:
|
110 |
+
process_images(in_files, args.indir, args.outdir, config)
|
111 |
+
else:
|
112 |
+
in_files_n = len(in_files)
|
113 |
+
chunk_size = in_files_n // args.n_jobs + (1 if in_files_n % args.n_jobs > 0 else 0)
|
114 |
+
Parallel(n_jobs=args.n_jobs)(
|
115 |
+
delayed(process_images)(in_files[start:start+chunk_size], args.indir, args.outdir, config)
|
116 |
+
for start in range(0, len(in_files), chunk_size)
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
import argparse
|
122 |
+
|
123 |
+
aparser = argparse.ArgumentParser()
|
124 |
+
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
|
125 |
+
aparser.add_argument('indir', type=str, help='Path to folder with images')
|
126 |
+
aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
|
127 |
+
aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use')
|
128 |
+
aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension')
|
129 |
+
|
130 |
+
main(aparser.parse_args())
|
bin/gen_mask_dataset_hydra.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import traceback
|
7 |
+
import hydra
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import PIL.Image as Image
|
11 |
+
import numpy as np
|
12 |
+
from joblib import Parallel, delayed
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
|
15 |
+
from saicinpainting.evaluation.utils import load_yaml, SmallMode
|
16 |
+
from saicinpainting.training.data.masks import MixedMaskGenerator
|
17 |
+
|
18 |
+
|
19 |
+
class MakeManyMasksWrapper:
|
20 |
+
def __init__(self, impl, variants_n=2):
|
21 |
+
self.impl = impl
|
22 |
+
self.variants_n = variants_n
|
23 |
+
|
24 |
+
def get_masks(self, img):
|
25 |
+
img = np.transpose(np.array(img), (2, 0, 1))
|
26 |
+
return [self.impl(img)[0] for _ in range(self.variants_n)]
|
27 |
+
|
28 |
+
|
29 |
+
def process_images(src_images, indir, outdir, config):
|
30 |
+
if config.generator_kind == 'segmentation':
|
31 |
+
mask_generator = SegmentationMask(**config.mask_generator_kwargs)
|
32 |
+
elif config.generator_kind == 'random':
|
33 |
+
mask_generator_kwargs = OmegaConf.to_container(config.mask_generator_kwargs, resolve=True)
|
34 |
+
variants_n = mask_generator_kwargs.pop('variants_n', 2)
|
35 |
+
mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**mask_generator_kwargs),
|
36 |
+
variants_n=variants_n)
|
37 |
+
else:
|
38 |
+
raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
|
39 |
+
|
40 |
+
max_tamper_area = config.get('max_tamper_area', 1)
|
41 |
+
|
42 |
+
for infile in src_images:
|
43 |
+
try:
|
44 |
+
file_relpath = infile[len(indir):]
|
45 |
+
img_outpath = os.path.join(outdir, file_relpath)
|
46 |
+
os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
|
47 |
+
|
48 |
+
image = Image.open(infile).convert('RGB')
|
49 |
+
|
50 |
+
# scale input image to output resolution and filter smaller images
|
51 |
+
if min(image.size) < config.cropping.out_min_size:
|
52 |
+
handle_small_mode = SmallMode(config.cropping.handle_small_mode)
|
53 |
+
if handle_small_mode == SmallMode.DROP:
|
54 |
+
continue
|
55 |
+
elif handle_small_mode == SmallMode.UPSCALE:
|
56 |
+
factor = config.cropping.out_min_size / min(image.size)
|
57 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
58 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
59 |
+
else:
|
60 |
+
factor = config.cropping.out_min_size / min(image.size)
|
61 |
+
out_size = (np.array(image.size) * factor).round().astype('uint32')
|
62 |
+
image = image.resize(out_size, resample=Image.BICUBIC)
|
63 |
+
|
64 |
+
# generate and select masks
|
65 |
+
src_masks = mask_generator.get_masks(image)
|
66 |
+
|
67 |
+
filtered_image_mask_pairs = []
|
68 |
+
for cur_mask in src_masks:
|
69 |
+
if config.cropping.out_square_crop:
|
70 |
+
(crop_left,
|
71 |
+
crop_top,
|
72 |
+
crop_right,
|
73 |
+
crop_bottom) = propose_random_square_crop(cur_mask,
|
74 |
+
min_overlap=config.cropping.crop_min_overlap)
|
75 |
+
cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
|
76 |
+
cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
|
77 |
+
else:
|
78 |
+
cur_image = image
|
79 |
+
|
80 |
+
if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
|
81 |
+
continue
|
82 |
+
|
83 |
+
filtered_image_mask_pairs.append((cur_image, cur_mask))
|
84 |
+
|
85 |
+
mask_indices = np.random.choice(len(filtered_image_mask_pairs),
|
86 |
+
size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
|
87 |
+
replace=False)
|
88 |
+
|
89 |
+
# crop masks; save masks together with input image
|
90 |
+
mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
|
91 |
+
for i, idx in enumerate(mask_indices):
|
92 |
+
cur_image, cur_mask = filtered_image_mask_pairs[idx]
|
93 |
+
cur_basename = mask_basename + f'_crop{i:03d}'
|
94 |
+
Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
|
95 |
+
mode='L').save(cur_basename + f'_mask{i:03d}.png')
|
96 |
+
cur_image.save(cur_basename + '.png')
|
97 |
+
except KeyboardInterrupt:
|
98 |
+
return
|
99 |
+
except Exception as ex:
|
100 |
+
print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
|
101 |
+
|
102 |
+
|
103 |
+
@hydra.main(config_path='../configs/data_gen/whydra', config_name='random_medium_256.yaml')
|
104 |
+
def main(config: OmegaConf):
|
105 |
+
if not config.indir.endswith('/'):
|
106 |
+
config.indir += '/'
|
107 |
+
|
108 |
+
os.makedirs(config.outdir, exist_ok=True)
|
109 |
+
|
110 |
+
in_files = list(glob.glob(os.path.join(config.indir, '**', f'*.{config.location.extension}'),
|
111 |
+
recursive=True))
|
112 |
+
if config.n_jobs == 0:
|
113 |
+
process_images(in_files, config.indir, config.outdir, config)
|
114 |
+
else:
|
115 |
+
in_files_n = len(in_files)
|
116 |
+
chunk_size = in_files_n // config.n_jobs + (1 if in_files_n % config.n_jobs > 0 else 0)
|
117 |
+
Parallel(n_jobs=config.n_jobs)(
|
118 |
+
delayed(process_images)(in_files[start:start+chunk_size], config.indir, config.outdir, config)
|
119 |
+
for start in range(0, len(in_files), chunk_size)
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
main()
|
bin/gen_outpainting_dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import sys
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
from saicinpainting.evaluation.data import load_image
|
10 |
+
from saicinpainting.evaluation.utils import move_to_device
|
11 |
+
|
12 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
13 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
14 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
15 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
16 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import hydra
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import tqdm
|
23 |
+
import yaml
|
24 |
+
from omegaconf import OmegaConf
|
25 |
+
from torch.utils.data._utils.collate import default_collate
|
26 |
+
|
27 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
28 |
+
from saicinpainting.training.trainers import load_checkpoint
|
29 |
+
from saicinpainting.utils import register_debug_signal_handlers
|
30 |
+
|
31 |
+
LOGGER = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
def main(args):
|
35 |
+
try:
|
36 |
+
if not args.indir.endswith('/'):
|
37 |
+
args.indir += '/'
|
38 |
+
|
39 |
+
for in_img in glob.glob(os.path.join(args.indir, '**', '*' + args.img_suffix), recursive=True):
|
40 |
+
if 'mask' in os.path.basename(in_img):
|
41 |
+
continue
|
42 |
+
|
43 |
+
out_img_path = os.path.join(args.outdir, os.path.splitext(in_img[len(args.indir):])[0] + '.png')
|
44 |
+
out_mask_path = f'{os.path.splitext(out_img_path)[0]}_mask.png'
|
45 |
+
|
46 |
+
os.makedirs(os.path.dirname(out_img_path), exist_ok=True)
|
47 |
+
|
48 |
+
img = load_image(in_img)
|
49 |
+
height, width = img.shape[1:]
|
50 |
+
pad_h, pad_w = int(height * args.coef / 2), int(width * args.coef / 2)
|
51 |
+
|
52 |
+
mask = np.zeros((height, width), dtype='uint8')
|
53 |
+
|
54 |
+
if args.expand:
|
55 |
+
img = np.pad(img, ((0, 0), (pad_h, pad_h), (pad_w, pad_w)))
|
56 |
+
mask = np.pad(mask, ((pad_h, pad_h), (pad_w, pad_w)), mode='constant', constant_values=255)
|
57 |
+
else:
|
58 |
+
mask[:pad_h] = 255
|
59 |
+
mask[-pad_h:] = 255
|
60 |
+
mask[:, :pad_w] = 255
|
61 |
+
mask[:, -pad_w:] = 255
|
62 |
+
|
63 |
+
# img = np.pad(img, ((0, 0), (pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode='symmetric')
|
64 |
+
# mask = np.pad(mask, ((pad_h * 2, pad_h * 2), (pad_w * 2, pad_w * 2)), mode = 'symmetric')
|
65 |
+
|
66 |
+
img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype('uint8')
|
67 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
68 |
+
cv2.imwrite(out_img_path, img)
|
69 |
+
|
70 |
+
cv2.imwrite(out_mask_path, mask)
|
71 |
+
except KeyboardInterrupt:
|
72 |
+
LOGGER.warning('Interrupted by user')
|
73 |
+
except Exception as ex:
|
74 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
75 |
+
sys.exit(1)
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
import argparse
|
80 |
+
|
81 |
+
aparser = argparse.ArgumentParser()
|
82 |
+
aparser.add_argument('indir', type=str, help='Root directory with images')
|
83 |
+
aparser.add_argument('outdir', type=str, help='Where to store results')
|
84 |
+
aparser.add_argument('--img-suffix', type=str, default='.png', help='Input image extension')
|
85 |
+
aparser.add_argument('--expand', action='store_true', help='Generate mask by padding (true) or by cropping (false)')
|
86 |
+
aparser.add_argument('--coef', type=float, default=0.2, help='How much to crop/expand in order to get masks')
|
87 |
+
|
88 |
+
main(aparser.parse_args())
|
bin/make_checkpoint.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def get_checkpoint_files(s):
|
10 |
+
s = s.strip()
|
11 |
+
if ',' in s:
|
12 |
+
return [get_checkpoint_files(chunk) for chunk in s.split(',')]
|
13 |
+
return 'last.ckpt' if s == 'last' else f'{s}.ckpt'
|
14 |
+
|
15 |
+
|
16 |
+
def main(args):
|
17 |
+
checkpoint_fnames = get_checkpoint_files(args.epochs)
|
18 |
+
if isinstance(checkpoint_fnames, str):
|
19 |
+
checkpoint_fnames = [checkpoint_fnames]
|
20 |
+
assert len(checkpoint_fnames) >= 1
|
21 |
+
|
22 |
+
checkpoint_path = os.path.join(args.indir, 'models', checkpoint_fnames[0])
|
23 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
24 |
+
del checkpoint['optimizer_states']
|
25 |
+
|
26 |
+
if len(checkpoint_fnames) > 1:
|
27 |
+
for fname in checkpoint_fnames[1:]:
|
28 |
+
print('sum', fname)
|
29 |
+
sum_tensors_cnt = 0
|
30 |
+
other_cp = torch.load(os.path.join(args.indir, 'models', fname), map_location='cpu')
|
31 |
+
for k in checkpoint['state_dict'].keys():
|
32 |
+
if checkpoint['state_dict'][k].dtype is torch.float:
|
33 |
+
checkpoint['state_dict'][k].data.add_(other_cp['state_dict'][k].data)
|
34 |
+
sum_tensors_cnt += 1
|
35 |
+
print('summed', sum_tensors_cnt, 'tensors')
|
36 |
+
|
37 |
+
for k in checkpoint['state_dict'].keys():
|
38 |
+
if checkpoint['state_dict'][k].dtype is torch.float:
|
39 |
+
checkpoint['state_dict'][k].data.mul_(1 / float(len(checkpoint_fnames)))
|
40 |
+
|
41 |
+
state_dict = checkpoint['state_dict']
|
42 |
+
|
43 |
+
if not args.leave_discriminators:
|
44 |
+
for k in list(state_dict.keys()):
|
45 |
+
if k.startswith('discriminator.'):
|
46 |
+
del state_dict[k]
|
47 |
+
|
48 |
+
if not args.leave_losses:
|
49 |
+
for k in list(state_dict.keys()):
|
50 |
+
if k.startswith('loss_'):
|
51 |
+
del state_dict[k]
|
52 |
+
|
53 |
+
out_checkpoint_path = os.path.join(args.outdir, 'models', 'best.ckpt')
|
54 |
+
os.makedirs(os.path.dirname(out_checkpoint_path), exist_ok=True)
|
55 |
+
|
56 |
+
torch.save(checkpoint, out_checkpoint_path)
|
57 |
+
|
58 |
+
shutil.copy2(os.path.join(args.indir, 'config.yaml'),
|
59 |
+
os.path.join(args.outdir, 'config.yaml'))
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
import argparse
|
64 |
+
|
65 |
+
aparser = argparse.ArgumentParser()
|
66 |
+
aparser.add_argument('indir',
|
67 |
+
help='Path to directory with output of training '
|
68 |
+
'(i.e. directory, which has samples, modules, config.yaml and train.log')
|
69 |
+
aparser.add_argument('outdir',
|
70 |
+
help='Where to put minimal checkpoint, which can be consumed by "bin/predict.py"')
|
71 |
+
aparser.add_argument('--epochs', type=str, default='last',
|
72 |
+
help='Which checkpoint to take. '
|
73 |
+
'Can be "last" or integer - number of epoch')
|
74 |
+
aparser.add_argument('--leave-discriminators', action='store_true',
|
75 |
+
help='If enabled, the state of discriminators will not be removed from the checkpoint')
|
76 |
+
aparser.add_argument('--leave-losses', action='store_true',
|
77 |
+
help='If enabled, weights of nn-based losses (e.g. perceptual) will not be removed')
|
78 |
+
|
79 |
+
main(aparser.parse_args())
|
bin/mask_example.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from skimage import io
|
3 |
+
from skimage.transform import resize
|
4 |
+
|
5 |
+
from saicinpainting.evaluation.masks.mask import SegmentationMask
|
6 |
+
|
7 |
+
im = io.imread('imgs/ex4.jpg')
|
8 |
+
im = resize(im, (512, 1024), anti_aliasing=True)
|
9 |
+
mask_seg = SegmentationMask(num_variants_per_mask=10)
|
10 |
+
mask_examples = mask_seg.get_masks(im)
|
11 |
+
for i, example in enumerate(mask_examples):
|
12 |
+
plt.imshow(example)
|
13 |
+
plt.show()
|
14 |
+
plt.imsave(f'tmp/img_masks/{i}.png', example)
|
bin/paper_runfiles/blur_tests.sh
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/env bash
|
2 |
+
#
|
3 |
+
## !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
|
4 |
+
#
|
5 |
+
## paths to data are valid for mml7
|
6 |
+
#PLACES_ROOT="/data/inpainting/Places365"
|
7 |
+
#OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
|
8 |
+
#
|
9 |
+
#source "$(dirname $0)/env.sh"
|
10 |
+
#
|
11 |
+
#for datadir in test_large_30k # val_large
|
12 |
+
#do
|
13 |
+
# for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
14 |
+
# do
|
15 |
+
# "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
16 |
+
# "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
|
17 |
+
#
|
18 |
+
# "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
19 |
+
# done
|
20 |
+
#
|
21 |
+
# for conf in segm_256 segm_512
|
22 |
+
# do
|
23 |
+
# "$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
24 |
+
# "$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
|
25 |
+
#
|
26 |
+
# "$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
27 |
+
# done
|
28 |
+
#done
|
29 |
+
#
|
30 |
+
#IN_DIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k/random_medium_512"
|
31 |
+
#PRED_DIR="/data/inpainting/predictions/final/images/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37/random_medium_512"
|
32 |
+
#BLUR_OUT_DIR="/data/inpainting/predictions/final/blur/images"
|
33 |
+
#
|
34 |
+
#for b in 0.1
|
35 |
+
#
|
36 |
+
#"$BINDIR/blur_predicts.py" "$BASEDIR/../../configs/eval2.yaml" "$CUR_IN_DIR" "$CUR_OUT_DIR" "$CUR_EVAL_DIR"
|
37 |
+
#
|
bin/paper_runfiles/env.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DIRNAME="$(dirname $0)"
|
2 |
+
DIRNAME="$(realpath ""$DIRNAME"")"
|
3 |
+
|
4 |
+
BINDIR="$DIRNAME/.."
|
5 |
+
SRCDIR="$BINDIR/.."
|
6 |
+
CONFIGDIR="$SRCDIR/configs"
|
7 |
+
|
8 |
+
export PYTHONPATH="$SRCDIR:$PYTHONPATH"
|
bin/paper_runfiles/find_best_checkpoint.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import os
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
|
7 |
+
|
8 |
+
def ssim_fid100_f1(metrics, fid_scale=100):
|
9 |
+
ssim = metrics.loc['total', 'ssim']['mean']
|
10 |
+
fid = metrics.loc['total', 'fid']['mean']
|
11 |
+
fid_rel = max(0, fid_scale - fid) / fid_scale
|
12 |
+
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
|
13 |
+
return f1
|
14 |
+
|
15 |
+
|
16 |
+
def find_best_checkpoint(model_list, models_dir):
|
17 |
+
with open(model_list) as f:
|
18 |
+
models = [m.strip() for m in f.readlines()]
|
19 |
+
with open(f'{model_list}_best', 'w') as f:
|
20 |
+
for model in models:
|
21 |
+
print(model)
|
22 |
+
best_f1 = 0
|
23 |
+
best_epoch = 0
|
24 |
+
best_step = 0
|
25 |
+
with open(os.path.join(models_dir, model, 'train.log')) as fm:
|
26 |
+
lines = fm.readlines()
|
27 |
+
for line_index in range(len(lines)):
|
28 |
+
line = lines[line_index]
|
29 |
+
if 'Validation metrics after epoch' in line:
|
30 |
+
sharp_index = line.index('#')
|
31 |
+
cur_ep = line[sharp_index + 1:]
|
32 |
+
comma_index = cur_ep.index(',')
|
33 |
+
cur_ep = int(cur_ep[:comma_index])
|
34 |
+
total_index = line.index('total ')
|
35 |
+
step = int(line[total_index:].split()[1].strip())
|
36 |
+
total_line = lines[line_index + 5]
|
37 |
+
if not total_line.startswith('total'):
|
38 |
+
continue
|
39 |
+
words = total_line.strip().split()
|
40 |
+
f1 = float(words[-1])
|
41 |
+
print(f'\tEpoch: {cur_ep}, f1={f1}')
|
42 |
+
if f1 > best_f1:
|
43 |
+
best_f1 = f1
|
44 |
+
best_epoch = cur_ep
|
45 |
+
best_step = step
|
46 |
+
f.write(f'{model}\t{best_epoch}\t{best_step}\t{best_f1}\n')
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == '__main__':
|
50 |
+
parser = ArgumentParser()
|
51 |
+
parser.add_argument('model_list')
|
52 |
+
parser.add_argument('models_dir')
|
53 |
+
args = parser.parse_args()
|
54 |
+
find_best_checkpoint(args.model_list, args.models_dir)
|
bin/paper_runfiles/generate_test_celeba-hq.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/CelebA-HQ_val_test"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in "val" "test"
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-celeba-hq \
|
13 |
+
location.out_dir=$OUT_DIR cropping.out_square_crop=False
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
bin/paper_runfiles/generate_test_ffhq.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/FFHQ_val"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in test
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-ffhq \
|
13 |
+
location.out_dir=$OUT_DIR cropping.out_square_crop=False
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
bin/paper_runfiles/generate_test_paris.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in paris_eval_gt
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
|
13 |
+
location.out_dir=OUT_DIR cropping.out_square_crop=False cropping.out_min_size=227
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
bin/paper_runfiles/generate_test_paris_256.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml-ws01
|
4 |
+
OUT_DIR="/media/inpainting/paper_data/Paris_StreetView_Dataset_val_256"
|
5 |
+
|
6 |
+
source "$(dirname $0)/env.sh"
|
7 |
+
|
8 |
+
for datadir in paris_eval_gt
|
9 |
+
do
|
10 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 segm_256
|
11 |
+
do
|
12 |
+
"$BINDIR/gen_mask_dataset_hydra.py" -cn $conf datadir=$datadir location=mml-ws01-paris \
|
13 |
+
location.out_dir=$OUT_DIR cropping.out_square_crop=False cropping.out_min_size=256
|
14 |
+
|
15 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
16 |
+
done
|
17 |
+
done
|
bin/paper_runfiles/generate_val_test.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# !!! file set to make test_large_30k from the vanilla test_large: configs/test_large_30k.lst
|
4 |
+
|
5 |
+
# paths to data are valid for mml7
|
6 |
+
PLACES_ROOT="/data/inpainting/Places365"
|
7 |
+
OUT_DIR="/data/inpainting/paper_data/Places365_val_test"
|
8 |
+
|
9 |
+
source "$(dirname $0)/env.sh"
|
10 |
+
|
11 |
+
for datadir in test_large_30k # val_large
|
12 |
+
do
|
13 |
+
for conf in random_thin_256 random_medium_256 random_thick_256 random_thin_512 random_medium_512 random_thick_512
|
14 |
+
do
|
15 |
+
"$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
16 |
+
"$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 8
|
17 |
+
|
18 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
19 |
+
done
|
20 |
+
|
21 |
+
for conf in segm_256 segm_512
|
22 |
+
do
|
23 |
+
"$BINDIR/gen_mask_dataset.py" "$CONFIGDIR/data_gen/${conf}.yaml" \
|
24 |
+
"$PLACES_ROOT/$datadir" "$OUT_DIR/$datadir/$conf" --n-jobs 2
|
25 |
+
|
26 |
+
"$BINDIR/calc_dataset_stats.py" --samples-n 20 "$OUT_DIR/$datadir/$conf" "$OUT_DIR/$datadir/${conf}_stats"
|
27 |
+
done
|
28 |
+
done
|
bin/paper_runfiles/predict_inner_features.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml7
|
4 |
+
|
5 |
+
source "$(dirname $0)/env.sh"
|
6 |
+
|
7 |
+
"$BINDIR/predict_inner_features.py" \
|
8 |
+
-cn default_inner_features_ffc \
|
9 |
+
model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-34-05_train_ablv2_work_ffc075_resume_epoch39" \
|
10 |
+
indir="/data/inpainting/paper_data/inner_features_vis/input/" \
|
11 |
+
outdir="/data/inpainting/paper_data/inner_features_vis/output/ffc" \
|
12 |
+
dataset.img_suffix=.png
|
13 |
+
|
14 |
+
|
15 |
+
"$BINDIR/predict_inner_features.py" \
|
16 |
+
-cn default_inner_features_work \
|
17 |
+
model.path="/data/inpainting/paper_data/final_models/ours/r.suvorov_2021-03-05_17-08-35_train_ablv2_work_resume_epoch37" \
|
18 |
+
indir="/data/inpainting/paper_data/inner_features_vis/input/" \
|
19 |
+
outdir="/data/inpainting/paper_data/inner_features_vis/output/work" \
|
20 |
+
dataset.img_suffix=.png
|
bin/paper_runfiles/update_test_data_stats.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# paths to data are valid for mml7
|
4 |
+
|
5 |
+
source "$(dirname $0)/env.sh"
|
6 |
+
|
7 |
+
#INDIR="/data/inpainting/paper_data/Places365_val_test/test_large_30k"
|
8 |
+
#
|
9 |
+
#for dataset in random_medium_256 random_medium_512 random_thick_256 random_thick_512 random_thin_256 random_thin_512
|
10 |
+
#do
|
11 |
+
# "$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
|
12 |
+
#done
|
13 |
+
#
|
14 |
+
#"$BINDIR/calc_dataset_stats.py" "/data/inpainting/evalset2" "/data/inpainting/evalset2_stats2"
|
15 |
+
|
16 |
+
|
17 |
+
INDIR="/data/inpainting/paper_data/CelebA-HQ_val_test/test"
|
18 |
+
|
19 |
+
for dataset in random_medium_256 random_thick_256 random_thin_256
|
20 |
+
do
|
21 |
+
"$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
|
22 |
+
done
|
23 |
+
|
24 |
+
|
25 |
+
INDIR="/data/inpainting/paper_data/Paris_StreetView_Dataset_val_256/paris_eval_gt"
|
26 |
+
|
27 |
+
for dataset in random_medium_256 random_thick_256 random_thin_256
|
28 |
+
do
|
29 |
+
"$BINDIR/calc_dataset_stats.py" "$INDIR/$dataset" "$INDIR/${dataset}_stats2"
|
30 |
+
done
|
bin/predict.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Example command:
|
4 |
+
# ./bin/predict.py \
|
5 |
+
# model.path=<path to checkpoint, prepared by make_checkpoint.py> \
|
6 |
+
# indir=<path to input data> \
|
7 |
+
# outdir=<where to store predicts>
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.utils import move_to_device
|
15 |
+
|
16 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
17 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
18 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
19 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
20 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
21 |
+
|
22 |
+
import cv2
|
23 |
+
import hydra
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import tqdm
|
27 |
+
import yaml
|
28 |
+
from omegaconf import OmegaConf
|
29 |
+
from torch.utils.data._utils.collate import default_collate
|
30 |
+
|
31 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
32 |
+
from saicinpainting.training.trainers import load_checkpoint
|
33 |
+
from saicinpainting.utils import register_debug_signal_handlers
|
34 |
+
|
35 |
+
LOGGER = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
@hydra.main(config_path='../configs/prediction', config_name='default.yaml')
|
39 |
+
def main(predict_config: OmegaConf):
|
40 |
+
try:
|
41 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
42 |
+
|
43 |
+
device = torch.device(predict_config.device)
|
44 |
+
|
45 |
+
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
|
46 |
+
with open(train_config_path, 'r') as f:
|
47 |
+
train_config = OmegaConf.create(yaml.safe_load(f))
|
48 |
+
|
49 |
+
train_config.training_model.predict_only = True
|
50 |
+
|
51 |
+
out_ext = predict_config.get('out_ext', '.png')
|
52 |
+
|
53 |
+
checkpoint_path = os.path.join(predict_config.model.path,
|
54 |
+
'models',
|
55 |
+
predict_config.model.checkpoint)
|
56 |
+
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
|
57 |
+
model.freeze()
|
58 |
+
model.to(device)
|
59 |
+
|
60 |
+
if not predict_config.indir.endswith('/'):
|
61 |
+
predict_config.indir += '/'
|
62 |
+
|
63 |
+
dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
|
64 |
+
with torch.no_grad():
|
65 |
+
for img_i in tqdm.trange(len(dataset)):
|
66 |
+
mask_fname = dataset.mask_filenames[img_i]
|
67 |
+
cur_out_fname = os.path.join(
|
68 |
+
predict_config.outdir,
|
69 |
+
os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
|
70 |
+
)
|
71 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
72 |
+
|
73 |
+
batch = move_to_device(default_collate([dataset[img_i]]), device)
|
74 |
+
batch['mask'] = (batch['mask'] > 0) * 1
|
75 |
+
batch = model(batch)
|
76 |
+
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
|
77 |
+
|
78 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
79 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
80 |
+
cv2.imwrite(cur_out_fname, cur_res)
|
81 |
+
except KeyboardInterrupt:
|
82 |
+
LOGGER.warning('Interrupted by user')
|
83 |
+
except Exception as ex:
|
84 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
85 |
+
sys.exit(1)
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
main()
|
bin/predict_inner_features.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Example command:
|
4 |
+
# ./bin/predict.py \
|
5 |
+
# model.path=<path to checkpoint, prepared by make_checkpoint.py> \
|
6 |
+
# indir=<path to input data> \
|
7 |
+
# outdir=<where to store predicts>
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import traceback
|
13 |
+
|
14 |
+
from saicinpainting.evaluation.utils import move_to_device
|
15 |
+
|
16 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
17 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
18 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
19 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
20 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
21 |
+
|
22 |
+
import cv2
|
23 |
+
import hydra
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import tqdm
|
27 |
+
import yaml
|
28 |
+
from omegaconf import OmegaConf
|
29 |
+
from torch.utils.data._utils.collate import default_collate
|
30 |
+
|
31 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
32 |
+
from saicinpainting.training.trainers import load_checkpoint, DefaultInpaintingTrainingModule
|
33 |
+
from saicinpainting.utils import register_debug_signal_handlers, get_shape
|
34 |
+
|
35 |
+
LOGGER = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
@hydra.main(config_path='../configs/prediction', config_name='default_inner_features.yaml')
|
39 |
+
def main(predict_config: OmegaConf):
|
40 |
+
try:
|
41 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
42 |
+
|
43 |
+
device = torch.device(predict_config.device)
|
44 |
+
|
45 |
+
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
|
46 |
+
with open(train_config_path, 'r') as f:
|
47 |
+
train_config = OmegaConf.create(yaml.safe_load(f))
|
48 |
+
|
49 |
+
checkpoint_path = os.path.join(predict_config.model.path, 'models', predict_config.model.checkpoint)
|
50 |
+
model = load_checkpoint(train_config, checkpoint_path, strict=False)
|
51 |
+
model.freeze()
|
52 |
+
model.to(device)
|
53 |
+
|
54 |
+
assert isinstance(model, DefaultInpaintingTrainingModule), 'Only DefaultInpaintingTrainingModule is supported'
|
55 |
+
assert isinstance(getattr(model.generator, 'model', None), torch.nn.Sequential)
|
56 |
+
|
57 |
+
if not predict_config.indir.endswith('/'):
|
58 |
+
predict_config.indir += '/'
|
59 |
+
|
60 |
+
dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
|
61 |
+
|
62 |
+
max_level = max(predict_config.levels)
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
for img_i in tqdm.trange(len(dataset)):
|
66 |
+
mask_fname = dataset.mask_filenames[img_i]
|
67 |
+
cur_out_fname = os.path.join(predict_config.outdir, os.path.splitext(mask_fname[len(predict_config.indir):])[0])
|
68 |
+
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
|
69 |
+
|
70 |
+
batch = move_to_device(default_collate([dataset[img_i]]), device)
|
71 |
+
|
72 |
+
img = batch['image']
|
73 |
+
mask = batch['mask']
|
74 |
+
mask[:] = 0
|
75 |
+
mask_h, mask_w = mask.shape[-2:]
|
76 |
+
mask[:, :,
|
77 |
+
mask_h // 2 - predict_config.hole_radius : mask_h // 2 + predict_config.hole_radius,
|
78 |
+
mask_w // 2 - predict_config.hole_radius : mask_w // 2 + predict_config.hole_radius] = 1
|
79 |
+
|
80 |
+
masked_img = torch.cat([img * (1 - mask), mask], dim=1)
|
81 |
+
|
82 |
+
feats = masked_img
|
83 |
+
for level_i, level in enumerate(model.generator.model):
|
84 |
+
feats = level(feats)
|
85 |
+
if level_i in predict_config.levels:
|
86 |
+
cur_feats = torch.cat([f for f in feats if torch.is_tensor(f)], dim=1) \
|
87 |
+
if isinstance(feats, tuple) else feats
|
88 |
+
|
89 |
+
if predict_config.slice_channels:
|
90 |
+
cur_feats = cur_feats[:, slice(*predict_config.slice_channels)]
|
91 |
+
|
92 |
+
cur_feat = cur_feats.pow(2).mean(1).pow(0.5).clone()
|
93 |
+
cur_feat -= cur_feat.min()
|
94 |
+
cur_feat /= cur_feat.std()
|
95 |
+
cur_feat = cur_feat.clamp(0, 1) / 1
|
96 |
+
cur_feat = cur_feat.cpu().numpy()[0]
|
97 |
+
cur_feat *= 255
|
98 |
+
cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
|
99 |
+
cv2.imwrite(cur_out_fname + f'_lev{level_i:02d}_norm.png', cur_feat)
|
100 |
+
|
101 |
+
# for channel_i in predict_config.channels:
|
102 |
+
#
|
103 |
+
# cur_feat = cur_feats[0, channel_i].clone().detach().cpu().numpy()
|
104 |
+
# cur_feat -= cur_feat.min()
|
105 |
+
# cur_feat /= cur_feat.max()
|
106 |
+
# cur_feat *= 255
|
107 |
+
# cur_feat = np.clip(cur_feat, 0, 255).astype('uint8')
|
108 |
+
# cv2.imwrite(cur_out_fname + f'_lev{level_i}_ch{channel_i}.png', cur_feat)
|
109 |
+
elif level_i >= max_level:
|
110 |
+
break
|
111 |
+
except KeyboardInterrupt:
|
112 |
+
LOGGER.warning('Interrupted by user')
|
113 |
+
except Exception as ex:
|
114 |
+
LOGGER.critical(f'Prediction failed due to {ex}:\n{traceback.format_exc()}')
|
115 |
+
sys.exit(1)
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == '__main__':
|
119 |
+
main()
|
bin/report_from_tb.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
|
7 |
+
import tensorflow as tf
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
|
10 |
+
|
11 |
+
GROUPING_RULES = [
|
12 |
+
re.compile(r'^(?P<group>train|test|val|extra_val_.*?(256|512))_(?P<title>.*)', re.I)
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
DROP_RULES = [
|
17 |
+
re.compile(r'_std$', re.I)
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def need_drop(tag):
|
22 |
+
for rule in DROP_RULES:
|
23 |
+
if rule.search(tag):
|
24 |
+
return True
|
25 |
+
return False
|
26 |
+
|
27 |
+
|
28 |
+
def get_group_and_title(tag):
|
29 |
+
for rule in GROUPING_RULES:
|
30 |
+
match = rule.search(tag)
|
31 |
+
if match is None:
|
32 |
+
continue
|
33 |
+
return match.group('group'), match.group('title')
|
34 |
+
return None, None
|
35 |
+
|
36 |
+
|
37 |
+
def main(args):
|
38 |
+
os.makedirs(args.outdir, exist_ok=True)
|
39 |
+
|
40 |
+
ignored_events = set()
|
41 |
+
|
42 |
+
for orig_fname in glob.glob(args.inglob):
|
43 |
+
cur_dirpath = os.path.dirname(orig_fname) # remove filename, this should point to "version_0" directory
|
44 |
+
subdirname = os.path.basename(cur_dirpath) # == "version_0" most of time
|
45 |
+
exp_root_path = os.path.dirname(cur_dirpath) # remove "version_0"
|
46 |
+
exp_name = os.path.basename(exp_root_path)
|
47 |
+
|
48 |
+
writers_by_group = {}
|
49 |
+
|
50 |
+
for e in tf.compat.v1.train.summary_iterator(orig_fname):
|
51 |
+
for v in e.summary.value:
|
52 |
+
if need_drop(v.tag):
|
53 |
+
continue
|
54 |
+
|
55 |
+
cur_group, cur_title = get_group_and_title(v.tag)
|
56 |
+
if cur_group is None:
|
57 |
+
if v.tag not in ignored_events:
|
58 |
+
print(f'WARNING: Could not detect group for {v.tag}, ignoring it')
|
59 |
+
ignored_events.add(v.tag)
|
60 |
+
continue
|
61 |
+
|
62 |
+
cur_writer = writers_by_group.get(cur_group, None)
|
63 |
+
if cur_writer is None:
|
64 |
+
if args.include_version:
|
65 |
+
cur_outdir = os.path.join(args.outdir, exp_name, f'{subdirname}_{cur_group}')
|
66 |
+
else:
|
67 |
+
cur_outdir = os.path.join(args.outdir, exp_name, cur_group)
|
68 |
+
cur_writer = SummaryWriter(cur_outdir)
|
69 |
+
writers_by_group[cur_group] = cur_writer
|
70 |
+
|
71 |
+
cur_writer.add_scalar(cur_title, v.simple_value, global_step=e.step, walltime=e.wall_time)
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
import argparse
|
76 |
+
|
77 |
+
aparser = argparse.ArgumentParser()
|
78 |
+
aparser.add_argument('inglob', type=str)
|
79 |
+
aparser.add_argument('outdir', type=str)
|
80 |
+
aparser.add_argument('--include-version', action='store_true',
|
81 |
+
help='Include subdirectory name e.g. "version_0" into output path')
|
82 |
+
|
83 |
+
main(aparser.parse_args())
|
bin/sample_from_dataset.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
from skimage import io
|
8 |
+
from skimage.segmentation import mark_boundaries
|
9 |
+
|
10 |
+
from saicinpainting.evaluation.data import InpaintingDataset
|
11 |
+
from saicinpainting.evaluation.vis import save_item_for_vis
|
12 |
+
|
13 |
+
def save_mask_for_sidebyside(item, out_file):
|
14 |
+
mask = item['mask']# > 0.5
|
15 |
+
if mask.ndim == 3:
|
16 |
+
mask = mask[0]
|
17 |
+
mask = np.clip(mask * 255, 0, 255).astype('uint8')
|
18 |
+
io.imsave(out_file, mask)
|
19 |
+
|
20 |
+
def save_img_for_sidebyside(item, out_file):
|
21 |
+
img = np.transpose(item['image'], (1, 2, 0))
|
22 |
+
img = np.clip(img * 255, 0, 255).astype('uint8')
|
23 |
+
io.imsave(out_file, img)
|
24 |
+
|
25 |
+
def save_masked_img_for_sidebyside(item, out_file):
|
26 |
+
mask = item['mask']
|
27 |
+
img = item['image']
|
28 |
+
|
29 |
+
img = (1-mask) * img + mask
|
30 |
+
img = np.transpose(img, (1, 2, 0))
|
31 |
+
|
32 |
+
img = np.clip(img * 255, 0, 255).astype('uint8')
|
33 |
+
io.imsave(out_file, img)
|
34 |
+
|
35 |
+
def main(args):
|
36 |
+
dataset = InpaintingDataset(args.datadir, img_suffix='.png')
|
37 |
+
|
38 |
+
area_bins = np.linspace(0, 1, args.area_bins + 1)
|
39 |
+
|
40 |
+
heights = []
|
41 |
+
widths = []
|
42 |
+
image_areas = []
|
43 |
+
hole_areas = []
|
44 |
+
hole_area_percents = []
|
45 |
+
area_bins_count = np.zeros(args.area_bins)
|
46 |
+
area_bin_titles = [f'{area_bins[i] * 100:.0f}-{area_bins[i + 1] * 100:.0f}' for i in range(args.area_bins)]
|
47 |
+
|
48 |
+
bin2i = [[] for _ in range(args.area_bins)]
|
49 |
+
|
50 |
+
for i, item in enumerate(tqdm.tqdm(dataset)):
|
51 |
+
h, w = item['image'].shape[1:]
|
52 |
+
heights.append(h)
|
53 |
+
widths.append(w)
|
54 |
+
full_area = h * w
|
55 |
+
image_areas.append(full_area)
|
56 |
+
hole_area = (item['mask'] == 1).sum()
|
57 |
+
hole_areas.append(hole_area)
|
58 |
+
hole_percent = hole_area / full_area
|
59 |
+
hole_area_percents.append(hole_percent)
|
60 |
+
bin_i = np.clip(np.searchsorted(area_bins, hole_percent) - 1, 0, len(area_bins_count) - 1)
|
61 |
+
area_bins_count[bin_i] += 1
|
62 |
+
bin2i[bin_i].append(i)
|
63 |
+
|
64 |
+
os.makedirs(args.outdir, exist_ok=True)
|
65 |
+
|
66 |
+
for bin_i in range(args.area_bins):
|
67 |
+
bindir = os.path.join(args.outdir, area_bin_titles[bin_i])
|
68 |
+
os.makedirs(bindir, exist_ok=True)
|
69 |
+
bin_idx = bin2i[bin_i]
|
70 |
+
for sample_i in np.random.choice(bin_idx, size=min(len(bin_idx), args.samples_n), replace=False):
|
71 |
+
item = dataset[sample_i]
|
72 |
+
path = os.path.join(bindir, dataset.img_filenames[sample_i].split('/')[-1])
|
73 |
+
save_masked_img_for_sidebyside(item, path)
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
import argparse
|
78 |
+
|
79 |
+
aparser = argparse.ArgumentParser()
|
80 |
+
aparser.add_argument('--datadir', type=str,
|
81 |
+
help='Path to folder with images and masks (output of gen_mask_dataset.py)')
|
82 |
+
aparser.add_argument('--outdir', type=str, help='Where to put results')
|
83 |
+
aparser.add_argument('--samples-n', type=int, default=10,
|
84 |
+
help='Number of sample images with masks to copy for visualization for each area bin')
|
85 |
+
aparser.add_argument('--area-bins', type=int, default=10, help='How many area bins to have')
|
86 |
+
|
87 |
+
main(aparser.parse_args())
|
bin/side_by_side.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset
|
9 |
+
from saicinpainting.evaluation.utils import load_yaml
|
10 |
+
from saicinpainting.training.visualizers.base import visualize_mask_and_images
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
config = load_yaml(args.config)
|
15 |
+
|
16 |
+
datasets = [PrecomputedInpaintingResultsDataset(args.datadir, cur_predictdir, **config.dataset_kwargs)
|
17 |
+
for cur_predictdir in args.predictdirs]
|
18 |
+
assert len({len(ds) for ds in datasets}) == 1
|
19 |
+
len_first = len(datasets[0])
|
20 |
+
|
21 |
+
indices = list(range(len_first))
|
22 |
+
if len_first > args.max_n:
|
23 |
+
indices = sorted(random.sample(indices, args.max_n))
|
24 |
+
|
25 |
+
os.makedirs(args.outpath, exist_ok=True)
|
26 |
+
|
27 |
+
filename2i = {}
|
28 |
+
|
29 |
+
keys = ['image'] + [i for i in range(len(datasets))]
|
30 |
+
for img_i in indices:
|
31 |
+
try:
|
32 |
+
mask_fname = os.path.basename(datasets[0].mask_filenames[img_i])
|
33 |
+
if mask_fname in filename2i:
|
34 |
+
filename2i[mask_fname] += 1
|
35 |
+
idx = filename2i[mask_fname]
|
36 |
+
mask_fname_only, ext = os.path.split(mask_fname)
|
37 |
+
mask_fname = f'{mask_fname_only}_{idx}{ext}'
|
38 |
+
else:
|
39 |
+
filename2i[mask_fname] = 1
|
40 |
+
|
41 |
+
cur_vis_dict = datasets[0][img_i]
|
42 |
+
for ds_i, ds in enumerate(datasets):
|
43 |
+
cur_vis_dict[ds_i] = ds[img_i]['inpainted']
|
44 |
+
|
45 |
+
vis_img = visualize_mask_and_images(cur_vis_dict, keys,
|
46 |
+
last_without_mask=False,
|
47 |
+
mask_only_first=True,
|
48 |
+
black_mask=args.black)
|
49 |
+
vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
|
50 |
+
|
51 |
+
out_fname = os.path.join(args.outpath, mask_fname)
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
|
56 |
+
cv2.imwrite(out_fname, vis_img)
|
57 |
+
except Exception as ex:
|
58 |
+
print(f'Could not process {img_i} due to {ex}')
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
import argparse
|
63 |
+
|
64 |
+
aparser = argparse.ArgumentParser()
|
65 |
+
aparser.add_argument('--max-n', type=int, default=100, help='Maximum number of images to print')
|
66 |
+
aparser.add_argument('--black', action='store_true', help='Whether to fill mask on GT with black')
|
67 |
+
aparser.add_argument('config', type=str, help='Path to evaluation config (e.g. configs/eval1.yaml)')
|
68 |
+
aparser.add_argument('outpath', type=str, help='Where to put results')
|
69 |
+
aparser.add_argument('datadir', type=str,
|
70 |
+
help='Path to folder with images and masks')
|
71 |
+
aparser.add_argument('predictdirs', type=str,
|
72 |
+
nargs='+',
|
73 |
+
help='Path to folders with predicts')
|
74 |
+
|
75 |
+
|
76 |
+
main(aparser.parse_args())
|
bin/split_tar.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import tqdm
|
5 |
+
import webdataset as wds
|
6 |
+
|
7 |
+
|
8 |
+
def main(args):
|
9 |
+
input_dataset = wds.Dataset(args.infile)
|
10 |
+
output_dataset = wds.ShardWriter(args.outpattern)
|
11 |
+
for rec in tqdm.tqdm(input_dataset):
|
12 |
+
output_dataset.write(rec)
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
aparser = argparse.ArgumentParser()
|
19 |
+
aparser.add_argument('infile', type=str)
|
20 |
+
aparser.add_argument('outpattern', type=str)
|
21 |
+
|
22 |
+
main(aparser.parse_args())
|
bin/train.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import traceback
|
7 |
+
|
8 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
9 |
+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
10 |
+
os.environ['MKL_NUM_THREADS'] = '1'
|
11 |
+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
|
12 |
+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
13 |
+
|
14 |
+
import hydra
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from pytorch_lightning import Trainer
|
17 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
18 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
19 |
+
from pytorch_lightning.plugins import DDPPlugin
|
20 |
+
|
21 |
+
from saicinpainting.training.trainers import make_training_model
|
22 |
+
from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \
|
23 |
+
handle_deterministic_config
|
24 |
+
|
25 |
+
LOGGER = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
@handle_ddp_subprocess()
|
29 |
+
@hydra.main(config_path='../configs/training', config_name='tiny_test.yaml')
|
30 |
+
def main(config: OmegaConf):
|
31 |
+
try:
|
32 |
+
need_set_deterministic = handle_deterministic_config(config)
|
33 |
+
|
34 |
+
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
|
35 |
+
|
36 |
+
is_in_ddp_subprocess = handle_ddp_parent_process()
|
37 |
+
|
38 |
+
config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir)
|
39 |
+
if not is_in_ddp_subprocess:
|
40 |
+
LOGGER.info(OmegaConf.to_yaml(config))
|
41 |
+
OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml'))
|
42 |
+
|
43 |
+
checkpoints_dir = os.path.join(os.getcwd(), 'models')
|
44 |
+
os.makedirs(checkpoints_dir, exist_ok=True)
|
45 |
+
|
46 |
+
# there is no need to suppress this logger in ddp, because it handles rank on its own
|
47 |
+
metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd()))
|
48 |
+
metrics_logger.log_hyperparams(config)
|
49 |
+
|
50 |
+
training_model = make_training_model(config)
|
51 |
+
|
52 |
+
trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True)
|
53 |
+
if need_set_deterministic:
|
54 |
+
trainer_kwargs['deterministic'] = True
|
55 |
+
|
56 |
+
trainer = Trainer(
|
57 |
+
# there is no need to suppress checkpointing in ddp, because it handles rank on its own
|
58 |
+
callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs),
|
59 |
+
logger=metrics_logger,
|
60 |
+
default_root_dir=os.getcwd(),
|
61 |
+
**trainer_kwargs
|
62 |
+
)
|
63 |
+
trainer.fit(training_model)
|
64 |
+
except KeyboardInterrupt:
|
65 |
+
LOGGER.warning('Interrupted by user')
|
66 |
+
except Exception as ex:
|
67 |
+
LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}')
|
68 |
+
sys.exit(1)
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
main()
|
configs/analyze_mask_errors.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_kwargs:
|
2 |
+
img_suffix: .jpg
|
3 |
+
inpainted_suffix: .jpg
|
4 |
+
|
5 |
+
take_global_top: 30
|
6 |
+
take_worst_best_top: 30
|
7 |
+
take_overlapping_top: 30
|
configs/data_gen/gen_segm_dataset1.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: segmentation
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
confidence_threshold: 0.5
|
5 |
+
max_object_area: 0.5
|
6 |
+
min_mask_area: 0.02
|
7 |
+
downsample_levels: 6
|
8 |
+
num_variants_per_mask: 5
|
9 |
+
rigidness_mode: 1
|
10 |
+
max_foreground_coverage: 0.3
|
11 |
+
max_foreground_intersection: 0.7
|
12 |
+
max_mask_intersection: 0.1
|
13 |
+
max_hidden_area: 0.1
|
14 |
+
max_scale_change: 0.25
|
15 |
+
horizontal_flip: True
|
16 |
+
max_vertical_shift: 0.2
|
17 |
+
position_shuffle: True
|
18 |
+
|
19 |
+
max_masks_per_image: 5
|
20 |
+
|
21 |
+
cropping:
|
22 |
+
out_min_size: 512
|
23 |
+
handle_small_mode: drop
|
24 |
+
out_square_crop: True
|
25 |
+
crop_min_overlap: 0.5
|
configs/data_gen/gen_segm_dataset3.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: segmentation
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
confidence_threshold: 0.5
|
5 |
+
max_object_area: 0.5
|
6 |
+
min_mask_area: 0.07
|
7 |
+
downsample_levels: 6
|
8 |
+
num_variants_per_mask: 3
|
9 |
+
rigidness_mode: 1
|
10 |
+
max_foreground_coverage: 0.4
|
11 |
+
max_foreground_intersection: 0.8
|
12 |
+
max_mask_intersection: 0.2
|
13 |
+
max_hidden_area: 0.1
|
14 |
+
max_scale_change: 0.25
|
15 |
+
horizontal_flip: True
|
16 |
+
max_vertical_shift: 0.3
|
17 |
+
position_shuffle: True
|
18 |
+
|
19 |
+
max_masks_per_image: 3
|
20 |
+
|
21 |
+
cropping:
|
22 |
+
out_min_size: 512
|
23 |
+
handle_small_mode: drop
|
24 |
+
out_square_crop: True
|
25 |
+
crop_min_overlap: 0.5
|
configs/data_gen/random_medium_256.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 5
|
8 |
+
max_width: 50
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 100
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 0
|
15 |
+
bbox_min_size: 10
|
16 |
+
bbox_max_size: 50
|
17 |
+
max_times: 5
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 256
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_medium_512.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 10
|
8 |
+
max_width: 100
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 200
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 0
|
15 |
+
bbox_min_size: 30
|
16 |
+
bbox_max_size: 150
|
17 |
+
max_times: 5
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 512
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_thick_256.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 1
|
7 |
+
max_times: 5
|
8 |
+
max_width: 100
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 200
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 10
|
15 |
+
bbox_min_size: 30
|
16 |
+
bbox_max_size: 150
|
17 |
+
max_times: 3
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 256
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_thick_512.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 1
|
7 |
+
max_times: 5
|
8 |
+
max_width: 250
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 450
|
11 |
+
|
12 |
+
box_proba: 0.3
|
13 |
+
box_kwargs:
|
14 |
+
margin: 10
|
15 |
+
bbox_min_size: 30
|
16 |
+
bbox_max_size: 300
|
17 |
+
max_times: 4
|
18 |
+
min_times: 1
|
19 |
+
|
20 |
+
segm_proba: 0
|
21 |
+
squares_proba: 0
|
22 |
+
|
23 |
+
variants_n: 5
|
24 |
+
|
25 |
+
max_masks_per_image: 1
|
26 |
+
|
27 |
+
cropping:
|
28 |
+
out_min_size: 512
|
29 |
+
handle_small_mode: upscale
|
30 |
+
out_square_crop: True
|
31 |
+
crop_min_overlap: 1
|
32 |
+
|
33 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_thin_256.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 50
|
8 |
+
max_width: 10
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 40
|
11 |
+
box_proba: 0
|
12 |
+
segm_proba: 0
|
13 |
+
squares_proba: 0
|
14 |
+
|
15 |
+
variants_n: 5
|
16 |
+
|
17 |
+
max_masks_per_image: 1
|
18 |
+
|
19 |
+
cropping:
|
20 |
+
out_min_size: 256
|
21 |
+
handle_small_mode: upscale
|
22 |
+
out_square_crop: True
|
23 |
+
crop_min_overlap: 1
|
24 |
+
|
25 |
+
max_tamper_area: 0.5
|
configs/data_gen/random_thin_512.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 1
|
5 |
+
irregular_kwargs:
|
6 |
+
min_times: 4
|
7 |
+
max_times: 70
|
8 |
+
max_width: 20
|
9 |
+
max_angle: 4
|
10 |
+
max_len: 100
|
11 |
+
box_proba: 0
|
12 |
+
segm_proba: 0
|
13 |
+
squares_proba: 0
|
14 |
+
|
15 |
+
variants_n: 5
|
16 |
+
|
17 |
+
max_masks_per_image: 1
|
18 |
+
|
19 |
+
cropping:
|
20 |
+
out_min_size: 512
|
21 |
+
handle_small_mode: upscale
|
22 |
+
out_square_crop: True
|
23 |
+
crop_min_overlap: 1
|
24 |
+
|
25 |
+
max_tamper_area: 0.5
|
configs/data_gen/segm_256.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: segmentation
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
confidence_threshold: 0.5
|
5 |
+
max_object_area: 0.5
|
6 |
+
min_mask_area: 0.05
|
7 |
+
downsample_levels: 6
|
8 |
+
num_variants_per_mask: 3
|
9 |
+
rigidness_mode: 1
|
10 |
+
max_foreground_coverage: 1 # turn off filtering by overlap
|
11 |
+
max_foreground_intersection: 1 # turn off filtering by overlap
|
12 |
+
max_mask_intersection: 0.2 # the lower this value the higher diversity
|
13 |
+
max_hidden_area: 0.5
|
14 |
+
max_scale_change: 0.25
|
15 |
+
horizontal_flip: True
|
16 |
+
max_vertical_shift: 0.3
|
17 |
+
position_shuffle: True
|
18 |
+
|
19 |
+
max_masks_per_image: 1
|
20 |
+
|
21 |
+
cropping:
|
22 |
+
out_min_size: 256
|
23 |
+
handle_small_mode: upscale
|
24 |
+
out_square_crop: True
|
25 |
+
crop_min_overlap: 1
|
26 |
+
|
27 |
+
max_tamper_area: 0.5
|
configs/data_gen/segm_512.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: segmentation
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
confidence_threshold: 0.5
|
5 |
+
max_object_area: 0.5
|
6 |
+
min_mask_area: 0.05
|
7 |
+
downsample_levels: 6
|
8 |
+
num_variants_per_mask: 3
|
9 |
+
rigidness_mode: 1
|
10 |
+
max_foreground_coverage: 1 # turn off filtering by overlap
|
11 |
+
max_foreground_intersection: 1 # turn off filtering by overlap
|
12 |
+
max_mask_intersection: 0.2 # the lower this value the higher diversity
|
13 |
+
max_hidden_area: 0.5
|
14 |
+
max_scale_change: 0.25
|
15 |
+
horizontal_flip: True
|
16 |
+
max_vertical_shift: 0.3
|
17 |
+
position_shuffle: True
|
18 |
+
|
19 |
+
max_masks_per_image: 1
|
20 |
+
|
21 |
+
cropping:
|
22 |
+
out_min_size: 512
|
23 |
+
handle_small_mode: upscale
|
24 |
+
out_square_crop: True
|
25 |
+
crop_min_overlap: 1
|
26 |
+
|
27 |
+
max_tamper_area: 0.5
|
configs/data_gen/sr_256.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_kind: random
|
2 |
+
|
3 |
+
mask_generator_kwargs:
|
4 |
+
irregular_proba: 0
|
5 |
+
box_proba: 0
|
6 |
+
segm_proba: 0
|
7 |
+
squares_proba: 0
|
8 |
+
superres_proba: 1
|
9 |
+
superres_kwargs:
|
10 |
+
min_step: 2
|
11 |
+
max_step: 4
|
12 |
+
min_width: 1
|
13 |
+
max_width: 3
|
14 |
+
|
15 |
+
variants_n: 5
|
16 |
+
|
17 |
+
max_masks_per_image: 1
|
18 |
+
|
19 |
+
cropping:
|
20 |
+
out_min_size: 256
|
21 |
+
handle_small_mode: upscale
|
22 |
+
out_square_crop: True
|
23 |
+
crop_min_overlap: 1
|
24 |
+
|
25 |
+
max_tamper_area: 1
|
configs/data_gen/whydra/location/mml-ws01-celeba-hq.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
root_dir: /media/inpainting/CelebA-HQ
|
4 |
+
out_dir: /media/inpainting/paper_data/CelebA-HQ_val_test
|
5 |
+
extension: jpg
|
configs/data_gen/whydra/location/mml-ws01-ffhq.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
root_dir: /media/inpainting/FFHQ/
|
4 |
+
out_dir: /media/inpainting/paper_data/FFHQ_val
|
5 |
+
extension: png
|
configs/data_gen/whydra/location/mml-ws01-paris.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
root_dir: /media/inpainting/Paris_StreetView_Dataset
|
4 |
+
out_dir: /media/inpainting/paper_data/Paris_StreetView_Dataset_val
|
5 |
+
extension: png
|
configs/data_gen/whydra/location/mml7-places.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
root_dir: /data/inpainting/Places365
|
4 |
+
out_dir: /data/inpainting/paper_data/Places365_val_test
|
5 |
+
extension: jpg
|
configs/data_gen/whydra/random_medium_256.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datadir: val_large
|
2 |
+
indir: ${location.root_dir}/${datadir}
|
3 |
+
outdir: ${location.out_dir}/${datadir}/random_medium_256
|
4 |
+
|
5 |
+
n_jobs: 8
|
6 |
+
|
7 |
+
generator_kind: random
|
8 |
+
|
9 |
+
mask_generator_kwargs:
|
10 |
+
irregular_proba: 1
|
11 |
+
irregular_kwargs:
|
12 |
+
min_times: 4
|
13 |
+
max_times: 5
|
14 |
+
max_width: 50
|
15 |
+
max_angle: 4
|
16 |
+
max_len: 100
|
17 |
+
|
18 |
+
box_proba: 0.3
|
19 |
+
box_kwargs:
|
20 |
+
margin: 0
|
21 |
+
bbox_min_size: 10
|
22 |
+
bbox_max_size: 50
|
23 |
+
max_times: 5
|
24 |
+
min_times: 1
|
25 |
+
|
26 |
+
segm_proba: 0
|
27 |
+
squares_proba: 0
|
28 |
+
|
29 |
+
variants_n: 5
|
30 |
+
|
31 |
+
max_masks_per_image: 1
|
32 |
+
|
33 |
+
cropping:
|
34 |
+
out_min_size: 256
|
35 |
+
handle_small_mode: upscale
|
36 |
+
out_square_crop: True
|
37 |
+
crop_min_overlap: 1
|
38 |
+
|
39 |
+
max_tamper_area: 0.5
|
40 |
+
|
41 |
+
defaults:
|
42 |
+
- location: mml7-places
|
configs/data_gen/whydra/random_medium_512.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datadir: val_large
|
2 |
+
indir: ${location.root_dir}/${datadir}
|
3 |
+
outdir: ${location.out_dir}/${datadir}/random_medium_512
|
4 |
+
|
5 |
+
n_jobs: 8
|
6 |
+
|
7 |
+
generator_kind: random
|
8 |
+
|
9 |
+
mask_generator_kwargs:
|
10 |
+
irregular_proba: 1
|
11 |
+
irregular_kwargs:
|
12 |
+
min_times: 4
|
13 |
+
max_times: 10
|
14 |
+
max_width: 100
|
15 |
+
max_angle: 4
|
16 |
+
max_len: 200
|
17 |
+
|
18 |
+
box_proba: 0.3
|
19 |
+
box_kwargs:
|
20 |
+
margin: 0
|
21 |
+
bbox_min_size: 30
|
22 |
+
bbox_max_size: 150
|
23 |
+
max_times: 5
|
24 |
+
min_times: 1
|
25 |
+
|
26 |
+
segm_proba: 0
|
27 |
+
squares_proba: 0
|
28 |
+
|
29 |
+
variants_n: 5
|
30 |
+
|
31 |
+
max_masks_per_image: 1
|
32 |
+
|
33 |
+
cropping:
|
34 |
+
out_min_size: 512
|
35 |
+
handle_small_mode: upscale
|
36 |
+
out_square_crop: True
|
37 |
+
crop_min_overlap: 1
|
38 |
+
|
39 |
+
max_tamper_area: 0.5
|
40 |
+
|
41 |
+
defaults:
|
42 |
+
- location: mml7-places
|