import base64 import hashlib import os import subprocess from dataclasses import dataclass from typing import Optional import faiss import numpy as np import pandas as pd import streamlit as st from streamlit import runtime from streamlit.logger import get_logger from streamlit.runtime.scriptrunner import get_script_run_ctx from pipeline import clip_wrapper from pipeline.process_videos import DATAFRAME_PATH NUM_FRAMES_TO_RETURN = 21 logger = get_logger(__name__) class SemanticSearcher: def __init__(self, dataset: pd.DataFrame): dim_columns = dataset.filter(regex="^dim_").columns self.embedder = clip_wrapper.ClipWrapper().texts2vec self.metadata = dataset.drop(columns=dim_columns) self.index = faiss.IndexFlatIP(len(dim_columns)) self.index.add(np.ascontiguousarray(dataset[dim_columns].to_numpy(np.float32))) def search(self, query: str) -> list["SearchResult"]: v = self.embedder([query]).detach().numpy() D, I = self.index.search(v, NUM_FRAMES_TO_RETURN) return [ SearchResult( video_id=row["video_id"], frame_idx=row["frame_idx"], timestamp=row["timestamp"], base64_image=row["base64_image"], score=score, ) for score, (_, row) in zip(D[0], self.metadata.iloc[I[0]].iterrows()) ] @st.cache_resource def get_semantic_searcher(): return SemanticSearcher(pd.read_parquet(DATAFRAME_PATH)) @st.cache_data def get_git_hash() -> Optional[str]: try: return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() except subprocess.CalledProcessError: return None @dataclass class SearchResult: video_id: str frame_idx: int timestamp: float base64_image: str score: float def get_video_url(video_id: str, timestamp: float) -> str: return f"https://www.youtube.com/watch?v={video_id}&t={int(timestamp)}" def display_search_results(results: list[SearchResult]) -> None: col_count = 3 # Number of videos per row col_num = 0 # Counter to keep track of the current column row = st.empty() # Placeholder for the current row for i, result in enumerate(results): if col_num == 0: row = st.columns(col_count) # Create a new row of columns with row[col_num]: # Apply CSS styling to the video container st.markdown( """ """, unsafe_allow_html=True, ) st.markdown( f""" frame {result.frame_idx} timestamp {int(result.timestamp)} """, unsafe_allow_html=True, ) col_num += 1 if col_num >= col_count: col_num = 0 def get_remote_ip() -> str: """Get remote ip.""" try: ctx = get_script_run_ctx() if ctx is None: return None session_info = runtime.get_instance().get_client(ctx.session_id) if session_info is None: return None except Exception as e: return None return session_info.request.remote_ip def main(): st.set_page_config(page_title="video-semantic-search", layout="wide") st.header("Visual content search over music videos") st.markdown("_App by Ben Tenmann and Sidney Radcliffe_") searcher = get_semantic_searcher() num_videos = len(searcher.metadata.video_id.unique()) st.text_input( f"What are you looking for? Search over {num_videos} music videos.", key="query" ) query = st.session_state["query"] if query: query_sha256 = hashlib.sha256(query.encode()).hexdigest()[:10] ip_sha256 = hashlib.sha256(get_remote_ip().encode()).hexdigest()[:10] logger.info(f"sha256(ip)={ip_sha256} sha256(query)={query_sha256}") st.text("Click image to open video") display_search_results(searcher.search(query)) if get_git_hash(): st.text(f"Build: {get_git_hash()[0:7]}") if __name__ == "__main__": main()