Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import numpy as np | |
import streamlit as st | |
import torch | |
from huggingface_hub import PyTorchModelHubMixin | |
from PIL import Image | |
from torchvision import transforms | |
from torchvision.transforms.functional import to_pil_image | |
from model import ICN | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def mask_processing(x): | |
if x > 90: | |
return 140 | |
elif x < 80: | |
return 0 | |
else: | |
return 255 | |
def grid_to_heatmap(grid, size=1024): | |
mask = to_pil_image(grid.view(7, 7)) | |
mask = mask.resize((size, size), Image.BICUBIC) | |
mask = Image.eval(mask, mask_processing) | |
colormap = plt.get_cmap("Wistia") | |
heatmap = np.array(colormap(mask)) | |
heatmap = (heatmap * 255).astype(np.uint8) | |
heatmap = Image.fromarray(heatmap) | |
return heatmap, mask | |
def summary_image(img, fake, prediction): | |
prediction -= prediction.min() | |
prediction = prediction / prediction.max() | |
size = 1024 | |
img1 = img.resize((size, size)) | |
img2 = fake.resize((size, size)) | |
heatmap, mask = grid_to_heatmap(prediction) | |
img1.paste(heatmap, (0, 0), mask) | |
img2.paste(heatmap, (0, 0), mask) | |
return img1, img2 | |
def load_model(): | |
model = torch.jit.load("traced_model.pt") | |
model.eval().to(device) | |
return model | |
model = ICN.from_pretrained("AlexBlck/image-comparator").eval().to(device) | |
# model = load_model() | |
st.title("Image Comparator Network") | |
st.write("## Upload a pair of images") | |
cols = st.columns(2) | |
with cols[0]: | |
im1 = st.file_uploader("Image 1", type=["jpg", "png"]) | |
with cols[1]: | |
im2 = st.file_uploader("Image 2", type=["jpg", "png"]) | |
if not (im1 and im2): | |
st.stop() | |
btn = st.button("Run") | |
if not btn: | |
st.stop() | |
im1 = Image.open(im1).convert("RGB") | |
im2 = Image.open(im2).convert("RGB") | |
tr = transforms.Compose( | |
[ | |
transforms.Resize(size=(224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
img = torch.vstack((tr(im1), tr(im2))).unsqueeze(0) | |
heatmap, cl = model(img.to(device)) | |
confs = torch.softmax(cl, dim=1) | |
pred = torch.argmax(confs, dim=1).item() | |
if pred == 0: | |
st.success("No Manipulation Detected") | |
heatmap *= 0 | |
elif pred == 1: | |
st.warning("Manipulation Detected!") | |
else: | |
st.error("Images are not related.") | |
heatmap *= 0 | |
img1, img2 = summary_image(im1, im2, heatmap[0]) | |
cols = st.columns(2) | |
with cols[0]: | |
st.image(img1) | |
with cols[1]: | |
st.image(img2) | |