canadianjosieharrison commited on
Commit
b45ac3d
1 Parent(s): 4148997

Update image_helpers.py

Browse files
Files changed (1) hide show
  1. image_helpers.py +149 -113
image_helpers.py CHANGED
@@ -1,113 +1,149 @@
1
- import os
2
- from PIL import Image
3
- from cv2 import imread, cvtColor, COLOR_BGR2GRAY, COLOR_BGR2BGRA, COLOR_BGRA2RGB, threshold, THRESH_BINARY_INV, findContours, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE, contourArea, minEnclosingCircle
4
- import numpy as np
5
- import torch
6
- import matplotlib.pyplot as plt
7
-
8
- def convert_images_to_grayscale(folder_path):
9
- # Check if the folder exists
10
- if not os.path.isdir(folder_path):
11
- print(f"The folder path {folder_path} does not exist.")
12
- return
13
-
14
- # Iterate over all files in the folder
15
- for filename in os.listdir(folder_path):
16
- if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
17
- image_path = os.path.join(folder_path, filename)
18
-
19
- # Open an image file
20
- with Image.open(image_path) as img:
21
- # Convert image to grayscale
22
- grayscale_img = img.convert('L').convert('RGB')
23
- grayscale_img.save(os.path.join(folder_path, filename))
24
-
25
- def crop_center_largest_contour(folder_path):
26
- for each_image in os.listdir(folder_path):
27
- image_path = os.path.join(folder_path, each_image)
28
- image = imread(image_path)
29
- gray_image = cvtColor(image, COLOR_BGR2GRAY)
30
-
31
- # Threshold the image to get the non-white pixels
32
- _, binary_mask = threshold(gray_image, 254, 255, THRESH_BINARY_INV)
33
-
34
- # Find the largest contour
35
- contours, _ = findContours(binary_mask, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
36
- largest_contour = max(contours, key=contourArea)
37
-
38
- # Get the minimum enclosing circle
39
- (x, y), radius = minEnclosingCircle(largest_contour)
40
- center = (int(x), int(y))
41
- radius = int(radius/3) # Divide by three (arbitrary) to make shape better
42
-
43
- # Crop the image to the bounding box of the circle
44
- x_min = max(0, center[0] - radius)
45
- x_max = min(image.shape[1], center[0] + radius)
46
- y_min = max(0, center[1] - radius)
47
- y_max = min(image.shape[0], center[1] + radius)
48
- cropped_image = image[y_min:y_max, x_min:x_max]
49
- cropped_image_rgba = cvtColor(cropped_image, COLOR_BGR2BGRA)
50
- cropped_pil_image = Image.fromarray(cvtColor(cropped_image_rgba, COLOR_BGRA2RGB))
51
- cropped_pil_image.save(image_path)
52
-
53
- def extract_embeddings(transformation_chain, model: torch.nn.Module):
54
- """Utility to compute embeddings."""
55
- device = model.device
56
-
57
- def pp(batch):
58
- images = batch["image"]
59
- image_batch_transformed = torch.stack(
60
- [transformation_chain(image) for image in images]
61
- )
62
- new_batch = {"pixel_values": image_batch_transformed.to(device)}
63
- with torch.no_grad():
64
- embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
65
- return {"embeddings": embeddings}
66
-
67
- return pp
68
-
69
- def compute_scores(emb_one, emb_two):
70
- """Computes cosine similarity between two vectors."""
71
- scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
72
- return scores.numpy().tolist()
73
-
74
-
75
- def fetch_similar(image, transformation_chain, device, model, all_candidate_embeddings, candidate_ids, top_k=3):
76
- """Fetches the `top_k` similar images with `image` as the query."""
77
- # Prepare the input query image for embedding computation.
78
- image_transformed = transformation_chain(image).unsqueeze(0)
79
- new_batch = {"pixel_values": image_transformed.to(device)}
80
-
81
- # Compute the embedding.
82
- with torch.no_grad():
83
- query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
84
-
85
- # Compute similarity scores with all the candidate images at one go.
86
- # We also create a mapping between the candidate image identifiers
87
- # and their similarity scores with the query image.
88
- sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
89
- similarity_mapping = dict(zip(candidate_ids, sim_scores))
90
-
91
- # Sort the mapping dictionary and return `top_k` candidates.
92
- similarity_mapping_sorted = dict(
93
- sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
94
- )
95
- id_entries = list(similarity_mapping_sorted.keys())[:top_k]
96
-
97
- ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
98
- return ids
99
-
100
- def plot_images(images):
101
-
102
- plt.figure(figsize=(20, 10))
103
- columns = 6
104
- for (i, image) in enumerate(images):
105
- ax = plt.subplot(int(len(images) / columns + 1), columns, i + 1)
106
- if i == 0:
107
- ax.set_title("Query Image\n")
108
- else:
109
- ax.set_title(
110
- "Similar Image # " + str(i)
111
- )
112
- plt.imshow(np.array(image).astype("int"))
113
- plt.axis("off")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from cv2 import imread, cvtColor, COLOR_BGR2GRAY, COLOR_BGR2BGRA, COLOR_BGRA2RGB, threshold, THRESH_BINARY_INV, findContours, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE, contourArea, minEnclosingCircle
4
+ import numpy as np
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+
8
+ def convert_images_to_grayscale(folder_path):
9
+ # Check if the folder exists
10
+ if not os.path.isdir(folder_path):
11
+ print(f"The folder path {folder_path} does not exist.")
12
+ return
13
+
14
+ # Iterate over all files in the folder
15
+ for filename in os.listdir(folder_path):
16
+ if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
17
+ image_path = os.path.join(folder_path, filename)
18
+
19
+ # Open an image file
20
+ with Image.open(image_path) as img:
21
+ # Convert image to grayscale
22
+ grayscale_img = img.convert('L').convert('RGB')
23
+ grayscale_img.save(os.path.join(folder_path, filename))
24
+
25
+ def crop_center_largest_contour(folder_path):
26
+ for each_image in os.listdir(folder_path):
27
+ image_path = os.path.join(folder_path, each_image)
28
+ image = imread(image_path)
29
+ gray_image = cvtColor(image, COLOR_BGR2GRAY)
30
+
31
+ # Threshold the image to get the non-white pixels
32
+ _, binary_mask = threshold(gray_image, 254, 255, THRESH_BINARY_INV)
33
+
34
+ # Find the largest contour
35
+ contours, _ = findContours(binary_mask, RETR_EXTERNAL, CHAIN_APPROX_SIMPLE)
36
+ largest_contour = max(contours, key=contourArea)
37
+
38
+ # Get the minimum enclosing circle
39
+ (x, y), radius = minEnclosingCircle(largest_contour)
40
+ center = (int(x), int(y))
41
+ radius = int(radius/3) # Divide by three (arbitrary) to make shape better
42
+
43
+ # Crop the image to the bounding box of the circle
44
+ x_min = max(0, center[0] - radius)
45
+ x_max = min(image.shape[1], center[0] + radius)
46
+ y_min = max(0, center[1] - radius)
47
+ y_max = min(image.shape[0], center[1] + radius)
48
+ cropped_image = image[y_min:y_max, x_min:x_max]
49
+ cropped_image_rgba = cvtColor(cropped_image, COLOR_BGR2BGRA)
50
+ cropped_pil_image = Image.fromarray(cvtColor(cropped_image_rgba, COLOR_BGRA2RGB))
51
+ cropped_pil_image.save(image_path)
52
+
53
+ def calculate_variance(patch):
54
+ # Convert patch to numpy array
55
+ patch_array = np.array(patch)
56
+ # Calculate the variance
57
+ variance = np.var(patch_array)
58
+ return variance
59
+
60
+ def crop_least_variant_patch(folder_path):
61
+ for each_image in os.listdir(folder_path):
62
+ image_path = os.path.join(folder_path, each_image)
63
+ image = Image.open(image_path)
64
+ # define window size
65
+ width, height = image.size
66
+ window_size = round(height * .2)
67
+ stride = round(window_size * .2)
68
+ min_variance = float('inf')
69
+ best_patch = None
70
+ # slide window across image
71
+ for x in range(0, width - window_size + 1, stride):
72
+ for y in range(0, height - window_size + 1, stride):
73
+ patch = image.crop((x,y,x + window_size, y + window_size))
74
+ patch_w, patch_h = patch.size
75
+ total_pixels = patch_w * patch_h
76
+ white_pixels = np.sum(np.all(np.array(patch) == [255, 255, 255], axis=2))
77
+ if white_pixels < (total_pixels / 2):
78
+ # calculate variance / standard deviation
79
+ variance = calculate_variance(patch)
80
+ if variance < min_variance:
81
+ # update minimum var / sd
82
+ min_variance = variance
83
+ best_patch = patch
84
+ try:
85
+ best_patch.save(image_path)
86
+ except AttributeError as e:
87
+ print("No good homogenous patch to save.")
88
+
89
+ def extract_embeddings(transformation_chain, model: torch.nn.Module):
90
+ """Utility to compute embeddings."""
91
+ device = model.device
92
+
93
+ def pp(batch):
94
+ images = batch["image"]
95
+ image_batch_transformed = torch.stack(
96
+ [transformation_chain(image) for image in images]
97
+ )
98
+ new_batch = {"pixel_values": image_batch_transformed.to(device)}
99
+ with torch.no_grad():
100
+ embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
101
+ return {"embeddings": embeddings}
102
+
103
+ return pp
104
+
105
+ def compute_scores(emb_one, emb_two):
106
+ """Computes cosine similarity between two vectors."""
107
+ scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
108
+ return scores.numpy().tolist()
109
+
110
+
111
+ def fetch_similar(image, transformation_chain, device, model, all_candidate_embeddings, candidate_ids, top_k=3):
112
+ """Fetches the `top_k` similar images with `image` as the query."""
113
+ # Prepare the input query image for embedding computation.
114
+ image_transformed = transformation_chain(image).unsqueeze(0)
115
+ new_batch = {"pixel_values": image_transformed.to(device)}
116
+
117
+ # Compute the embedding.
118
+ with torch.no_grad():
119
+ query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
120
+
121
+ # Compute similarity scores with all the candidate images at one go.
122
+ # We also create a mapping between the candidate image identifiers
123
+ # and their similarity scores with the query image.
124
+ sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
125
+ similarity_mapping = dict(zip(candidate_ids, sim_scores))
126
+
127
+ # Sort the mapping dictionary and return `top_k` candidates.
128
+ similarity_mapping_sorted = dict(
129
+ sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
130
+ )
131
+ id_entries = list(similarity_mapping_sorted.keys())[:top_k]
132
+
133
+ ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
134
+ return ids
135
+
136
+ def plot_images(images):
137
+
138
+ plt.figure(figsize=(20, 10))
139
+ columns = 6
140
+ for (i, image) in enumerate(images):
141
+ ax = plt.subplot(int(len(images) / columns + 1), columns, i + 1)
142
+ if i == 0:
143
+ ax.set_title("Query Image\n")
144
+ else:
145
+ ax.set_title(
146
+ "Similar Image # " + str(i)
147
+ )
148
+ plt.imshow(np.array(image).astype("int"))
149
+ plt.axis("off")