Spaces:
Configuration error
Configuration error
BATCH_SIZE = 64 | |
DOWNSAMPLE = 24 | |
FOLDER_PATH = "." | |
import phash_jax | |
import jax.numpy as jnp | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import statistics | |
from decord import VideoReader | |
from decord import cpu | |
import gradio | |
def binary_array_to_hex(arr): | |
""" | |
Function to make a hex string out of a binary array. | |
""" | |
bit_string = ''.join(str(b) for b in 1 * arr.flatten()) | |
width = int(jnp.ceil(len(bit_string) / 4)) | |
return '{:0>{width}x}'.format(int(bit_string, 2), width=width) | |
def compute_batch_hashes(vid_path): | |
kwargs={"width": 64, "height":64} | |
vr = VideoReader(vid_path, ctx=cpu(0), **kwargs) | |
hashes = [] | |
h_prev = None | |
batch = [] | |
for i in range(0, len(vr), DOWNSAMPLE * BATCH_SIZE): | |
print(f"batch_{i}") | |
ids = [id for id in range(i, min(i + DOWNSAMPLE * BATCH_SIZE, len(vr)), DOWNSAMPLE)] | |
vr.seek(0) | |
batch = jnp.array(vr.get_batch(ids).asnumpy()) | |
batch_h = phash_jax.batch_phash(batch) | |
for i in range(len(ids)): | |
h = batch_h[i] | |
if h_prev == None: | |
h_prev=h | |
hashes.append({"frame_id":ids[i], "hash": binary_array_to_hex(h), "distance": int(phash_jax.hash_dist(h, h_prev))}) | |
h_prev = h | |
return gradio.update(value=hashes, visible=False) | |
def plot_hash_distance(hashes, threshold): | |
fig = plt.figure() | |
ids = [h["frame_id"] for h in hashes] | |
distances = [h["distance"] for h in hashes] | |
plt.plot(ids, distances, ".") | |
plt.plot(ids, [threshold]* len(ids), "r-") | |
return fig | |
def compute_threshold(hashes): | |
min_length = 24 * 3 | |
ids = [h["frame_id"] for h in hashes] | |
distances = [h["distance"] for h in hashes] | |
thrs_ = sorted(list(set(distances)),reverse=True) | |
best = thrs_[0] - 1 | |
for threshold in thrs_[1:]: | |
durations = [] | |
i_start=0 | |
for i, h in enumerate(hashes): | |
if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length: | |
durations.append(hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"]) | |
i_start=i | |
if len(durations) < (len(hashes) * DOWNSAMPLE / 24) / 20: | |
best = threshold | |
return best | |
def get_slides(vid_path, hashes, threshold): | |
min_length = 24 * 1.5 | |
vr = VideoReader(vid_path, ctx=cpu(0)) | |
slideshow = [] | |
i_start = 0 | |
for i, h in enumerate(hashes): | |
if h["distance"] > threshold and hashes[i-1]["frame_id"] - hashes[i_start]["frame_id"] > min_length: | |
path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{i-1}.png' | |
Image.fromarray(vr[hashes[i-1]["frame_id"]].asnumpy()).save(path) | |
slideshow.append({"slide": path, "start": i_start, "end": i-1}) | |
i_start=i | |
path=f'{FOLDER_PATH}/{vid_path.split("/")[-1].split(".")[0]}_{i_start}_{len(vr)-1}.png' | |
Image.fromarray(vr[-1].asnumpy()).save(path) | |
slideshow.append({"slide": path, "start": i_start, "end": len(vr)-1}) | |
return [s["slide"] for s in slideshow] | |
def trigger_plots(f2f_distance_plot, hashes, threshold): | |
# if not hist_plot.get_config()["visible"] and len(hashes.get_config()["value"]) > 0 : | |
return gradio.update(value=plot_hash_distance(hashes, threshold)) | |
def set_visible(): | |
return gradio.update(visible=True) | |
demo = gradio.Blocks() | |
with demo: | |
with gradio.Row(): | |
with gradio.Column(): | |
with gradio.Row(): | |
vid=gradio.Video(mirror_webcam=False) | |
with gradio.Row(): | |
btn_vid_proc = gradio.Button("Compute hashes") | |
with gradio.Row(): | |
hist_plot = gradio.Plot(label="Frame to frame hash distance histogram", visible=False) | |
with gradio.Column(): | |
hashes = gradio.JSON() | |
with gradio.Column(visible=False) as result_row: | |
btn_plot = gradio.Button("Plot & compute optimal threshold") | |
threshold = gradio.Slider(minimum=1, maximum=30, value=5, label="Threshold") | |
f2f_distance_plot = gradio.Plot(label="Frame to frame hash distance") | |
btn_slides = gradio.Button("Extract Slides") | |
with gradio.Row(): | |
slideshow = gradio.Gallery(label="Extracted slides") | |
slideshow.style(grid=6) | |
btn_vid_proc.click(fn=compute_batch_hashes, inputs=[vid], outputs=[hashes]) | |
hashes.change(fn=set_visible, inputs=[], outputs=[result_row]) | |
btn_plot.click(fn=compute_threshold, inputs=[hashes], outputs=[threshold]) | |
btn_plot.click(fn=trigger_plots, inputs=[f2f_distance_plot, hashes, threshold], outputs=[f2f_distance_plot]) | |
threshold.change(fn=plot_hash_distance, inputs=[hashes, threshold], outputs=f2f_distance_plot) | |
btn_slides.click(fn=get_slides, inputs=[vid, hashes, threshold], outputs=[slideshow]) | |
demo.queue(default_enabled=True).launch() |