import argparse import os import sys import numpy as np import torch import torchvision.transforms as transforms from PIL import Image from networks.drn_seg import DRNSeg, DRNSub from utils.tools import * from utils.visualize import * def load_classifier(model_path, gpu_id): if torch.cuda.is_available() and gpu_id != -1: device = 'cuda:{}'.format(gpu_id) else: device = 'cpu' model = DRNSub(1) state_dict = torch.load(model_path, map_location='cpu') model.load_state_dict(state_dict['model']) model.to(device) model.device = device model.eval() return model local_model_path = 'weights/local.pth' global_model_path = 'weights/global.pth' gpu_id = 0 # Loading the model if torch.cuda.is_available(): device = 'cuda:{}'.format(gpu_id) else: device = 'cpu' local_model = DRNSeg(2) state_dict = torch.load(local_model_path, map_location=device) local_model.load_state_dict(state_dict['model']) local_model.to(device) local_model.eval() global_model = load_classifier(global_model_path, gpu_id) # prob = classify_fake(model, args.input_path, args.no_crop) # Data preprocessing tf = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def classify_fake(img, no_crop=False, global_model=global_model, model_file='utils/dlib_face_detector/mmod_human_face_detector.dat'): # Data preprocessing im_w, im_h = img.size if no_crop: face = img else: faces = face_detection(img, verbose=False, model_file=model_file) if len(faces) == 0: print("no face detected by dlib, exiting") sys.exit() face, box = faces[0] face = resize_shorter_side(face, 400)[0] face_tens = tf(face).to(global_model.device) # Prediction with torch.no_grad(): prob = global_model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item() return prob def heatmap_analysis(img, no_crop=False): im_w, im_h = img.size if no_crop: face = imgs else: faces = face_detection(img, verbose=False) if len(faces) == 0: print("no face detected by dlib, exiting") sys.exit() face, box = faces[0] face = resize_shorter_side(face, 400)[0] face_tens = tf(face).to(device) # Warping field prediction with torch.no_grad(): flow = local_model(face_tens.unsqueeze(0))[0].cpu().numpy() flow = np.transpose(flow, (1, 2, 0)) h, w, _ = flow.shape # Undoing the warps modified = face.resize((w, h), Image.BICUBIC) modified_np = np.asarray(modified) reverse_np = warp(modified_np, flow) reverse = Image.fromarray(reverse_np) flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2) cv_out = get_heatmap_cv(modified_np, flow_magn, 7) heatmap = Image.fromarray(cv_out) return modified, reverse, heatmap # Saving the results # modified.save( # os.path.join(dest_folder, 'cropped_input.jpg'), # quality=90) # reverse.save( # os.path.join(dest_folder, 'warped.jpg'), # quality=90) # flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2) # save_heatmap_cv( # modified_np, flow_magn, # os.path.join(dest_folder, 'heatmap.jpg'))