Spaces:
Runtime error
Runtime error
File size: 4,010 Bytes
e875957 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import glob
import argparse
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from networks.drn_seg import DRNSeg, DRNSub
from utils.tools import *
from utils.visualize import *
from sklearn.metrics import average_precision_score, accuracy_score
def load_global_classifier(model_path, gpu_id):
if torch.cuda.is_available() and gpu_id != -1:
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
model = DRNSub(1)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
def load_local_detector(model_path, gpu_id):
if torch.cuda.is_available():
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
model = DRNSeg(2)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
tf = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def load_data(img_path, device):
face = Image.open(img_path).convert('RGB')
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(device)
return face_tens, face
def classify_fake(model, img_path):
img = load_data(img_path, model.device)[0].unsqueeze(0)
# Prediction
with torch.no_grad():
prob = model(img)[0].sigmoid().cpu().item()
return prob
def calc_psnr(img0, img1, mask=None):
return -10 * np.log10(np.mean((img0 - img1)**2) + 1e-6)
def detect_warp(model, img_path):
img, modified = load_data(img_path, model.device)
# Warping field prediction
with torch.no_grad():
flow = model(img.unsqueeze(0))[0].cpu().numpy()
flow = np.transpose(flow, (1, 2, 0))
# Undoing the warps
flow = flow_resize(flow, modified.size)
modified_np = np.asarray(modified)
reverse_np = warp(modified_np, flow)
original = Image.open(img_path.replace('modified', 'reference')).convert('RGB')
original_np = np.asarray(original.resize(modified.size, Image.BICUBIC))
psnr_before = calc_psnr(original_np / 255, modified_np / 255)
psnr_after = calc_psnr(original_np / 255, reverse_np / 255)
return psnr_before, psnr_after
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataroot", required=True, help='the root to the dataset')
parser.add_argument(
"--global_pth", required=True, help="path to the global model")
parser.add_argument(
"--local_pth", required=True, help="path to the local model")
parser.add_argument(
"--gpu_id", default='0', help="the id of the gpu to run model on")
args = parser.parse_args()
glb_model = load_global_classifier(args.global_pth, args.gpu_id)
lcl_model = load_local_detector(args.local_pth, args.gpu_id)
pred_prob, gt_prob, psnr_before, psnr_after = [], [], [], []
for img_path in glob.glob(args.dataroot + '/original/*'):
pred_prob.append(classify_fake(glb_model, img_path))
gt_prob.append(0)
for img_path in glob.glob(args.dataroot + '/modified/*'):
pred_prob.append(classify_fake(glb_model, img_path))
gt_prob.append(1)
psnrs = detect_warp(lcl_model, img_path)
psnr_before.append(psnrs[0])
psnr_after.append(psnrs[1])
pred_prob, gt_prob, psnr_before, psnr_after = \
np.array(pred_prob), np.array(gt_prob), np.array(psnr_before), np.array(psnr_after)
acc = accuracy_score(gt_prob, pred_prob > 0.5)
avg_precision = average_precision_score(gt_prob, pred_prob)
delta_psnr = psnr_after.mean() - psnr_before.mean()
print("Accuracy: ", acc)
print("Average precision: ", avg_precision)
print("PSNR increase: ", delta_psnr)
|