|
from pathlib import Path |
|
|
|
import albumentations as A |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from albumentations.pytorch.functional import img_to_tensor |
|
from huggingface_hub import hf_hub_download |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from torchvision.utils import draw_segmentation_masks, make_grid, save_image |
|
|
|
import utils.misc as misc |
|
from models import get_ensemble_model |
|
from opt import get_opt |
|
|
|
|
|
def greet(input_image): |
|
opt, model = _get_model() |
|
|
|
with torch.no_grad(): |
|
image = input_image |
|
image = np.array(image) |
|
h, w = image.shape[:2] |
|
if max(h, w) > 1024: |
|
transform = A.LongestMaxSize(1024) |
|
else: |
|
transform = None |
|
|
|
dsm_image = torch.from_numpy(image).permute(2, 0, 1) |
|
|
|
image_size = image.shape[:2] |
|
if transform is not None: |
|
image = transform(image=image)["image"] |
|
image = img_to_tensor( |
|
image, |
|
normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD}, |
|
) |
|
image = image.to(opt.device).unsqueeze(0) |
|
outputs = model(image, seg_size=image_size) |
|
out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu() |
|
pred = outputs["ensemble"]["out_map"].max().item() |
|
if pred > opt.mask_threshold: |
|
output_string = f"Found manipulation (manipulation probability {pred:.2f})." |
|
else: |
|
output_string = ( |
|
f"No manipulation found (manipulation probability {pred:.2f})." |
|
) |
|
|
|
if transform is not None: |
|
output_string += f"\nNote: Image was too large ({h}, {w}) and was resized to fit the model, which may decrease accuracy. We recommend image size smaller than 1024x1024." |
|
|
|
overlay = draw_segmentation_masks( |
|
dsm_image, masks=out_map[0, ...] > opt.mask_threshold |
|
) |
|
overlay = overlay.permute(1, 2, 0) |
|
overlay = overlay.detach().cpu().numpy() |
|
overlay = overlay.astype(np.uint8) |
|
return overlay, output_string |
|
|
|
|
|
def _get_model(config_path="configs/final.yaml", ckpt_path="tmp/checkpoint.pt"): |
|
ckpt_path = Path(ckpt_path) |
|
if not ckpt_path.exists(): |
|
ckpt_path.parent.mkdir(exist_ok=True, parents=True) |
|
hf_hub_download( |
|
repo_id="yhzhai/WSCL", |
|
filename="checkpoint.pt", |
|
local_dir=ckpt_path.parent.as_posix(), |
|
) |
|
|
|
opt = get_opt(config_path) |
|
opt.resume = ckpt_path.as_posix() |
|
|
|
model = get_ensemble_model(opt).to(opt.device) |
|
misc.resume_from(model, opt.resume) |
|
return opt, model |
|
|
|
|
|
with gr.Blocks(css=".output-image, .input-image, .image-preview {height: 400px !important}") as demo: |
|
|
|
gr.HTML(""" |
|
<div style="text-align: center; margin-bottom: 20px;"> |
|
<h1>WSCL: Image Manipulation Detection</h1> |
|
<h4>This demo detects and localizes image manipulations. For best performance, please use image of size smaller than 1024x1024.</h4> |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> |
|
<a href="https://arxiv.org/abs/2309.01246" style="margin-right: 5px;"><img src="https://img.shields.io/badge/arXiv-2309.01246-red"></a> |
|
<a href="https://github.com/yhZhai/WSCL" style="margin-left: 5px;"><img src='https://img.shields.io/badge/Github-WSCL-blue'></a> |
|
</div> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
iface = gr.Interface( |
|
fn=greet, |
|
|
|
inputs=gr.Image(), |
|
outputs=["image", "text"], |
|
examples=[["demo/au.jpg"], ["demo/tp.jpg"]], |
|
cache_examples=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|