File size: 3,347 Bytes
e875957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import os
import sys
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image

from networks.drn_seg import DRNSeg, DRNSub
from utils.tools import *
from utils.visualize import *

def load_classifier(model_path, gpu_id):
    if torch.cuda.is_available() and gpu_id != -1:
        device = 'cuda:{}'.format(gpu_id)
    else:
        device = 'cpu'
    model = DRNSub(1)
    state_dict = torch.load(model_path, map_location='cpu')
    model.load_state_dict(state_dict['model'])
    model.to(device)
    model.device = device
    model.eval()
    return model



local_model_path = 'weights/local.pth'
global_model_path = 'weights/global.pth'
gpu_id = 0

# Loading the model
if torch.cuda.is_available():
    device = 'cuda:{}'.format(gpu_id)
else:
    device = 'cpu'

local_model = DRNSeg(2)
state_dict = torch.load(local_model_path, map_location=device)
local_model.load_state_dict(state_dict['model'])
local_model.to(device)
local_model.eval()

global_model = load_classifier(global_model_path, gpu_id)

# prob = classify_fake(model, args.input_path, args.no_crop)

# Data preprocessing
tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def classify_fake(img, no_crop=False, global_model=global_model,
                  model_file='utils/dlib_face_detector/mmod_human_face_detector.dat'):
    # Data preprocessing
    im_w, im_h = img.size
    if no_crop:
        face = img
    else:
        faces = face_detection(img, verbose=False, model_file=model_file)
        if len(faces) == 0:
            print("no face detected by dlib, exiting")
            sys.exit()
        face, box = faces[0]
    face = resize_shorter_side(face, 400)[0]
    face_tens = tf(face).to(global_model.device)

    # Prediction
    with torch.no_grad():
        prob = global_model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()
    return prob



def heatmap_analysis(img, no_crop=False):

    im_w, im_h = img.size
    if no_crop:
        face = imgs
    else:
        faces = face_detection(img, verbose=False)
        if len(faces) == 0:
            print("no face detected by dlib, exiting")
            sys.exit()
        face, box = faces[0]
    face = resize_shorter_side(face, 400)[0]
    face_tens = tf(face).to(device)

    # Warping field prediction
    with torch.no_grad():
        flow = local_model(face_tens.unsqueeze(0))[0].cpu().numpy()
        flow = np.transpose(flow, (1, 2, 0))
        h, w, _ = flow.shape

    # Undoing the warps
    modified = face.resize((w, h), Image.BICUBIC)
    modified_np = np.asarray(modified)
    reverse_np = warp(modified_np, flow)
    reverse = Image.fromarray(reverse_np)

    flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
    cv_out = get_heatmap_cv(modified_np, flow_magn, 7)
    heatmap = Image.fromarray(cv_out)
    return modified, reverse, heatmap

    # Saving the results
    # modified.save(
    #     os.path.join(dest_folder, 'cropped_input.jpg'),
    #     quality=90)
    # reverse.save(
    #     os.path.join(dest_folder, 'warped.jpg'),
    #     quality=90)
    # flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
    # save_heatmap_cv(
    #     modified_np, flow_magn,
    #     os.path.join(dest_folder, 'heatmap.jpg'))