Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
from cv2 import imread, cvtColor, COLOR_BGR2GRAY, COLOR_BGR2BGRA, COLOR_BGRA2RGB, threshold, THRESH_BINARY_INV, findContours, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE, contourArea, minEnclosingCircle | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
def convert_images_to_grayscale(folder_path): | |
# Check if the folder exists | |
if not os.path.isdir(folder_path): | |
print(f"The folder path {folder_path} does not exist.") | |
return | |
# Iterate over all files in the folder | |
for filename in os.listdir(folder_path): | |
if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')): | |
image_path = os.path.join(folder_path, filename) | |
# Open an image file | |
with Image.open(image_path) as img: | |
# Convert image to grayscale | |
grayscale_img = img.convert('L').convert('RGB') | |
grayscale_img.save(os.path.join(folder_path, filename)) | |
def crop_center_largest_contour(folder_path): | |
for each_image in os.listdir(folder_path): | |
image_path = os.path.join(folder_path, each_image) | |
image = imread(image_path) | |
gray_image = cvtColor(image, COLOR_BGR2GRAY) | |
# Threshold the image to get the non-white pixels | |
_, binary_mask = threshold(gray_image, 254, 255, THRESH_BINARY_INV) | |
# Find the largest contour | |
contours, _ = findContours(binary_mask, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE) | |
largest_contour = max(contours, key=contourArea) | |
# Get the minimum enclosing circle | |
(x, y), radius = minEnclosingCircle(largest_contour) | |
center = (int(x), int(y)) | |
radius = int(radius/3) # Divide by three (arbitrary) to make shape better | |
# Crop the image to the bounding box of the circle | |
x_min = max(0, center[0] - radius) | |
x_max = min(image.shape[1], center[0] + radius) | |
y_min = max(0, center[1] - radius) | |
y_max = min(image.shape[0], center[1] + radius) | |
cropped_image = image[y_min:y_max, x_min:x_max] | |
cropped_image_rgba = cvtColor(cropped_image, COLOR_BGR2BGRA) | |
cropped_pil_image = Image.fromarray(cvtColor(cropped_image_rgba, COLOR_BGRA2RGB)) | |
cropped_pil_image.save(image_path) | |
def extract_embeddings(transformation_chain, model: torch.nn.Module): | |
"""Utility to compute embeddings.""" | |
device = model.device | |
def pp(batch): | |
images = batch["image"] | |
image_batch_transformed = torch.stack( | |
[transformation_chain(image) for image in images] | |
) | |
new_batch = {"pixel_values": image_batch_transformed.to(device)} | |
with torch.no_grad(): | |
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu() | |
return {"embeddings": embeddings} | |
return pp | |
def compute_scores(emb_one, emb_two): | |
"""Computes cosine similarity between two vectors.""" | |
scores = torch.nn.functional.cosine_similarity(emb_one, emb_two) | |
return scores.numpy().tolist() | |
def fetch_similar(image, transformation_chain, device, model, all_candidate_embeddings, candidate_ids, top_k=3): | |
"""Fetches the `top_k` similar images with `image` as the query.""" | |
# Prepare the input query image for embedding computation. | |
image_transformed = transformation_chain(image).unsqueeze(0) | |
new_batch = {"pixel_values": image_transformed.to(device)} | |
# Compute the embedding. | |
with torch.no_grad(): | |
query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu() | |
# Compute similarity scores with all the candidate images at one go. | |
# We also create a mapping between the candidate image identifiers | |
# and their similarity scores with the query image. | |
sim_scores = compute_scores(all_candidate_embeddings, query_embeddings) | |
similarity_mapping = dict(zip(candidate_ids, sim_scores)) | |
# Sort the mapping dictionary and return `top_k` candidates. | |
similarity_mapping_sorted = dict( | |
sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True) | |
) | |
id_entries = list(similarity_mapping_sorted.keys())[:top_k] | |
ids = list(map(lambda x: int(x.split("_")[0]), id_entries)) | |
return ids | |
def plot_images(images): | |
plt.figure(figsize=(20, 10)) | |
columns = 6 | |
for (i, image) in enumerate(images): | |
ax = plt.subplot(int(len(images) / columns + 1), columns, i + 1) | |
if i == 0: | |
ax.set_title("Query Image\n") | |
else: | |
ax.set_title( | |
"Similar Image # " + str(i) | |
) | |
plt.imshow(np.array(image).astype("int")) | |
plt.axis("off") | |