AlexBlck's picture
Streamlit upload
bd0a3d5
raw
history blame
2.54 kB
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import torch
from huggingface_hub import PyTorchModelHubMixin
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from model import ICN
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def mask_processing(x):
if x > 90:
return 140
elif x < 80:
return 0
else:
return 255
def grid_to_heatmap(grid, size=1024):
mask = to_pil_image(grid.view(7, 7))
mask = mask.resize((size, size), Image.BICUBIC)
mask = Image.eval(mask, mask_processing)
colormap = plt.get_cmap("Wistia")
heatmap = np.array(colormap(mask))
heatmap = (heatmap * 255).astype(np.uint8)
heatmap = Image.fromarray(heatmap)
return heatmap, mask
def summary_image(img, fake, prediction):
prediction -= prediction.min()
prediction = prediction / prediction.max()
size = 1024
img1 = img.resize((size, size))
img2 = fake.resize((size, size))
heatmap, mask = grid_to_heatmap(prediction)
img1.paste(heatmap, (0, 0), mask)
img2.paste(heatmap, (0, 0), mask)
return img1, img2
@st.cache_resource
def load_model():
model = torch.jit.load("traced_model.pt")
model.eval().to(device)
return model
model = ICN.from_pretrained("AlexBlck/image-comparator").eval().to(device)
# model = load_model()
st.title("Image Comparator Network")
st.write("## Upload a pair of images")
cols = st.columns(2)
with cols[0]:
im1 = st.file_uploader("Image 1", type=["jpg", "png"])
with cols[1]:
im2 = st.file_uploader("Image 2", type=["jpg", "png"])
if not (im1 and im2):
st.stop()
btn = st.button("Run")
if not btn:
st.stop()
im1 = Image.open(im1).convert("RGB")
im2 = Image.open(im2).convert("RGB")
tr = transforms.Compose(
[
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = torch.vstack((tr(im1), tr(im2))).unsqueeze(0)
heatmap, cl = model(img.to(device))
confs = torch.softmax(cl, dim=1)
pred = torch.argmax(confs, dim=1).item()
if pred == 0:
st.success("No Manipulation Detected")
heatmap *= 0
elif pred == 1:
st.warning("Manipulation Detected!")
else:
st.error("Images are not related.")
heatmap *= 0
img1, img2 = summary_image(im1, im2, heatmap[0])
cols = st.columns(2)
with cols[0]:
st.image(img1)
with cols[1]:
st.image(img2)