Spaces:
Runtime error
Runtime error
File size: 6,592 Bytes
67a6282 0ff294a 67a6282 b548269 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import gradio as gr
from transformers import PerceiverForOpticalFlow
import torch
import torch.nn.functional as F
import numpy as np
import requests
from PIL import Image
import matplotlib.pyplot as plt
import itertools
import math
import cv2
model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")
TRAIN_SIZE = model.config.train_size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def normalize(im):
return im / 255.0 * 2 - 1
# source: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/9
def extract_image_patches(x, kernel, stride=1, dilation=1):
# Do TF 'SAME' Padding
b,c,h,w = x.shape
h2 = math.ceil(h / stride)
w2 = math.ceil(w / stride)
pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
# Extract patches
patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
patches = patches.permute(0,4,5,1,2,3).contiguous()
return patches.view(b,-1,patches.shape[-2], patches.shape[-1])
def compute_optical_flow(model, img1, img2, grid_indices, FLOW_SCALE_FACTOR = 20):
"""Function to compute optical flow between two images.
To compute the flow between images of arbitrary sizes, we divide the image
into patches, compute the flow for each patch, and stitch the flows together.
Args:
model: PyTorch Perceiver model
img1: first image
img2: second image
grid_indices: indices of the upper left corner for each patch.
"""
img1 = torch.tensor(np.moveaxis(img1, -1, 0))
img2 = torch.tensor(np.moveaxis(img2, -1, 0))
imgs = torch.stack([img1, img2], dim=0)[None]
height = imgs.shape[-2]
width = imgs.shape[-1]
patch_size = model.config.train_size
if height < patch_size[0]:
raise ValueError(
f"Height of image (shape: {imgs.shape}) must be at least {patch_size[0]}."
"Please pad or resize your image to the minimum dimension."
)
if width < patch_size[1]:
raise ValueError(
f"Width of image (shape: {imgs.shape}) must be at least {patch_size[1]}."
"Please pad or resize your image to the minimum dimension."
)
flows = 0
flow_count = 0
for y, x in grid_indices:
imgs = torch.stack([img1, img2], dim=0)[None]
inp_piece = imgs[..., y : y + patch_size[0],
x : x + patch_size[1]]
batch_size, _, C, H, W = inp_piece.shape
patches = extract_image_patches(inp_piece.view(batch_size*2,C,H,W), kernel=3)
_, C, H, W = patches.shape
patches = patches.view(batch_size, -1, C, H, W).float().to(model.device)
# actual forward pass
with torch.no_grad():
output = model(inputs=patches).logits * FLOW_SCALE_FACTOR
# the code below could also be implemented in PyTorch
flow_piece = output.cpu().detach().numpy()
weights_x, weights_y = np.meshgrid(
torch.arange(patch_size[1]), torch.arange(patch_size[0]))
weights_x = np.minimum(weights_x + 1, patch_size[1] - weights_x)
weights_y = np.minimum(weights_y + 1, patch_size[0] - weights_y)
weights = np.minimum(weights_x, weights_y)[np.newaxis, :, :,
np.newaxis]
padding = [(0, 0), (y, height - y - patch_size[0]),
(x, width - x - patch_size[1]), (0, 0)]
flows += np.pad(flow_piece * weights, padding)
flow_count += np.pad(weights, padding)
# delete activations to avoid OOM
del output
flows /= flow_count
return flows
def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20):
if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]:
raise ValueError(
f"Overlap should be less than size of patch (got {min_overlap}"
f"for patch size {patch_size}).")
ys = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap))
xs = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap))
# Make sure the final patch is flush with the image boundary
ys[-1] = image_shape[0] - patch_size[0]
xs[-1] = image_shape[1] - patch_size[1]
return itertools.product(ys, xs)
def return_flow(flow):
flow = np.array(flow)
# Use Hue, Saturation, Value colour model
hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
hsv[..., 2] = 255
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
hsv[..., 0] = ang / np.pi / 2 * 180
hsv[..., 1] = np.clip(mag * 255 / 24, 0, 255)
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return Image.fromarray(bgr)
# load image examples
urls = ["https://storage.googleapis.com/perceiver_io/sintel_frame1.png", "https://storage.googleapis.com/perceiver_io/sintel_frame2.png"]
for idx, url in enumerate(urls):
image = Image.open(requests.get(url, stream=True).raw)
image.save(f"image_{idx}.png")
def process_images(image1, image2):
im1 = np.array(image1)
im2 = np.array(image2)
# Divide images into patches, compute flow between corresponding patches
# of both images, and stitch the flows together
grid_indices = compute_grid_indices(im1.shape)
output = compute_optical_flow(model, normalize(im1), normalize(im2), grid_indices)
# return as PIL Image
predicted_flow = return_flow(output[0])
return predicted_flow
title = "Interactive demo: Perceiver for optical flow"
description = "Demo for predicting optical flow (i.e. the task of, given 2 images, estimating the 2D displacement for each pixel in the first image) with Perceiver IO. To use it, simply upload 2 images (e.g. 2 subsequent frames) or use the example images below and click 'submit' to let the model predict the flow of the pixels. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.14795'>Perceiver IO: A General Architecture for Structured Inputs & Outputs</a> | <a href='https://deepmind.com/blog/article/building-architectures-that-can-handle-the-worlds-data/'>Official blog</a></p>"
examples =[[f"image_{idx}.png" for idx in range(len(urls))]]
iface = gr.Interface(fn=process_images,
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Image(type="pil")],
outputs=gr.outputs.Image(type="pil"),
title=title,
description=description,
article=article,
examples=examples,
enable_queue=True)
iface.launch(debug=True) |