Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -60,6 +60,29 @@ def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
|
|
60 |
return similarity
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def filter_masks(
|
64 |
image: np.ndarray,
|
65 |
masks: List[Dict[str, Any]],
|
@@ -77,15 +100,8 @@ def filter_masks(
|
|
77 |
or mask["stability_score"] < stability_score_threshold
|
78 |
):
|
79 |
continue
|
80 |
-
|
81 |
filtered_masks.append(mask)
|
82 |
-
|
83 |
-
x, y, w, h = mask["bbox"]
|
84 |
-
masked = image * np.expand_dims(mask["segmentation"], -1)
|
85 |
-
crop = masked[y: y + h, x: x + w]
|
86 |
-
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
87 |
-
crop = PIL.Image.fromarray(crop)
|
88 |
-
cropped_masks.append(crop)
|
89 |
|
90 |
if query and filtered_masks:
|
91 |
scores = get_scores(cropped_masks, query)
|
@@ -167,9 +183,9 @@ demo = gr.Interface(
|
|
167 |
[
|
168 |
0.9,
|
169 |
0.8,
|
170 |
-
0.
|
171 |
os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
|
172 |
-
"
|
173 |
],
|
174 |
[
|
175 |
0.9,
|
|
|
60 |
return similarity
|
61 |
|
62 |
|
63 |
+
def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
|
64 |
+
x, y, w, h = mask["bbox"]
|
65 |
+
masked = image * np.expand_dims(mask["segmentation"], -1)
|
66 |
+
crop = masked[y : y + h, x : x + w]
|
67 |
+
if h > w:
|
68 |
+
top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
|
69 |
+
else:
|
70 |
+
top, bottom, left, right = (w - h) // 2, (w - h) // 2, 0, 0
|
71 |
+
# padding
|
72 |
+
crop = cv2.copyMakeBorder(
|
73 |
+
crop,
|
74 |
+
top,
|
75 |
+
bottom,
|
76 |
+
left,
|
77 |
+
right,
|
78 |
+
cv2.BORDER_CONSTANT,
|
79 |
+
value=(0, 0, 0),
|
80 |
+
)
|
81 |
+
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
82 |
+
crop = PIL.Image.fromarray(crop)
|
83 |
+
return crop
|
84 |
+
|
85 |
+
|
86 |
def filter_masks(
|
87 |
image: np.ndarray,
|
88 |
masks: List[Dict[str, Any]],
|
|
|
100 |
or mask["stability_score"] < stability_score_threshold
|
101 |
):
|
102 |
continue
|
|
|
103 |
filtered_masks.append(mask)
|
104 |
+
cropped_masks.append(crop_image(image, mask))
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
if query and filtered_masks:
|
107 |
scores = get_scores(cropped_masks, query)
|
|
|
183 |
[
|
184 |
0.9,
|
185 |
0.8,
|
186 |
+
0.001,
|
187 |
os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
|
188 |
+
"building",
|
189 |
],
|
190 |
[
|
191 |
0.9,
|