File size: 4,103 Bytes
a5af557
 
60b5ed2
a5af557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60b5ed2
 
 
 
 
 
a5af557
60b5ed2
a5af557
60b5ed2
 
a5af557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60b5ed2
 
 
a5af557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eb32fd
 
 
 
 
 
 
2c85bb4
7eb32fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from pathlib import Path

import albumentations as A
import gradio as gr
import numpy as np
import torch
from albumentations.pytorch.functional import img_to_tensor
from huggingface_hub import hf_hub_download
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.utils import draw_segmentation_masks, make_grid, save_image

import utils.misc as misc
from models import get_ensemble_model
from opt import get_opt


def greet(input_image):
    opt, model = _get_model()

    with torch.no_grad():
        image = input_image
        image = np.array(image)
        h, w = image.shape[:2]
        if max(h, w) > 1024:
            transform = A.LongestMaxSize(1024)
        else:
            transform = None

        dsm_image = torch.from_numpy(image).permute(2, 0, 1)

        image_size = image.shape[:2]
        if transform is not None:
            image = transform(image=image)["image"]
        image = img_to_tensor(
            image,
            normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
        )
        image = image.to(opt.device).unsqueeze(0)
        outputs = model(image, seg_size=image_size)
        out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
        pred = outputs["ensemble"]["out_map"].max().item()
        if pred > opt.mask_threshold:
            output_string = f"Found manipulation (manipulation probability {pred:.2f})."
        else:
            output_string = (
                f"No manipulation found (manipulation probability {pred:.2f})."
            )

        if transform is not None:
            output_string += f"\nNote: Image was too large ({h}, {w}) and was resized to fit the model, which may decrease accuracy. We recommend image size smaller than 1024x1024."

        overlay = draw_segmentation_masks(
            dsm_image, masks=out_map[0, ...] > opt.mask_threshold
        )
        overlay = overlay.permute(1, 2, 0)
        overlay = overlay.detach().cpu().numpy()
        overlay = overlay.astype(np.uint8)
    return overlay, output_string


def _get_model(config_path="configs/final.yaml", ckpt_path="tmp/checkpoint.pt"):
    ckpt_path = Path(ckpt_path)
    if not ckpt_path.exists():
        ckpt_path.parent.mkdir(exist_ok=True, parents=True)
        hf_hub_download(
            repo_id="yhzhai/WSCL",
            filename="checkpoint.pt",
            local_dir=ckpt_path.parent.as_posix(),
        )

    opt = get_opt(config_path)
    opt.resume = ckpt_path.as_posix()

    model = get_ensemble_model(opt).to(opt.device)
    misc.resume_from(model, opt.resume)
    return opt, model


with gr.Blocks(css=".output-image, .input-image, .image-preview {height: 400px !important}") as demo:
    # gr.Markdown("## Image Manipulation Detection and Localization")
    gr.HTML("""
    <div style="text-align: center; margin-bottom: 20px;">
        <h1>WSCL: Image Manipulation Detection</h1>
        <h4>This demo detects and localizes image manipulations. For best performance, please use image of size smaller than 1024x1024.</h4>
        <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
            <a href="https://arxiv.org/abs/2309.01246" style="margin-right: 5px;"><img src="https://img.shields.io/badge/arXiv-2309.01246-red"></a>
            <a href="https://github.com/yhZhai/WSCL" style="margin-left: 5px;"><img src='https://img.shields.io/badge/Github-WSCL-blue'></a>
        </div>
    </div>
    """)
    
    with gr.Row():
        with gr.Column():
            iface = gr.Interface(
                fn=greet,
                # title="WSCL: Image Manipulation Detection",
                inputs=gr.Image(),
                outputs=["image", "text"],
                examples=[["demo/au.jpg"], ["demo/tp.jpg"]],
                cache_examples=True,
            )

# iface = gr.Interface(
#     fn=greet,
#     title="WSCL: Image Manipulation Detection",
#     inputs=gr.Image(),
#     outputs=["image", "text"],
#     examples=[["demo/au.jpg"], ["demo/tp.jpg"]],
#     cache_examples=True,
# )
demo.launch()