File size: 3,809 Bytes
1801c3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import base64
import os
from dataclasses import dataclass
from typing import Final

import faiss
import numpy as np
import pandas as pd
import streamlit as st

from pipeline import clip_wrapper


class SemanticSearcher:
    def __init__(self, dataset: pd.DataFrame):
        dim_columns = dataset.filter(regex="^dim_").columns

        self.embedder = clip_wrapper.ClipWrapper().texts2vec
        self.metadata = dataset.drop(columns=dim_columns)
        self.index = faiss.IndexFlatIP(len(dim_columns))
        self.index.add(np.ascontiguousarray(dataset[dim_columns].to_numpy(np.float32)))

    def search(self, query: str) -> list["SearchResult"]:
        v = self.embedder([query]).detach().numpy()
        D, I = self.index.search(v, 10)
        return [
            SearchResult(
                video_id=row["video_id"],
                frame_idx=row["frame_idx"],
                timestamp=row["timestamp"],
                score=score,
            )
            for score, (_, row) in zip(D[0], self.metadata.iloc[I[0]].iterrows())
        ]


DATASET_PATH: Final[str] = os.environ.get("DATASET_PATH", "data/dataset.parquet")
SEARCHER: Final[SemanticSearcher] = SemanticSearcher(pd.read_parquet(DATASET_PATH))


@dataclass
class SearchResult:
    video_id: str
    frame_idx: int
    timestamp: float
    score: float


def get_video_url(video_id: str, timestamp: float) -> str:
    return f"https://www.youtube.com/watch?v={video_id}&t={int(timestamp)}"


def display_search_results(results: list[SearchResult]) -> None:
    col_count = 3  # Number of videos per row

    col_num = 0  # Counter to keep track of the current column
    row = st.empty()  # Placeholder for the current row

    for i, result in enumerate(results):
        if col_num == 0:
            row = st.columns(col_count)  # Create a new row of columns

        with row[col_num]:
            # Apply CSS styling to the video container
            st.markdown(
                """
                <style>
                .video-container {
                    position: relative;
                    padding-bottom: 56.25%;
                    padding-top: 30px;
                    height: 0;
                    overflow: hidden;
                }
                .video-container iframe,
                .video-container object,
                .video-container embed {
                    position: absolute;
                    top: 0;
                    left: 0;
                    width: 100%;
                    height: 100%;
                }
                </style>
                """,
                unsafe_allow_html=True,
            )

            # Display the embedded YouTube video
            # st.video(get_video_url(result.video_id), start_time=int(result.timestamp))
            # st.image(f"data/images/{result.video_id}/{result.frame_idx}.jpg")
            with open(
                f"data/images/{result.video_id}/{result.frame_idx}.jpg", "rb"
            ) as f:
                image = f.read()
                encoded = base64.b64encode(image).decode()
            st.markdown(
                f"""
                <a href="{get_video_url(result.video_id, result.timestamp)}">
                <img src="data:image/jpeg;base64,{encoded}" alt="frame {result.frame_idx}" width="100%">
                </a>
                """,
                unsafe_allow_html=True,
            )

        col_num += 1
        if col_num >= col_count:
            col_num = 0


def main():
    st.set_page_config(page_title="video-semantic-search", layout="wide")
    st.header("Video Semantic Search")

    st.text_input("What are you looking for?", key="query")

    query = st.session_state["query"]
    if query:
        display_search_results(SEARCHER.search(query))


if __name__ == "__main__":
    main()