|
|
|
import cv2 |
|
import numpy as np |
|
import sklearn |
|
import torch |
|
import os |
|
import pickle |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from joblib import Parallel, delayed |
|
|
|
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset, load_image |
|
from saicinpainting.evaluation.losses.fid.inception import InceptionV3 |
|
from saicinpainting.evaluation.utils import load_yaml |
|
from saicinpainting.training.visualizers.base import visualize_mask_and_images |
|
|
|
|
|
def draw_score(img, score): |
|
img = np.transpose(img, (1, 2, 0)) |
|
cv2.putText(img, f'{score:.2f}', |
|
(40, 40), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
1, |
|
(0, 1, 0), |
|
thickness=3) |
|
img = np.transpose(img, (2, 0, 1)) |
|
return img |
|
|
|
|
|
def save_global_samples(global_mask_fnames, mask2real_fname, mask2fake_fname, out_dir, real_scores_by_fname, fake_scores_by_fname): |
|
for cur_mask_fname in global_mask_fnames: |
|
cur_real_fname = mask2real_fname[cur_mask_fname] |
|
orig_img = load_image(cur_real_fname, mode='RGB') |
|
fake_img = load_image(mask2fake_fname[cur_mask_fname], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]] |
|
mask = load_image(cur_mask_fname, mode='L')[None, ...] |
|
|
|
draw_score(orig_img, real_scores_by_fname.loc[cur_real_fname, 'real_score']) |
|
draw_score(fake_img, fake_scores_by_fname.loc[cur_mask_fname, 'fake_score']) |
|
|
|
cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=mask, fake=fake_img), |
|
keys=['image', 'fake'], |
|
last_without_mask=True) |
|
cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8') |
|
cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(os.path.join(out_dir, os.path.splitext(os.path.basename(cur_mask_fname))[0] + '.jpg'), |
|
cur_grid) |
|
|
|
|
|
def save_samples_by_real(worst_best_by_real, mask2fake_fname, fake_info, out_dir): |
|
for real_fname in worst_best_by_real.index: |
|
worst_mask_path = worst_best_by_real.loc[real_fname, 'worst'] |
|
best_mask_path = worst_best_by_real.loc[real_fname, 'best'] |
|
orig_img = load_image(real_fname, mode='RGB') |
|
worst_mask_img = load_image(worst_mask_path, mode='L')[None, ...] |
|
worst_fake_img = load_image(mask2fake_fname[worst_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]] |
|
best_mask_img = load_image(best_mask_path, mode='L')[None, ...] |
|
best_fake_img = load_image(mask2fake_fname[best_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]] |
|
|
|
draw_score(orig_img, worst_best_by_real.loc[real_fname, 'real_score']) |
|
draw_score(worst_fake_img, worst_best_by_real.loc[real_fname, 'worst_score']) |
|
draw_score(best_fake_img, worst_best_by_real.loc[real_fname, 'best_score']) |
|
|
|
cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=np.zeros_like(worst_mask_img), |
|
worst_mask=worst_mask_img, worst_img=worst_fake_img, |
|
best_mask=best_mask_img, best_img=best_fake_img), |
|
keys=['image', 'worst_mask', 'worst_img', 'best_mask', 'best_img'], |
|
rescale_keys=['worst_mask', 'best_mask'], |
|
last_without_mask=True) |
|
cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8') |
|
cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(os.path.join(out_dir, |
|
os.path.splitext(os.path.basename(real_fname))[0] + '.jpg'), |
|
cur_grid) |
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2) |
|
cur_stat = fake_info[fake_info['real_fname'] == real_fname] |
|
cur_stat['fake_score'].hist(ax=ax1) |
|
cur_stat['real_score'].hist(ax=ax2) |
|
fig.tight_layout() |
|
fig.savefig(os.path.join(out_dir, |
|
os.path.splitext(os.path.basename(real_fname))[0] + '_scores.png')) |
|
plt.close(fig) |
|
|
|
|
|
def extract_overlapping_masks(mask_fnames, cur_i, fake_scores_table, max_overlaps_n=2): |
|
result_pairs = [] |
|
result_scores = [] |
|
mask_fname_a = mask_fnames[cur_i] |
|
mask_a = load_image(mask_fname_a, mode='L')[None, ...] > 0.5 |
|
cur_score_a = fake_scores_table.loc[mask_fname_a, 'fake_score'] |
|
for mask_fname_b in mask_fnames[cur_i + 1:]: |
|
mask_b = load_image(mask_fname_b, mode='L')[None, ...] > 0.5 |
|
if not np.any(mask_a & mask_b): |
|
continue |
|
cur_score_b = fake_scores_table.loc[mask_fname_b, 'fake_score'] |
|
result_pairs.append((mask_fname_a, mask_fname_b)) |
|
result_scores.append(cur_score_b - cur_score_a) |
|
if len(result_pairs) >= max_overlaps_n: |
|
break |
|
return result_pairs, result_scores |
|
|
|
|
|
def main(args): |
|
config = load_yaml(args.config) |
|
|
|
latents_dir = os.path.join(args.outpath, 'latents') |
|
os.makedirs(latents_dir, exist_ok=True) |
|
global_worst_dir = os.path.join(args.outpath, 'global_worst') |
|
os.makedirs(global_worst_dir, exist_ok=True) |
|
global_best_dir = os.path.join(args.outpath, 'global_best') |
|
os.makedirs(global_best_dir, exist_ok=True) |
|
worst_best_by_best_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_max') |
|
os.makedirs(worst_best_by_best_worst_score_diff_max_dir, exist_ok=True) |
|
worst_best_by_best_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_min') |
|
os.makedirs(worst_best_by_best_worst_score_diff_min_dir, exist_ok=True) |
|
worst_best_by_real_best_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_max') |
|
os.makedirs(worst_best_by_real_best_score_diff_max_dir, exist_ok=True) |
|
worst_best_by_real_best_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_min') |
|
os.makedirs(worst_best_by_real_best_score_diff_min_dir, exist_ok=True) |
|
worst_best_by_real_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_max') |
|
os.makedirs(worst_best_by_real_worst_score_diff_max_dir, exist_ok=True) |
|
worst_best_by_real_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_min') |
|
os.makedirs(worst_best_by_real_worst_score_diff_min_dir, exist_ok=True) |
|
|
|
if not args.only_report: |
|
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] |
|
inception_model = InceptionV3([block_idx]).eval().cuda() |
|
|
|
dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs) |
|
|
|
real2vector_cache = {} |
|
|
|
real_features = [] |
|
fake_features = [] |
|
|
|
orig_fnames = [] |
|
mask_fnames = [] |
|
mask2real_fname = {} |
|
mask2fake_fname = {} |
|
|
|
for batch_i, batch in enumerate(dataset): |
|
orig_img_fname = dataset.img_filenames[batch_i] |
|
mask_fname = dataset.mask_filenames[batch_i] |
|
fake_fname = dataset.pred_filenames[batch_i] |
|
mask2real_fname[mask_fname] = orig_img_fname |
|
mask2fake_fname[mask_fname] = fake_fname |
|
|
|
cur_real_vector = real2vector_cache.get(orig_img_fname, None) |
|
if cur_real_vector is None: |
|
with torch.no_grad(): |
|
in_img = torch.from_numpy(batch['image'][None, ...]).cuda() |
|
cur_real_vector = inception_model(in_img)[0].squeeze(-1).squeeze(-1).cpu().numpy() |
|
real2vector_cache[orig_img_fname] = cur_real_vector |
|
|
|
pred_img = torch.from_numpy(batch['inpainted'][None, ...]).cuda() |
|
cur_fake_vector = inception_model(pred_img)[0].squeeze(-1).squeeze(-1).cpu().numpy() |
|
|
|
real_features.append(cur_real_vector) |
|
fake_features.append(cur_fake_vector) |
|
|
|
orig_fnames.append(orig_img_fname) |
|
mask_fnames.append(mask_fname) |
|
|
|
ids_features = np.concatenate(real_features + fake_features, axis=0) |
|
ids_labels = np.array(([1] * len(real_features)) + ([0] * len(fake_features))) |
|
|
|
with open(os.path.join(latents_dir, 'featues.pkl'), 'wb') as f: |
|
pickle.dump(ids_features, f, protocol=3) |
|
with open(os.path.join(latents_dir, 'labels.pkl'), 'wb') as f: |
|
pickle.dump(ids_labels, f, protocol=3) |
|
with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'wb') as f: |
|
pickle.dump(orig_fnames, f, protocol=3) |
|
with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'wb') as f: |
|
pickle.dump(mask_fnames, f, protocol=3) |
|
with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'wb') as f: |
|
pickle.dump(mask2real_fname, f, protocol=3) |
|
with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'wb') as f: |
|
pickle.dump(mask2fake_fname, f, protocol=3) |
|
|
|
svm = sklearn.svm.LinearSVC(dual=False) |
|
svm.fit(ids_features, ids_labels) |
|
|
|
pred_scores = svm.decision_function(ids_features) |
|
real_scores = pred_scores[:len(real_features)] |
|
fake_scores = pred_scores[len(real_features):] |
|
|
|
with open(os.path.join(latents_dir, 'pred_scores.pkl'), 'wb') as f: |
|
pickle.dump(pred_scores, f, protocol=3) |
|
with open(os.path.join(latents_dir, 'real_scores.pkl'), 'wb') as f: |
|
pickle.dump(real_scores, f, protocol=3) |
|
with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'wb') as f: |
|
pickle.dump(fake_scores, f, protocol=3) |
|
else: |
|
with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'rb') as f: |
|
orig_fnames = pickle.load(f) |
|
with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'rb') as f: |
|
mask_fnames = pickle.load(f) |
|
with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'rb') as f: |
|
mask2real_fname = pickle.load(f) |
|
with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'rb') as f: |
|
mask2fake_fname = pickle.load(f) |
|
with open(os.path.join(latents_dir, 'real_scores.pkl'), 'rb') as f: |
|
real_scores = pickle.load(f) |
|
with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'rb') as f: |
|
fake_scores = pickle.load(f) |
|
|
|
real_info = pd.DataFrame(data=[dict(real_fname=fname, |
|
real_score=score) |
|
for fname, score |
|
in zip(orig_fnames, real_scores)]) |
|
real_info.set_index('real_fname', drop=True, inplace=True) |
|
|
|
fake_info = pd.DataFrame(data=[dict(mask_fname=fname, |
|
fake_fname=mask2fake_fname[fname], |
|
real_fname=mask2real_fname[fname], |
|
fake_score=score) |
|
for fname, score |
|
in zip(mask_fnames, fake_scores)]) |
|
fake_info = fake_info.join(real_info, on='real_fname', how='left') |
|
fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True) |
|
|
|
fake_stats_by_real = fake_info.groupby('real_fname')['fake_score'].describe()[['mean', 'std']].rename( |
|
{'mean': 'mean_fake_by_real', 'std': 'std_fake_by_real'}, axis=1) |
|
fake_info = fake_info.join(fake_stats_by_real, on='real_fname', rsuffix='stat_by_real') |
|
fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True) |
|
fake_info.to_csv(os.path.join(latents_dir, 'join_scores_table.csv'), sep='\t', index=False) |
|
|
|
fake_scores_table = fake_info.set_index('mask_fname')['fake_score'].to_frame() |
|
real_scores_table = fake_info.set_index('real_fname')['real_score'].drop_duplicates().to_frame() |
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2) |
|
ax1.hist(fake_scores) |
|
ax2.hist(real_scores) |
|
fig.tight_layout() |
|
fig.savefig(os.path.join(args.outpath, 'global_scores_hist.png')) |
|
plt.close(fig) |
|
|
|
global_worst_masks = fake_info.sort_values('fake_score', ascending=True)['mask_fname'].iloc[:config.take_global_top].to_list() |
|
global_best_masks = fake_info.sort_values('fake_score', ascending=False)['mask_fname'].iloc[:config.take_global_top].to_list() |
|
save_global_samples(global_worst_masks, mask2real_fname, mask2fake_fname, global_worst_dir, real_scores_table, fake_scores_table) |
|
save_global_samples(global_best_masks, mask2real_fname, mask2fake_fname, global_best_dir, real_scores_table, fake_scores_table) |
|
|
|
|
|
worst_samples_by_real = fake_info.groupby('real_fname').apply( |
|
lambda d: d.set_index('mask_fname')['fake_score'].idxmin()).to_frame().rename({0: 'worst'}, axis=1) |
|
best_samples_by_real = fake_info.groupby('real_fname').apply( |
|
lambda d: d.set_index('mask_fname')['fake_score'].idxmax()).to_frame().rename({0: 'best'}, axis=1) |
|
worst_best_by_real = pd.concat([worst_samples_by_real, best_samples_by_real], axis=1) |
|
|
|
worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'worst_score'}, axis=1), |
|
on='worst') |
|
worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'best_score'}, axis=1), |
|
on='best') |
|
worst_best_by_real = worst_best_by_real.join(real_scores_table) |
|
|
|
worst_best_by_real['best_worst_score_diff'] = worst_best_by_real['best_score'] - worst_best_by_real['worst_score'] |
|
worst_best_by_real['real_best_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['best_score'] |
|
worst_best_by_real['real_worst_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['worst_score'] |
|
|
|
worst_best_by_best_worst_score_diff_min = worst_best_by_real.sort_values('best_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top] |
|
worst_best_by_best_worst_score_diff_max = worst_best_by_real.sort_values('best_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top] |
|
save_samples_by_real(worst_best_by_best_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_min_dir) |
|
save_samples_by_real(worst_best_by_best_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_max_dir) |
|
|
|
worst_best_by_real_best_score_diff_min = worst_best_by_real.sort_values('real_best_score_diff', ascending=True).iloc[:config.take_worst_best_top] |
|
worst_best_by_real_best_score_diff_max = worst_best_by_real.sort_values('real_best_score_diff', ascending=False).iloc[:config.take_worst_best_top] |
|
save_samples_by_real(worst_best_by_real_best_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_min_dir) |
|
save_samples_by_real(worst_best_by_real_best_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_max_dir) |
|
|
|
worst_best_by_real_worst_score_diff_min = worst_best_by_real.sort_values('real_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top] |
|
worst_best_by_real_worst_score_diff_max = worst_best_by_real.sort_values('real_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top] |
|
save_samples_by_real(worst_best_by_real_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_min_dir) |
|
save_samples_by_real(worst_best_by_real_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_max_dir) |
|
|
|
|
|
overlapping_mask_fname_pairs = [] |
|
overlapping_mask_fname_score_diffs = [] |
|
for cur_real_fname in orig_fnames: |
|
cur_fakes_info = fake_info[fake_info['real_fname'] == cur_real_fname] |
|
cur_mask_fnames = sorted(cur_fakes_info['mask_fname'].unique()) |
|
|
|
cur_mask_pairs_and_scores = Parallel(args.n_jobs)( |
|
delayed(extract_overlapping_masks)(cur_mask_fnames, i, fake_scores_table) |
|
for i in range(len(cur_mask_fnames) - 1) |
|
) |
|
for cur_pairs, cur_scores in cur_mask_pairs_and_scores: |
|
overlapping_mask_fname_pairs.extend(cur_pairs) |
|
overlapping_mask_fname_score_diffs.extend(cur_scores) |
|
|
|
overlapping_mask_fname_pairs = np.asarray(overlapping_mask_fname_pairs) |
|
overlapping_mask_fname_score_diffs = np.asarray(overlapping_mask_fname_score_diffs) |
|
overlapping_sort_idx = np.argsort(overlapping_mask_fname_score_diffs) |
|
overlapping_mask_fname_pairs = overlapping_mask_fname_pairs[overlapping_sort_idx] |
|
overlapping_mask_fname_score_diffs = overlapping_mask_fname_score_diffs[overlapping_sort_idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
|
|
aparser = argparse.ArgumentParser() |
|
aparser.add_argument('config', type=str, help='Path to config for dataset generation') |
|
aparser.add_argument('datadir', type=str, |
|
help='Path to folder with images and masks (output of gen_mask_dataset.py)') |
|
aparser.add_argument('predictdir', type=str, |
|
help='Path to folder with predicts (e.g. predict_hifill_baseline.py)') |
|
aparser.add_argument('outpath', type=str, help='Where to put results') |
|
aparser.add_argument('--only-report', action='store_true', |
|
help='Whether to skip prediction and feature extraction, ' |
|
'load all the possible latents and proceed with report only') |
|
aparser.add_argument('--n-jobs', type=int, default=8, help='how many processes to use for pair mask mining') |
|
|
|
main(aparser.parse_args()) |
|
|