Spaces:
Configuration error
Configuration error
Upload phash_jax.py
Browse files- phash_jax.py +21 -0
phash_jax.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
|
4 |
+
def convert_L(image):
|
5 |
+
#convert image to greyscale using the ITU-R 601-2 luma transform
|
6 |
+
# PIL.Image convert('L') method actually uses Floyd-Steinberg dithering
|
7 |
+
return jnp.maximum(jnp.minimum(image[:,:,0] * 0.299 + image[:,:,1] * 0.587 + image[:,:,2] * 0.114, 255), 0).astype("uint8")
|
8 |
+
|
9 |
+
def phash_jax(image, hash_size=8, highfreq_factor=4):
|
10 |
+
img_size = hash_size * highfreq_factor
|
11 |
+
image = jax.image.resize(convert_L(image), [img_size, img_size], "lanczos3") #convert to greyscale
|
12 |
+
dct = jax.scipy.fft.dct(jax.scipy.fft.dct(image, axis=0), axis=1)
|
13 |
+
dctlowfreq = dct[:hash_size, :hash_size]
|
14 |
+
med = jnp.median(dctlowfreq)
|
15 |
+
diff = dctlowfreq > med
|
16 |
+
return diff
|
17 |
+
|
18 |
+
def hash_dist(h1, h2):
|
19 |
+
return jnp.count_nonzero(h1.flatten() != h2.flatten())
|
20 |
+
|
21 |
+
batch_phash = jax.vmap(jax.jit(phash_jax))
|