P-PD / inference.py
mrneuralnet's picture
Initial commit
e875957
raw
history blame contribute delete
No virus
3.35 kB
import argparse
import os
import sys
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from networks.drn_seg import DRNSeg, DRNSub
from utils.tools import *
from utils.visualize import *
def load_classifier(model_path, gpu_id):
if torch.cuda.is_available() and gpu_id != -1:
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
model = DRNSub(1)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
local_model_path = 'weights/local.pth'
global_model_path = 'weights/global.pth'
gpu_id = 0
# Loading the model
if torch.cuda.is_available():
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
local_model = DRNSeg(2)
state_dict = torch.load(local_model_path, map_location=device)
local_model.load_state_dict(state_dict['model'])
local_model.to(device)
local_model.eval()
global_model = load_classifier(global_model_path, gpu_id)
# prob = classify_fake(model, args.input_path, args.no_crop)
# Data preprocessing
tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def classify_fake(img, no_crop=False, global_model=global_model,
model_file='utils/dlib_face_detector/mmod_human_face_detector.dat'):
# Data preprocessing
im_w, im_h = img.size
if no_crop:
face = img
else:
faces = face_detection(img, verbose=False, model_file=model_file)
if len(faces) == 0:
print("no face detected by dlib, exiting")
sys.exit()
face, box = faces[0]
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(global_model.device)
# Prediction
with torch.no_grad():
prob = global_model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()
return prob
def heatmap_analysis(img, no_crop=False):
im_w, im_h = img.size
if no_crop:
face = imgs
else:
faces = face_detection(img, verbose=False)
if len(faces) == 0:
print("no face detected by dlib, exiting")
sys.exit()
face, box = faces[0]
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(device)
# Warping field prediction
with torch.no_grad():
flow = local_model(face_tens.unsqueeze(0))[0].cpu().numpy()
flow = np.transpose(flow, (1, 2, 0))
h, w, _ = flow.shape
# Undoing the warps
modified = face.resize((w, h), Image.BICUBIC)
modified_np = np.asarray(modified)
reverse_np = warp(modified_np, flow)
reverse = Image.fromarray(reverse_np)
flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
cv_out = get_heatmap_cv(modified_np, flow_magn, 7)
heatmap = Image.fromarray(cv_out)
return modified, reverse, heatmap
# Saving the results
# modified.save(
# os.path.join(dest_folder, 'cropped_input.jpg'),
# quality=90)
# reverse.save(
# os.path.join(dest_folder, 'warped.jpg'),
# quality=90)
# flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
# save_heatmap_cv(
# modified_np, flow_magn,
# os.path.join(dest_folder, 'heatmap.jpg'))