|
import numpy as np |
|
from sentence_transformers import SentenceTransformer, util |
|
from open_clip import create_model_from_pretrained, get_tokenizer |
|
import torch |
|
from datasets import load_dataset |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import torch.nn as nn |
|
import boto3 |
|
import streamlit as st |
|
from PIL import Image |
|
from io import BytesIO |
|
from typing import List, Union |
|
|
|
|
|
|
|
model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384') |
|
tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384') |
|
|
|
|
|
|
|
def encode_query(query: Union[str, Image.Image]) -> torch.Tensor: |
|
""" |
|
Encode the query using the OpenCLIP model. |
|
|
|
Parameters |
|
---------- |
|
query : Union[str, Image.Image] |
|
The query, which can be a text string or an Image object. |
|
|
|
Returns |
|
------- |
|
torch.Tensor |
|
The encoded query vector. |
|
""" |
|
if isinstance(query, Image.Image): |
|
query = preprocess(query).unsqueeze(0) |
|
with torch.no_grad(): |
|
query_embedding = model.encode_image(query) |
|
elif isinstance(query, str): |
|
text = tokenizer(query, context_length=model.context_length) |
|
with torch.no_grad(): |
|
query_embedding = model.encode_text(text) |
|
else: |
|
raise ValueError("Query must be either a string or an Image.") |
|
|
|
return query_embedding |
|
|
|
def load_hf_datasets(dataset_name): |
|
""" |
|
Load Datasets from Hugging Face as DF |
|
--------------------------------------- |
|
dataset_name: str - name of dataset on Hugging Face |
|
--------------------------------------- |
|
|
|
RETURNS: dataset as pandas dataframe |
|
""" |
|
dataset = load_dataset(f"quasara-io/{dataset_name}") |
|
|
|
main_dataset = dataset['Main'] |
|
|
|
df = main_dataset.to_pandas() |
|
return df |
|
|
|
def get_image_vectors(df): |
|
|
|
image_vectors = np.vstack(df['Vector'].to_numpy()) |
|
return torch.tensor(image_vectors, dtype=torch.float32) |
|
|
|
|
|
def search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects): |
|
if search_in_images: |
|
|
|
query_vector = encode_query(query) |
|
|
|
|
|
|
|
image_vectors = get_image_vectors(df) |
|
|
|
|
|
|
|
query_vector = query_vector[0, :].detach().numpy() |
|
image_vectors = image_vectors.detach().numpy() |
|
cosine_similarities = cosine_similarity([query_vector], image_vectors) |
|
|
|
|
|
top_k_indices = np.argsort(-cosine_similarities[0])[:limit] |
|
|
|
|
|
return top_k_indices |
|
|
|
def get_file_paths(df, top_k_indices, column_name = 'File_Path'): |
|
""" |
|
Retrieve the file paths (or any specific column) from the DataFrame using the top K indices. |
|
|
|
Parameters: |
|
- df: pandas DataFrame containing the data |
|
- top_k_indices: numpy array of the top K indices |
|
- column_name: str, the name of the column to fetch (e.g., 'ImagePath') |
|
|
|
Returns: |
|
- top_k_paths: list of file paths or values from the specified column |
|
""" |
|
|
|
top_k_paths = df.iloc[top_k_indices][column_name].tolist() |
|
return top_k_paths |
|
|
|
|
|
def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY, folder_name= None): |
|
""" |
|
Retrieve and display images from AWS S3 in a Streamlit app. |
|
|
|
Parameters: |
|
- bucket_name: str, the name of the S3 bucket |
|
- file_paths: list, a list of file paths to retrieve from S3 |
|
|
|
Returns: |
|
- None (directly displays images in the Streamlit app) |
|
""" |
|
|
|
s3 = boto3.client( |
|
's3', |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY |
|
) |
|
|
|
|
|
for file_path in file_paths: |
|
|
|
s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}") |
|
img_data = s3_object['Body'].read() |
|
|
|
|
|
img = Image.open(BytesIO(img_data)) |
|
st.image(img, caption=file_path, use_column_width=True) |
|
|
|
|
|
|
|
def main(): |
|
dataset_name = "StopSign_test" |
|
query = "black car" |
|
limit = 10 |
|
offset = 0 |
|
scoring_func = "cosine" |
|
search_in_images = True |
|
search_in_small_objects = False |
|
|
|
df = load_hf_datasets(dataset_name) |
|
results = search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects) |
|
top_k_paths = get_file_paths(df,results) |
|
return top_k_paths |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|