File size: 5,215 Bytes
20ea451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import numpy as np
from sentence_transformers import SentenceTransformer, util
from open_clip import create_model_from_pretrained, get_tokenizer
import torch
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn as nn
import boto3
import streamlit as st
from PIL import Image
from io import BytesIO
from typing import List, Union


# Initialize the model globally to avoid reloading each time
model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384')

#what model do we use? 

def encode_query(query: Union[str, Image.Image]) -> torch.Tensor:
    """
    Encode the query using the OpenCLIP model.

    Parameters
    ----------
    query : Union[str, Image.Image]
        The query, which can be a text string or an Image object.

    Returns
    -------
    torch.Tensor
        The encoded query vector.
    """
    if isinstance(query, Image.Image):
        query = preprocess(query).unsqueeze(0)  # Preprocess the image and add batch dimension
        with torch.no_grad():
            query_embedding = model.encode_image(query)  # Get image embedding
    elif isinstance(query, str):
        text = tokenizer(query, context_length=model.context_length)
        with torch.no_grad():
            query_embedding = model.encode_text(text)  # Get text embedding
    else:
        raise ValueError("Query must be either a string or an Image.")
    
    return query_embedding

def load_hf_datasets(dataset_name):
    """
    Load Datasets from Hugging Face as DF
    ---------------------------------------
    dataset_name: str - name of dataset on Hugging Face
    ---------------------------------------

    RETURNS: dataset as pandas dataframe
    """
    dataset = load_dataset(f"quasara-io/{dataset_name}")
    # Access only the 'Main' split
    main_dataset = dataset['Main']
    # Convert to Pandas DataFrame
    df = main_dataset.to_pandas()
    return df

def get_image_vectors(df):
    # Get the image vectors from the dataframe
    image_vectors = np.vstack(df['Vector'].to_numpy())
    return torch.tensor(image_vectors, dtype=torch.float32)


def search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects):
    if search_in_images:
        # Encode the image query
        query_vector = encode_query(query)
        
        
        # Get the image vectors from the dataframe
        image_vectors = get_image_vectors(df)
        
        
        # Calculate the cosine similarity between the query vector and each image vector
        query_vector = query_vector[0, :].detach().numpy()  # Detach and convert to a NumPy array
        image_vectors = image_vectors.detach().numpy()  # Convert the image vectors to a NumPy array
        cosine_similarities = cosine_similarity([query_vector], image_vectors)

        # Get the top K indices of the most similar image vectors
        top_k_indices = np.argsort(-cosine_similarities[0])[:limit]

        # Return the top K indices
        return top_k_indices

def get_file_paths(df, top_k_indices, column_name = 'File_Path'):
    """
    Retrieve the file paths (or any specific column) from the DataFrame using the top K indices.
    
    Parameters:
    - df: pandas DataFrame containing the data
    - top_k_indices: numpy array of the top K indices
    - column_name: str, the name of the column to fetch (e.g., 'ImagePath')
    
    Returns:
    - top_k_paths: list of file paths or values from the specified column
    """
    # Fetch the specific column corresponding to the top K indices
    top_k_paths = df.iloc[top_k_indices][column_name].tolist()
    return top_k_paths


def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY,  folder_name= None):
    """
    Retrieve and display images from AWS S3 in a Streamlit app.
    
    Parameters:
    - bucket_name: str, the name of the S3 bucket
    - file_paths: list, a list of file paths to retrieve from S3
    
    Returns:
    - None (directly displays images in the Streamlit app)
    """
    # Initialize S3 client
    s3 = boto3.client(
            's3',
            aws_access_key_id=AWS_ACCESS_KEY_ID,
            aws_secret_access_key=AWS_SECRET_ACCESS_KEY
        )

    # Iterate over file paths and display each image
    for file_path in file_paths:
        # Retrieve the image from S3
        s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}")
        img_data = s3_object['Body'].read()
        
        # Open the image using PIL and display it using Streamlit
        img = Image.open(BytesIO(img_data))
        st.image(img, caption=file_path, use_column_width=True)



def main():
    dataset_name = "StopSign_test"
    query = "black car"
    limit = 10
    offset = 0
    scoring_func = "cosine"
    search_in_images = True
    search_in_small_objects = False

    df = load_hf_datasets(dataset_name)
    results = search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects)
    top_k_paths = get_file_paths(df,results)
    return top_k_paths


if __name__ == "__main__":
    main()