Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import PerceiverForOpticalFlow | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import requests | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import itertools | |
import math | |
import cv2 | |
model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver") | |
TRAIN_SIZE = model.config.train_size | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def normalize(im): | |
return im / 255.0 * 2 - 1 | |
# source: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/9 | |
def extract_image_patches(x, kernel, stride=1, dilation=1): | |
# Do TF 'SAME' Padding | |
b,c,h,w = x.shape | |
h2 = math.ceil(h / stride) | |
w2 = math.ceil(w / stride) | |
pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h | |
pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w | |
x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2)) | |
# Extract patches | |
patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride) | |
patches = patches.permute(0,4,5,1,2,3).contiguous() | |
return patches.view(b,-1,patches.shape[-2], patches.shape[-1]) | |
def compute_optical_flow(model, img1, img2, grid_indices, FLOW_SCALE_FACTOR = 20): | |
"""Function to compute optical flow between two images. | |
To compute the flow between images of arbitrary sizes, we divide the image | |
into patches, compute the flow for each patch, and stitch the flows together. | |
Args: | |
model: PyTorch Perceiver model | |
img1: first image | |
img2: second image | |
grid_indices: indices of the upper left corner for each patch. | |
""" | |
img1 = torch.tensor(np.moveaxis(img1, -1, 0)) | |
img2 = torch.tensor(np.moveaxis(img2, -1, 0)) | |
imgs = torch.stack([img1, img2], dim=0)[None] | |
height = imgs.shape[-2] | |
width = imgs.shape[-1] | |
patch_size = model.config.train_size | |
if height < patch_size[0]: | |
raise ValueError( | |
f"Height of image (shape: {imgs.shape}) must be at least {patch_size[0]}." | |
"Please pad or resize your image to the minimum dimension." | |
) | |
if width < patch_size[1]: | |
raise ValueError( | |
f"Width of image (shape: {imgs.shape}) must be at least {patch_size[1]}." | |
"Please pad or resize your image to the minimum dimension." | |
) | |
flows = 0 | |
flow_count = 0 | |
for y, x in grid_indices: | |
imgs = torch.stack([img1, img2], dim=0)[None] | |
inp_piece = imgs[..., y : y + patch_size[0], | |
x : x + patch_size[1]] | |
batch_size, _, C, H, W = inp_piece.shape | |
patches = extract_image_patches(inp_piece.view(batch_size*2,C,H,W), kernel=3) | |
_, C, H, W = patches.shape | |
patches = patches.view(batch_size, -1, C, H, W).float().to(model.device) | |
# actual forward pass | |
with torch.no_grad(): | |
output = model(inputs=patches).logits * FLOW_SCALE_FACTOR | |
# the code below could also be implemented in PyTorch | |
flow_piece = output.cpu().detach().numpy() | |
weights_x, weights_y = np.meshgrid( | |
torch.arange(patch_size[1]), torch.arange(patch_size[0])) | |
weights_x = np.minimum(weights_x + 1, patch_size[1] - weights_x) | |
weights_y = np.minimum(weights_y + 1, patch_size[0] - weights_y) | |
weights = np.minimum(weights_x, weights_y)[np.newaxis, :, :, | |
np.newaxis] | |
padding = [(0, 0), (y, height - y - patch_size[0]), | |
(x, width - x - patch_size[1]), (0, 0)] | |
flows += np.pad(flow_piece * weights, padding) | |
flow_count += np.pad(weights, padding) | |
# delete activations to avoid OOM | |
del output | |
flows /= flow_count | |
return flows | |
def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20): | |
if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]: | |
raise ValueError( | |
f"Overlap should be less than size of patch (got {min_overlap}" | |
f"for patch size {patch_size}).") | |
ys = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap)) | |
xs = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap)) | |
# Make sure the final patch is flush with the image boundary | |
ys[-1] = image_shape[0] - patch_size[0] | |
xs[-1] = image_shape[1] - patch_size[1] | |
return itertools.product(ys, xs) | |
def return_flow(flow): | |
flow = np.array(flow) | |
# Use Hue, Saturation, Value colour model | |
hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) | |
hsv[..., 2] = 255 | |
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) | |
hsv[..., 0] = ang / np.pi / 2 * 180 | |
hsv[..., 1] = np.clip(mag * 255 / 24, 0, 255) | |
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) | |
return Image.fromarray(bgr) | |
# load image examples | |
urls = ["https://storage.googleapis.com/perceiver_io/sintel_frame1.png", "https://storage.googleapis.com/perceiver_io/sintel_frame2.png"] | |
for idx, url in enumerate(urls): | |
image = Image.open(requests.get(url, stream=True).raw) | |
image.save(f"image_{idx}.png") | |
def process_images(image1, image2): | |
im1 = np.array(image1) | |
im2 = np.array(image2) | |
# Divide images into patches, compute flow between corresponding patches | |
# of both images, and stitch the flows together | |
grid_indices = compute_grid_indices(im1.shape) | |
output = compute_optical_flow(model, normalize(im1), normalize(im2), grid_indices) | |
# return as PIL Image | |
predicted_flow = return_flow(output[0]) | |
return predicted_flow | |
title = "Interactive demo: Perceiver for optical flow" | |
description = "Demo for predicting optical flow (i.e. the task of, given 2 images, estimating the 2D displacement for each pixel in the first image) with Perceiver IO. To use it, simply upload 2 images (e.g. 2 subsequent frames) or use the example images below and click 'submit' to let the model predict the flow of the pixels. Results will show up in a few seconds." | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.14795'>Perceiver IO: A General Architecture for Structured Inputs & Outputs</a> | <a href='https://deepmind.com/blog/article/building-architectures-that-can-handle-the-worlds-data/'>Official blog</a></p>" | |
examples =[[f"image_{idx}.png" for idx in range(len(urls))]] | |
iface = gr.Interface(fn=process_images, | |
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Image(type="pil")], | |
outputs=gr.outputs.Image(type="pil"), | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
enable_queue=True) | |
iface.launch(debug=True) |