inie2003's picture
added helper files
20ea451 verified
raw
history blame
5.22 kB
import numpy as np
from sentence_transformers import SentenceTransformer, util
from open_clip import create_model_from_pretrained, get_tokenizer
import torch
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn as nn
import boto3
import streamlit as st
from PIL import Image
from io import BytesIO
from typing import List, Union
# Initialize the model globally to avoid reloading each time
model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
#what model do we use?
def encode_query(query: Union[str, Image.Image]) -> torch.Tensor:
"""
Encode the query using the OpenCLIP model.
Parameters
----------
query : Union[str, Image.Image]
The query, which can be a text string or an Image object.
Returns
-------
torch.Tensor
The encoded query vector.
"""
if isinstance(query, Image.Image):
query = preprocess(query).unsqueeze(0) # Preprocess the image and add batch dimension
with torch.no_grad():
query_embedding = model.encode_image(query) # Get image embedding
elif isinstance(query, str):
text = tokenizer(query, context_length=model.context_length)
with torch.no_grad():
query_embedding = model.encode_text(text) # Get text embedding
else:
raise ValueError("Query must be either a string or an Image.")
return query_embedding
def load_hf_datasets(dataset_name):
"""
Load Datasets from Hugging Face as DF
---------------------------------------
dataset_name: str - name of dataset on Hugging Face
---------------------------------------
RETURNS: dataset as pandas dataframe
"""
dataset = load_dataset(f"quasara-io/{dataset_name}")
# Access only the 'Main' split
main_dataset = dataset['Main']
# Convert to Pandas DataFrame
df = main_dataset.to_pandas()
return df
def get_image_vectors(df):
# Get the image vectors from the dataframe
image_vectors = np.vstack(df['Vector'].to_numpy())
return torch.tensor(image_vectors, dtype=torch.float32)
def search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects):
if search_in_images:
# Encode the image query
query_vector = encode_query(query)
# Get the image vectors from the dataframe
image_vectors = get_image_vectors(df)
# Calculate the cosine similarity between the query vector and each image vector
query_vector = query_vector[0, :].detach().numpy() # Detach and convert to a NumPy array
image_vectors = image_vectors.detach().numpy() # Convert the image vectors to a NumPy array
cosine_similarities = cosine_similarity([query_vector], image_vectors)
# Get the top K indices of the most similar image vectors
top_k_indices = np.argsort(-cosine_similarities[0])[:limit]
# Return the top K indices
return top_k_indices
def get_file_paths(df, top_k_indices, column_name = 'File_Path'):
"""
Retrieve the file paths (or any specific column) from the DataFrame using the top K indices.
Parameters:
- df: pandas DataFrame containing the data
- top_k_indices: numpy array of the top K indices
- column_name: str, the name of the column to fetch (e.g., 'ImagePath')
Returns:
- top_k_paths: list of file paths or values from the specified column
"""
# Fetch the specific column corresponding to the top K indices
top_k_paths = df.iloc[top_k_indices][column_name].tolist()
return top_k_paths
def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY, folder_name= None):
"""
Retrieve and display images from AWS S3 in a Streamlit app.
Parameters:
- bucket_name: str, the name of the S3 bucket
- file_paths: list, a list of file paths to retrieve from S3
Returns:
- None (directly displays images in the Streamlit app)
"""
# Initialize S3 client
s3 = boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
)
# Iterate over file paths and display each image
for file_path in file_paths:
# Retrieve the image from S3
s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}")
img_data = s3_object['Body'].read()
# Open the image using PIL and display it using Streamlit
img = Image.open(BytesIO(img_data))
st.image(img, caption=file_path, use_column_width=True)
def main():
dataset_name = "StopSign_test"
query = "black car"
limit = 10
offset = 0
scoring_func = "cosine"
search_in_images = True
search_in_small_objects = False
df = load_hf_datasets(dataset_name)
results = search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects)
top_k_paths = get_file_paths(df,results)
return top_k_paths
if __name__ == "__main__":
main()