hhhhhh0103 commited on
Commit
6e3a622
β€’
1 Parent(s): ddf363d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -0
app.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SamModel, SamProcessor, pipeline
2
+ from PIL import Image
3
+ import cv2
4
+ import random
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn.functional import cosine_similarity
8
+ import gradio as gr
9
+
10
+ class RoiMatching():
11
+ def __init__(self,img1,img2,device='cuda:1', v_min=200, v_max= 7000, mode = 'embedding'):
12
+ """
13
+ Initialize
14
+ :param img1: PIL image
15
+ :param img2:
16
+ """
17
+ self.img1 = img1
18
+ self.img2 = img2
19
+ self.device = device
20
+ self.v_min = v_min
21
+ self.v_max = v_max
22
+ self.mode = mode
23
+
24
+ def _sam_everything(self,imgs):
25
+ generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=self.device)
26
+ outputs = generator(imgs, points_per_batch=64,pred_iou_thresh=0.90,stability_score_thresh=0.9,)
27
+ return outputs
28
+ def _mask_criteria(self, masks, v_min=200, v_max= 7000):
29
+ remove_list = set()
30
+ for _i, mask in enumerate(masks):
31
+ if mask.sum() < v_min or mask.sum() > v_max:
32
+ remove_list.add(_i)
33
+ masks = [mask for idx, mask in enumerate(masks) if idx not in remove_list]
34
+ n = len(masks)
35
+ remove_list = set()
36
+ for i in range(n):
37
+ for j in range(i + 1, n):
38
+ mask1, mask2 = masks[i], masks[j]
39
+ intersection = (mask1 & mask2).sum()
40
+ smaller_mask_area = min(masks[i].sum(), masks[j].sum())
41
+
42
+ if smaller_mask_area > 0 and (intersection / smaller_mask_area) >= 0.9:
43
+ if mask1.sum() < mask2.sum():
44
+ remove_list.add(i)
45
+ else:
46
+ remove_list.add(j)
47
+ return [mask for idx, mask in enumerate(masks) if idx not in remove_list]
48
+
49
+ def _roi_proto(self, image, masks):
50
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to(self.device)
51
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
52
+ inputs = processor(image, return_tensors="pt").to(self.device)
53
+ image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
54
+ embs = []
55
+ for _m in masks:
56
+ # Convert mask to uint8, resize, and then back to boolean
57
+ tmp_m = _m.astype(np.uint8)
58
+ tmp_m = cv2.resize(tmp_m, (64, 64), interpolation=cv2.INTER_NEAREST)
59
+ tmp_m = torch.tensor(tmp_m.astype(bool), device=self.device,
60
+ dtype=torch.float32) # Convert to tensor and send to CUDA
61
+ tmp_m = tmp_m.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions to match emb1
62
+
63
+ # Element-wise multiplication with emb1
64
+ tmp_emb = image_embeddings * tmp_m
65
+ # (1,256,64,64)
66
+
67
+ tmp_emb[tmp_emb == 0] = torch.nan
68
+ emb = torch.nanmean(tmp_emb, dim=(2, 3))
69
+ emb[torch.isnan(emb)] = 0
70
+ embs.append(emb)
71
+ return embs
72
+
73
+ def _cosine_similarity(self, vec1, vec2):
74
+ # Ensure vec1 and vec2 are 2D tensors [1, N]
75
+ vec1 = vec1.view(1, -1)
76
+ vec2 = vec2.view(1, -1)
77
+ return cosine_similarity(vec1, vec2).item()
78
+
79
+ def _similarity_matrix(self, protos1, protos2):
80
+ # Initialize similarity_matrix as a torch tensor
81
+ similarity_matrix = torch.zeros(len(protos1), len(protos2), device=self.device)
82
+ for i, vec_a in enumerate(protos1):
83
+ for j, vec_b in enumerate(protos2):
84
+ similarity_matrix[i, j] = self._cosine_similarity(vec_a, vec_b)
85
+ # Normalize the similarity matrix
86
+ sim_matrix = (similarity_matrix - similarity_matrix.min()) / (similarity_matrix.max() - similarity_matrix.min())
87
+ return similarity_matrix
88
+
89
+ def _roi_match(self, matrix, masks1, masks2, sim_criteria=0.8):
90
+ index_pairs = []
91
+ while torch.any(matrix > sim_criteria):
92
+ max_idx = torch.argmax(matrix)
93
+ max_sim_idx = (max_idx // matrix.shape[1], max_idx % matrix.shape[1])
94
+ if matrix[max_sim_idx[0], max_sim_idx[1]] > sim_criteria:
95
+ index_pairs.append(max_sim_idx)
96
+ matrix[max_sim_idx[0], :] = -1
97
+ matrix[:, max_sim_idx[1]] = -1
98
+ masks1_new = []
99
+ masks2_new = []
100
+ for i, j in index_pairs:
101
+ masks1_new.append(masks1[i])
102
+ masks2_new.append(masks2[j])
103
+ return masks1_new, masks2_new
104
+
105
+ def _overlap_pair(self, masks1,masks2):
106
+ self.masks1_cor = []
107
+ self.masks2_cor = []
108
+ k = 0
109
+ for mask in masks1[:-1]:
110
+ k += 1
111
+ print('mask1 {} is finding corresponding region mask...'.format(k))
112
+ m1 = mask
113
+ a1 = mask.sum()
114
+ v1 = np.mean(np.expand_dims(m1, axis=-1) * self.im1)
115
+ overlap = m1 * masks2[-1].astype(np.int64)
116
+ # print(np.unique(overlap))
117
+ if (overlap > 0).sum() / a1 > 0.3:
118
+ counts = np.bincount(overlap.flatten())
119
+ # print(counts)
120
+ sorted_indices = np.argsort(counts)[::-1]
121
+ top_two = sorted_indices[1:3]
122
+ # print(top_two)
123
+ if top_two[-1] == 0:
124
+ cor_ind = 0
125
+ elif abs(counts[top_two[-1]] - counts[top_two[0]]) / max(counts[top_two[-1]], counts[top_two[0]]) < 0.2:
126
+ cor_ind = 0
127
+ else:
128
+ # cor_ind = 0
129
+ m21 = masks2[top_two[0]-1]
130
+ m22 = masks2[top_two[1]-1]
131
+ a21 = masks2[top_two[0]-1].sum()
132
+ a22 = masks2[top_two[1]-1].sum()
133
+ v21 = np.mean(np.expand_dims(m21, axis=-1)*self.im2)
134
+ v22 = np.mean(np.expand_dims(m22, axis=-1)*self.im2)
135
+ if np.abs(a21-a1) > np.abs(a22-a1):
136
+ cor_ind = 0
137
+ else:
138
+ cor_ind = 1
139
+ print('area judge to cor_ind {}'.format(cor_ind))
140
+ if np.abs(v21-v1) < np.abs(v22-v1):
141
+ cor_ind = 0
142
+ else:
143
+ cor_ind = 1
144
+ # print('value judge to cor_ind {}'.format(cor_ind))
145
+ # print('mask1 {} has found the corresponding region mask: mask2 {}'.format(k, top_two[cor_ind]))
146
+
147
+ self.masks2_cor.append(masks2[top_two[cor_ind] - 1])
148
+ self.masks1_cor.append(mask)
149
+ # return masks1_new, masks2_new
150
+
151
+ def get_paired_roi(self):
152
+ self.masks1 = self._sam_everything(self.img1) # len(RM.masks1) 2; RM.masks1[0] dict; RM.masks1[0]['masks'] list
153
+ self.masks2 = self._sam_everything(self.img2)
154
+ self.masks1 = self._mask_criteria(self.masks1['masks'], v_min=self.v_min, v_max=self.v_max)
155
+ self.masks2 = self._mask_criteria(self.masks2['masks'], v_min=self.v_min, v_max=self.v_max)
156
+
157
+ match self.mode:
158
+ case 'embedding':
159
+ if len(self.masks1) > 0 and len(self.masks2) > 0:
160
+ self.embs1 = self._roi_proto(self.img1,self.masks1) #device:cuda1
161
+ self.embs2 = self._roi_proto(self.img2,self.masks2)
162
+ self.sim_matrix = self._similarity_matrix(self.embs1, self.embs2)
163
+ self.masks1, self.masks2 = self._roi_match(self.sim_matrix,self.masks1,self.masks2)
164
+ case 'overlaping':
165
+ self._overlap_pair(self.masks1,self.masks2)
166
+
167
+ def visualize_masks(image1, masks1, image2, masks2):
168
+ # Convert PIL images to numpy arrays
169
+ background1 = np.array(image1)
170
+ background2 = np.array(image2)
171
+
172
+ # Convert RGB to BGR (OpenCV uses BGR color format)
173
+ background1 = cv2.cvtColor(background1, cv2.COLOR_RGB2BGR)
174
+ background2 = cv2.cvtColor(background2, cv2.COLOR_RGB2BGR)
175
+
176
+ # Create a blank mask for each image
177
+ mask1 = np.zeros_like(background1)
178
+ mask2 = np.zeros_like(background2)
179
+
180
+ distinct_colors = [
181
+ (255, 0, 0), # Red
182
+ (0, 255, 0), # Green
183
+ (0, 0, 255), # Blue
184
+ (255, 255, 0), # Cyan
185
+ (255, 0, 255), # Magenta
186
+ (0, 255, 255), # Yellow
187
+ (128, 0, 0), # Maroon
188
+ (0, 128, 0), # Olive
189
+ (0, 0, 128), # Navy
190
+ (128, 128, 0), # Teal
191
+ (128, 0, 128), # Purple
192
+ (0, 128, 128), # Gray
193
+ (192, 192, 192) # Silver
194
+ ]
195
+
196
+ def random_color():
197
+ """Generate a random color with high saturation and value in HSV color space."""
198
+ hue = random.randint(0, 179) # Random hue value between 0 and 179 (HSV uses 0-179 range)
199
+ saturation = random.randint(200, 255) # High saturation value between 200 and 255
200
+ value = random.randint(200, 255) # High value (brightness) between 200 and 255
201
+ color = np.array([[[hue, saturation, value]]], dtype=np.uint8)
202
+ return cv2.cvtColor(color, cv2.COLOR_HSV2BGR)[0][0]
203
+
204
+
205
+ # Iterate through mask lists and overlay on the blank masks with different colors
206
+ for idx, (mask1_item, mask2_item) in enumerate(zip(masks1, masks2)):
207
+ # color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
208
+ # color = distinct_colors[idx % len(distinct_colors)]
209
+ color = random_color()
210
+ # Convert binary masks to uint8
211
+ mask1_item = np.uint8(mask1_item)
212
+ mask2_item = np.uint8(mask2_item)
213
+
214
+ # Create a mask where binary mask is True
215
+ fg_mask1 = np.where(mask1_item, 255, 0).astype(np.uint8)
216
+ fg_mask2 = np.where(mask2_item, 255, 0).astype(np.uint8)
217
+
218
+ # Apply the foreground masks on the corresponding masks with the same color
219
+ mask1[fg_mask1 > 0] = color
220
+ mask2[fg_mask2 > 0] = color
221
+
222
+ # Add the masks on top of the background images
223
+ result1 = cv2.addWeighted(background1, 1, mask1, 0.5, 0)
224
+ result2 = cv2.addWeighted(background2, 1, mask2, 0.5, 0)
225
+
226
+ return result1, result2
227
+
228
+ def predict(im1,im2):
229
+ RM = RoiMatching(im1,im2,device='cpu')
230
+ RM.get_paired_roi()
231
+ visualized_image1, visualized_image2 = visualize_masks(im1, RM.masks1, im2, RM.masks2)
232
+ return visualized_image1, visualized_image2
233
+
234
+ examples = [
235
+ ['./example/prostate_2d/image1.png', './example/prostate_2d/image2.png'],
236
+ ['./example/cardiac_2d/image1.png', './example/cardiac_2d/image2.png'],
237
+ ['./example/pathology/1B_B7_R.png', './example/pathology/1B_B7_T.png'],
238
+ ]
239
+
240
+
241
+ gradio_app = gr.Interface(
242
+ predict,
243
+ inputs=gr.Image(label="Select hot dog candidate", sources=['upload', 'webcam'], type="pil"),
244
+ outputs=[gr.Image(label="Processed Image"), gr.Label(label="Result", num_top_classes=2)],
245
+ title="SAMReg: One Registration is Worth Two Segmentations",
246
+ examples=examples,
247
+ description="<p> \
248
+ <strong>Register anything with ROI-based registration representation.</strong> <br>\
249
+ Choose an example below &#128293; &#128293; &#128293; <br>\
250
+ Or, upload by yourself: <br>\
251
+ 1. Upload images to be tested to 'img1' and 'img2'. <br>2. Upload a prompt image to 'im1' and 'im2'. <br>\
252
+ <br> \
253
+ πŸ’Ž SAM segments the target with any point or scribble, then SegGPT segments all other images. <br>\
254
+ πŸ’Ž Examples below were never trained and are randomly selected for testing in the wild. <br>\
255
+ πŸ’Ž Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. \
256
+ </p>",
257
+ )