gigant's picture
Update app.py
7d6e745
raw
history blame
4.7 kB
BATCH_SIZE = 64
DOWNSAMPLE = 24
FOLDER_PATH = "."
import phash_jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from PIL import Image
import statistics
from decord import VideoReader
from decord import cpu
import gradio
def binary_array_to_hex(arr):
"""
Function to make a hex string out of a binary array.
"""
bit_string = ''.join(str(b) for b in 1 * arr.flatten())
width = int(jnp.ceil(len(bit_string) / 4))
return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
def compute_batch_hashes(vid_path):
kwargs={"width": 64, "height":64}
vr = VideoReader(vid_path, ctx=cpu(0), **kwargs)
hashes = []
h_prev = None
batch = []
for i in range(0, len(vr), DOWNSAMPLE * BATCH_SIZE):
print(f"batch_{i}")
ids = [id for id in range(i, min(i + DOWNSAMPLE * BATCH_SIZE, len(vr)), DOWNSAMPLE)]
vr.seek(0)
batch = jnp.array(vr.get_batch(ids).asnumpy())
batch_h = phash_jax.batch_phash(batch)
for i in range(len(ids)):
h = batch_h[i]
if h_prev == None:
h_prev=h
hashes.append({"frame_id":ids[i], "hash": binary_array_to_hex(h), "distance": int(phash_jax.hash_dist(h, h_prev))})
h_prev = h
return gradio.update(value=hashes, visible=False)
def plot_hash_distance(hashes, threshold):
fig = plt.figure()
ids = [h["frame_id"] for h in hashes]
distances = [h["distance"] for h in hashes]
plt.plot(ids, distances, ".")
plt.plot(ids, [threshold]* len(ids), "r-")
return fig
def compute_threshold(hashes):
min_length = 24 * 3
ids = [h["frame_id"] for h in hashes]
distances = [h["distance"] for h in hashes]
thrs_ = sorted(list(set(distances)),reverse=True)
best = thrs_[0] - 1
for threshold in thrs_[1:]:
durations = []
i_start=0
for i, h in enumerate(hashes):
if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length:
durations.append(hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"])
i_start=i
if len(durations) < (len(hashes) * DOWNSAMPLE / 24) / 20:
best = threshold
return best
def get_slides(vid_path, hashes, threshold):
min_length = 24 * 1.5
vr = VideoReader(vid_path, ctx=cpu(0))
slideshow = []
i_start = 0
for i, h in enumerate(hashes):
if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length:
path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{i-1}.png'
Image.fromarray(vr[hashes[i-1]["frame_id"]].asnumpy()).save(path)
slideshow.append({"slide": path, "start": i_start, "end": i-1})
i_start=i
path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{len(vr)-1}.png'
Image.fromarray(vr[-1].asnumpy()).save(path)
slideshow.append({"slide": path, "start": i_start, "end": len(vr)-1})
return [s["slide"] for s in slideshow]
def trigger_plots(f2f_distance_plot, hashes, threshold):
# if not hist_plot.get_config()["visible"] and len(hashes.get_config()["value"]) > 0 :
return gradio.update(value=plot_hash_distance(hashes, threshold))
def set_visible():
return gradio.update(visible=True)
demo = gradio.Blocks()
with demo:
with gradio.Row():
with gradio.Column():
with gradio.Row():
vid=gradio.Video(mirror_webcam=False)
with gradio.Row():
btn_vid_proc = gradio.Button("Compute hashes")
with gradio.Row():
hist_plot = gradio.Plot(label="Frame to frame hash distance histogram", visible=False)
with gradio.Column():
hashes = gradio.JSON()
with gradio.Column(visible=False) as result_row:
btn_plot = gradio.Button("Plot & compute optimal threshold")
threshold = gradio.Slider(minimum=1, maximum=30, value=5, label="Threshold")
f2f_distance_plot = gradio.Plot(label="Frame to frame hash distance")
btn_slides = gradio.Button("Extract Slides")
with gradio.Row():
slideshow = gradio.Gallery(label="Extracted slides")
slideshow.style(grid=6)
btn_vid_proc.click(fn=compute_batch_hashes, inputs=[vid], outputs=[hashes])
hashes.change(fn=set_visible, inputs=[], outputs=[result_row])
btn_plot.click(fn=compute_threshold, inputs=[hashes], outputs=[threshold])
btn_plot.click(fn=trigger_plots, inputs=[f2f_distance_plot, hashes, threshold], outputs=[f2f_distance_plot])
threshold.change(fn=plot_hash_distance, inputs=[hashes, threshold], outputs=f2f_distance_plot)
btn_slides.click(fn=get_slides, inputs=[vid, hashes, threshold], outputs=[slideshow])
demo.queue(default_enabled=True).launch()