Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,349 Bytes
ebf587a b8884cd e43d011 ebf587a 4940b50 ebf587a 5709043 e43d011 5cdd253 e43d011 b8884cd ebf587a e2db66d ebf587a b8884cd ebf587a b8884cd ebf587a b8884cd ebf587a 5079be4 b8884cd 5cdd253 55c92cf ebf587a e43d011 ebf587a 5cdd253 5b04e06 5cdd253 b8884cd 5cdd253 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import torch
import kornia as K
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.vq import kmeans
from PIL import Image
def get_coordinates_from_mask(mask_in):
x_y = np.where(mask_in != [0,0,0,255])[:2]
x_y = np.column_stack((x_y[1], x_y[0]))
x_y = np.float32(x_y)
centroids,_ = kmeans(x_y,4)
centroids = np.int64(centroids)
return centroids
def get_top_bottom_coordinates(coords):
top_coord = min(coords, key=lambda x : x[1])
bottom_coord = max(coords, key=lambda x : x[1])
return top_coord, bottom_coord
def sort_centroids_clockwise(centroids: np.ndarray):
c_list = centroids.tolist()
c_list.sort(key = lambda y : y[0])
left_coords = c_list[:2]
right_coords = c_list[-2:]
top_left, bottom_left = get_top_bottom_coordinates(left_coords)
top_right, bottom_right = get_top_bottom_coordinates(right_coords)
return top_left, top_right, bottom_right, bottom_left
def infer(image_input, dst_height: str, dst_width: str):
if isinstance(image_input, dict):
image_in = np.array(image_input['composite'])
mask_in = np.array(image_input['layers'][0]) if image_input['layers'] else np.zeros_like(image_in)
else:
image_in = image_input
mask_in = np.zeros_like(image_in)
torch_img = K.utils.image_to_tensor(image_in).float() / 255.0
centroids = get_coordinates_from_mask(mask_in)
ordered_src_coords = sort_centroids_clockwise(centroids)
# the source points are the region to crop corners
points_src = torch.tensor([list(ordered_src_coords)], dtype=torch.float32)
# the destination points are the image vertexes
h, w = int(dst_height), int(dst_width) # destination size
points_dst = torch.tensor([[
[0., 0.], [w - 1., 0.], [w - 1., h - 1.], [0., h - 1.],
]], dtype=torch.float32)
# compute perspective transform
M: torch.tensor = K.geometry.transform.get_perspective_transform(points_src, points_dst)
# warp the original image by the found transform
torch_img = torch.stack([torch_img],)
img_warp: torch.tensor = K.geometry.transform.warp_perspective(torch_img, M, dsize=(h, w))
# convert back to numpy
img_np = K.utils.tensor_to_image(torch_img[0])
img_warp_np: np.ndarray = K.utils.tensor_to_image(img_warp[0])
# draw points into original image
for i in range(4):
center = tuple(points_src[0, i].long().numpy())
img_np = cv2.circle(img_np.copy(), center, 5, (0, 255, 0), -1)
# create the plot
fig, axs = plt.subplots(1, 2, figsize=(16, 10))
axs = axs.ravel()
axs[0].axis('off')
axs[0].set_title('image source')
axs[0].imshow(img_np)
axs[1].axis('off')
axs[1].set_title('image destination')
axs[1].imshow(img_warp_np)
return fig
description = """In this space you can warp an image using perspective transform with the Kornia library as seen in [this tutorial](https://kornia.github.io/tutorials/#category=Homography).
1. Upload an image or use the example provided
2. Set 4 points into the image using the brush tool, which define the area to warp
3. Set a desired output size (or go with the default)
4. Click Submit to run the demo
"""
# Load the example image
example_image = Image.open("bruce.png")
example_image_np = np.array(example_image)
with gr.Blocks() as demo:
gr.Markdown("# Homography Warping")
gr.Markdown(description)
with gr.Row():
image_input = gr.ImageEditor(
type="numpy",
label="Input Image",
brush=gr.Brush(colors=["#ff0000"], default_size=5),
height=400,
width=600
)
output_plot = gr.Plot(label="Output")
with gr.Row():
dst_height = gr.Textbox(label="Destination Height", value="64")
dst_width = gr.Textbox(label="Destination Width", value="128")
submit_button = gr.Button("Submit")
submit_button.click(
fn=infer,
inputs=[image_input, dst_height, dst_width],
outputs=output_plot
)
gr.Examples(
examples=[[example_image_np, "64", "128"]],
inputs=[image_input, dst_height, dst_width],
outputs=output_plot,
fn=infer,
cache_examples=True
)
if __name__ == "__main__":
demo.launch() |