import torch import torchvision.models as models import torchvision.transforms as transforms from PIL import Image import os import streamlit as st @st.cache() def load_model(): print("Loading model...") # Load the pre-trained ResNet-50 model and set it to eval mode model = models.resnet50(pretrained=True) model.eval() # Define the transform to be applied to each input image transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # Define the directory containing the input images input_dir = "lfw" # Create a dictionary to store the feature vectors features_dict = {} # Modify the ResNet model to return the output of the penultimate layer model = torch.nn.Sequential(*list(model.children())[:-1]) # Loop over each subdirectory and image file in the input directory for root, dirs, files in os.walk(input_dir): for file in files: # Check if the file is a JPEG image if file.endswith(".jpg"): # Load the image image = Image.open(os.path.join(root, file)) # Apply the transform to the image image = transform(image) # Reshape the image to add a batch dimension image = image.unsqueeze(0) # Extract the features from the model's penultimate layer with torch.no_grad(): features = model(image).squeeze() # Add the feature vector to the dictionary features_dict[os.path.join(root, file)] = features.numpy() return features_dict @st.cache() def create_nearest_neighbors_object(features_dict): print("Creating nearest neighbors object...") import numpy as np from sklearn.neighbors import NearestNeighbors # Create a list of the feature vectors features_list = list(features_dict.values()) # Create a NumPy array of the feature vectors features_array = np.array(features_list) # Create a nearest neighbors object nn = NearestNeighbors(n_neighbors=11, metric="euclidean") # Fit the nearest neighbors object to the feature vectors nn.fit(features_array) return nn # Create a get nearest neighbors function def get_nearest_neighbors(image_path): # Define the query image query_image = image_path print(query_image) # Loop through the dictionary to find the 10 nearest neighbors to the query image for key, value in features_dict.items(): if key == query_image: query_features = value query_features = query_features.reshape(1, -1) distances, indices = nn.kneighbors(query_features) indices = indices[0] distances = distances[0] for i in range(1, 11): image = Image.open(list(features_dict.keys())[indices[i]]) st.image( image, caption="Distance: " + str(distances[i]), use_column_width=True, ) # SteamLit App allow_output_mutation = True features_dict = load_model() nn = create_nearest_neighbors_object(features_dict) # Title st.title("Similarity Search") # Subtitle st.write("This app finds the 10 most similar images to a query image.") query_image = st.selectbox("Or select an image from the list", os.listdir("lfw")) # Display the query image from dir ./lfw/query_image/query_image_0001.jpg print("lfw/" + query_image + "/" + query_image + "_0001.jpg") st.image( "lfw/" + query_image + "/" + query_image + "_0001.jpg", caption="Query Image", use_column_width=True, ) # Find the 10 most similar images if st.button("Find Similar Images"): # Call the get nearest neighbors function nearest10 = get_nearest_neighbors( "lfw/" + query_image + "/" + query_image + "_0001.jpg" )