P-PD / local_detector.py
mrneuralnet's picture
Initial commit
e875957
raw
history blame
No virus
2.65 kB
import argparse
import os
import sys
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from networks.drn_seg import DRNSeg
from utils.tools import *
from utils.visualize import *
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path", required=True, help="the model input")
parser.add_argument(
"--dest_folder", required=True, help="folder to store the results")
parser.add_argument(
"--model_path", required=True, help="path to the drn model")
parser.add_argument(
"--gpu_id", default='0', help="the id of the gpu to run model on")
parser.add_argument(
"--no_crop",
action="store_true",
help="do not use a face detector, instead run on the full input image")
args = parser.parse_args()
img_path = args.input_path
dest_folder = args.dest_folder
model_path = args.model_path
gpu_id = args.gpu_id
# Loading the model
if torch.cuda.is_available():
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
model = DRNSeg(2)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict['model'])
model.to(device)
model.eval()
# Data preprocessing
tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
im_w, im_h = Image.open(img_path).size
if args.no_crop:
face = Image.open(img_path).convert('RGB')
else:
faces = face_detection(img_path, verbose=False)
if len(faces) == 0:
print("no face detected by dlib, exiting")
sys.exit()
face, box = faces[0]
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(device)
# Warping field prediction
with torch.no_grad():
flow = model(face_tens.unsqueeze(0))[0].cpu().numpy()
flow = np.transpose(flow, (1, 2, 0))
h, w, _ = flow.shape
# Undoing the warps
modified = face.resize((w, h), Image.BICUBIC)
modified_np = np.asarray(modified)
reverse_np = warp(modified_np, flow)
reverse = Image.fromarray(reverse_np)
# Saving the results
modified.save(
os.path.join(dest_folder, 'cropped_input.jpg'),
quality=90)
reverse.save(
os.path.join(dest_folder, 'warped.jpg'),
quality=90)
flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
save_heatmap_cv(
modified_np, flow_magn,
os.path.join(dest_folder, 'heatmap.jpg'))