P-PD / eval.py
mrneuralnet's picture
Initial commit
e875957
raw
history blame
No virus
4.01 kB
import glob
import argparse
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from networks.drn_seg import DRNSeg, DRNSub
from utils.tools import *
from utils.visualize import *
from sklearn.metrics import average_precision_score, accuracy_score
def load_global_classifier(model_path, gpu_id):
if torch.cuda.is_available() and gpu_id != -1:
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
model = DRNSub(1)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
def load_local_detector(model_path, gpu_id):
if torch.cuda.is_available():
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
model = DRNSeg(2)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
tf = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def load_data(img_path, device):
face = Image.open(img_path).convert('RGB')
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(device)
return face_tens, face
def classify_fake(model, img_path):
img = load_data(img_path, model.device)[0].unsqueeze(0)
# Prediction
with torch.no_grad():
prob = model(img)[0].sigmoid().cpu().item()
return prob
def calc_psnr(img0, img1, mask=None):
return -10 * np.log10(np.mean((img0 - img1)**2) + 1e-6)
def detect_warp(model, img_path):
img, modified = load_data(img_path, model.device)
# Warping field prediction
with torch.no_grad():
flow = model(img.unsqueeze(0))[0].cpu().numpy()
flow = np.transpose(flow, (1, 2, 0))
# Undoing the warps
flow = flow_resize(flow, modified.size)
modified_np = np.asarray(modified)
reverse_np = warp(modified_np, flow)
original = Image.open(img_path.replace('modified', 'reference')).convert('RGB')
original_np = np.asarray(original.resize(modified.size, Image.BICUBIC))
psnr_before = calc_psnr(original_np / 255, modified_np / 255)
psnr_after = calc_psnr(original_np / 255, reverse_np / 255)
return psnr_before, psnr_after
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataroot", required=True, help='the root to the dataset')
parser.add_argument(
"--global_pth", required=True, help="path to the global model")
parser.add_argument(
"--local_pth", required=True, help="path to the local model")
parser.add_argument(
"--gpu_id", default='0', help="the id of the gpu to run model on")
args = parser.parse_args()
glb_model = load_global_classifier(args.global_pth, args.gpu_id)
lcl_model = load_local_detector(args.local_pth, args.gpu_id)
pred_prob, gt_prob, psnr_before, psnr_after = [], [], [], []
for img_path in glob.glob(args.dataroot + '/original/*'):
pred_prob.append(classify_fake(glb_model, img_path))
gt_prob.append(0)
for img_path in glob.glob(args.dataroot + '/modified/*'):
pred_prob.append(classify_fake(glb_model, img_path))
gt_prob.append(1)
psnrs = detect_warp(lcl_model, img_path)
psnr_before.append(psnrs[0])
psnr_after.append(psnrs[1])
pred_prob, gt_prob, psnr_before, psnr_after = \
np.array(pred_prob), np.array(gt_prob), np.array(psnr_before), np.array(psnr_after)
acc = accuracy_score(gt_prob, pred_prob > 0.5)
avg_precision = average_precision_score(gt_prob, pred_prob)
delta_psnr = psnr_after.mean() - psnr_before.mean()
print("Accuracy: ", acc)
print("Average precision: ", avg_precision)
print("PSNR increase: ", delta_psnr)