gigant commited on
Commit
79b8121
1 Parent(s): d8810be

Upload phash_jax.py

Browse files
Files changed (1) hide show
  1. phash_jax.py +21 -0
phash_jax.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ def convert_L(image):
5
+ #convert image to greyscale using the ITU-R 601-2 luma transform
6
+ # PIL.Image convert('L') method actually uses Floyd-Steinberg dithering
7
+ return jnp.maximum(jnp.minimum(image[:,:,0] * 0.299 + image[:,:,1] * 0.587 + image[:,:,2] * 0.114, 255), 0).astype("uint8")
8
+
9
+ def phash_jax(image, hash_size=8, highfreq_factor=4):
10
+ img_size = hash_size * highfreq_factor
11
+ image = jax.image.resize(convert_L(image), [img_size, img_size], "lanczos3") #convert to greyscale
12
+ dct = jax.scipy.fft.dct(jax.scipy.fft.dct(image, axis=0), axis=1)
13
+ dctlowfreq = dct[:hash_size, :hash_size]
14
+ med = jnp.median(dctlowfreq)
15
+ diff = dctlowfreq > med
16
+ return diff
17
+
18
+ def hash_dist(h1, h2):
19
+ return jnp.count_nonzero(h1.flatten() != h2.flatten())
20
+
21
+ batch_phash = jax.vmap(jax.jit(phash_jax))