Spaces:
Configuration error
Configuration error
import jax | |
import jax.numpy as jnp | |
def convert_L(image): | |
#convert image to greyscale using the ITU-R 601-2 luma transform | |
# PIL.Image convert('L') method actually uses Floyd-Steinberg dithering | |
return jnp.maximum(jnp.minimum(image[:,:,0] * 0.299 + image[:,:,1] * 0.587 + image[:,:,2] * 0.114, 255), 0).astype("uint8") | |
def phash_jax(image, hash_size=8, highfreq_factor=4): | |
img_size = hash_size * highfreq_factor | |
image = jax.image.resize(convert_L(image), [img_size, img_size], "lanczos3") #convert to greyscale | |
dct = jax.scipy.fft.dct(jax.scipy.fft.dct(image, axis=0), axis=1) | |
dctlowfreq = dct[:hash_size, :hash_size] | |
med = jnp.median(dctlowfreq) | |
diff = dctlowfreq > med | |
return diff | |
def hash_dist(h1, h2): | |
return jnp.count_nonzero(h1.flatten() != h2.flatten()) | |
batch_phash = jax.vmap(jax.jit(phash_jax)) |