seungpyo-hong's picture
Update app.py
2667d51 verified
import numpy as np
import gradio as gr
from PIL import Image
import cv2
from skimage import color
from sklearn.cluster import KMeans
from typing import Tuple
def proc(img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
assert img.shape[-1] == 3
k_size = 11
sigma = 11
blurred = cv2.GaussianBlur(img, (k_size, k_size), sigma)
blurred_small = cv2.resize(blurred, (80, 80))
labs = color.rgb2lab(blurred_small)
lab_vectors = labs.reshape(-1, 3)
num_colors = 5
num_bins = 10
km = KMeans(n_clusters=num_colors)
km.fit(lab_vectors)
centroid_labs = km.cluster_centers_ # N x (L, a, b)
centroid_labs = np.array(
sorted(centroid_labs, key=lambda x: x[1] ** 2 + x[2] ** 2)
) # sort by L
seeds = np.log(np.arange(0, 100, num_bins) + num_bins)
ls = seeds * 100 / seeds[-1]
centroid_ls = np.clip(ls, 0, 100).reshape(1, num_bins, 1).repeat(num_colors, axis=0)
centroid_abs = centroid_labs[:, np.newaxis, 1:].repeat(num_bins, axis=1)
centroid_labs = np.concatenate([centroid_ls, centroid_abs], axis=-1).reshape(
num_colors, num_bins, 3
)
unique_indices = [0] + [
i
for i in range(1, num_colors)
if np.linalg.norm(centroid_labs[i] - centroid_labs[i - 1]) > 10
]
centroid_labs = centroid_labs[unique_indices, :, :]
centroid_rgbs = (color.lab2rgb(centroid_labs) * 255).astype(np.uint8)
centroid_rgb_vis = cv2.resize(
centroid_rgbs,
(int(img.shape[0] / num_colors * num_bins), img.shape[0]),
interpolation=cv2.INTER_NEAREST,
)
return centroid_rgb_vis
demo = gr.Interface(fn=proc, inputs="image", outputs="image")
if __name__ == "__main__":
demo.launch()