File size: 4,010 Bytes
e875957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import glob
import argparse
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

from networks.drn_seg import DRNSeg, DRNSub
from utils.tools import *
from utils.visualize import *
from sklearn.metrics import average_precision_score, accuracy_score


def load_global_classifier(model_path, gpu_id):
    if torch.cuda.is_available() and gpu_id != -1:
        device = 'cuda:{}'.format(gpu_id)
    else:
        device = 'cpu'
    model = DRNSub(1)
    state_dict = torch.load(model_path, map_location='cpu')
    model.load_state_dict(state_dict['model'])
    model.to(device)
    model.device = device
    model.eval()
    return model


def load_local_detector(model_path, gpu_id):
    if torch.cuda.is_available():
        device = 'cuda:{}'.format(gpu_id)
    else:
        device = 'cpu'

    model = DRNSeg(2)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict['model'])
    model.to(device)
    model.device = device
    model.eval()
    return model


tf = transforms.Compose([transforms.ToTensor(),
                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])])
def load_data(img_path, device):
    face = Image.open(img_path).convert('RGB')
    face = resize_shorter_side(face, 400)[0]
    face_tens = tf(face).to(device)
    return face_tens, face


def classify_fake(model, img_path):
    img = load_data(img_path, model.device)[0].unsqueeze(0)
    # Prediction
    with torch.no_grad():
        prob = model(img)[0].sigmoid().cpu().item()
    return prob


def calc_psnr(img0, img1, mask=None):
    return -10 * np.log10(np.mean((img0 - img1)**2) + 1e-6)


def detect_warp(model, img_path):
    img, modified = load_data(img_path, model.device)
    # Warping field prediction
    with torch.no_grad():
        flow = model(img.unsqueeze(0))[0].cpu().numpy()
        flow = np.transpose(flow, (1, 2, 0))

    # Undoing the warps
    flow = flow_resize(flow, modified.size)
    modified_np = np.asarray(modified)
    reverse_np = warp(modified_np, flow)
    original = Image.open(img_path.replace('modified', 'reference')).convert('RGB')
    original_np = np.asarray(original.resize(modified.size, Image.BICUBIC))

    psnr_before = calc_psnr(original_np / 255, modified_np / 255)
    psnr_after = calc_psnr(original_np / 255, reverse_np / 255)
    return psnr_before, psnr_after


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataroot", required=True, help='the root to the dataset')
    parser.add_argument(
        "--global_pth", required=True, help="path to the global model")
    parser.add_argument(
        "--local_pth", required=True, help="path to the local model")
    parser.add_argument(
        "--gpu_id", default='0', help="the id of the gpu to run model on")
    args = parser.parse_args()

    glb_model = load_global_classifier(args.global_pth, args.gpu_id)
    lcl_model = load_local_detector(args.local_pth, args.gpu_id)

    pred_prob, gt_prob, psnr_before, psnr_after = [], [], [], []
    for img_path in glob.glob(args.dataroot + '/original/*'):
        pred_prob.append(classify_fake(glb_model, img_path))
        gt_prob.append(0)

    for img_path in glob.glob(args.dataroot + '/modified/*'):
        pred_prob.append(classify_fake(glb_model, img_path))
        gt_prob.append(1)
        psnrs = detect_warp(lcl_model, img_path)
        psnr_before.append(psnrs[0])
        psnr_after.append(psnrs[1])

    pred_prob, gt_prob, psnr_before, psnr_after = \
        np.array(pred_prob), np.array(gt_prob), np.array(psnr_before), np.array(psnr_after)
    acc = accuracy_score(gt_prob, pred_prob > 0.5)
    avg_precision = average_precision_score(gt_prob, pred_prob)
    delta_psnr = psnr_after.mean() - psnr_before.mean()

    print("Accuracy: ", acc)
    print("Average precision: ", avg_precision)
    print("PSNR increase: ", delta_psnr)