curt-park commited on
Commit
6ecbb25
1 Parent(s): 4f7d8df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -60,6 +60,29 @@ def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
60
  return similarity
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def filter_masks(
64
  image: np.ndarray,
65
  masks: List[Dict[str, Any]],
@@ -77,15 +100,8 @@ def filter_masks(
77
  or mask["stability_score"] < stability_score_threshold
78
  ):
79
  continue
80
-
81
  filtered_masks.append(mask)
82
-
83
- x, y, w, h = mask["bbox"]
84
- masked = image * np.expand_dims(mask["segmentation"], -1)
85
- crop = masked[y: y + h, x: x + w]
86
- crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
87
- crop = PIL.Image.fromarray(crop)
88
- cropped_masks.append(crop)
89
 
90
  if query and filtered_masks:
91
  scores = get_scores(cropped_masks, query)
@@ -167,9 +183,9 @@ demo = gr.Interface(
167
  [
168
  0.9,
169
  0.8,
170
- 0.05,
171
  os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
172
- "water",
173
  ],
174
  [
175
  0.9,
 
60
  return similarity
61
 
62
 
63
+ def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
64
+ x, y, w, h = mask["bbox"]
65
+ masked = image * np.expand_dims(mask["segmentation"], -1)
66
+ crop = masked[y : y + h, x : x + w]
67
+ if h > w:
68
+ top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
69
+ else:
70
+ top, bottom, left, right = (w - h) // 2, (w - h) // 2, 0, 0
71
+ # padding
72
+ crop = cv2.copyMakeBorder(
73
+ crop,
74
+ top,
75
+ bottom,
76
+ left,
77
+ right,
78
+ cv2.BORDER_CONSTANT,
79
+ value=(0, 0, 0),
80
+ )
81
+ crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
82
+ crop = PIL.Image.fromarray(crop)
83
+ return crop
84
+
85
+
86
  def filter_masks(
87
  image: np.ndarray,
88
  masks: List[Dict[str, Any]],
 
100
  or mask["stability_score"] < stability_score_threshold
101
  ):
102
  continue
 
103
  filtered_masks.append(mask)
104
+ cropped_masks.append(crop_image(image, mask))
 
 
 
 
 
 
105
 
106
  if query and filtered_masks:
107
  scores = get_scores(cropped_masks, query)
 
183
  [
184
  0.9,
185
  0.8,
186
+ 0.001,
187
  os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
188
+ "building",
189
  ],
190
  [
191
  0.9,