PhyscalX commited on
Commit
ae507fe
1 Parent(s): 7c01d17

Use a safer process for submission

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -44,9 +44,9 @@ class Predictor(object):
44
  def __init__(self, model, kwargs):
45
  self.model = model
46
  self.kwargs = kwargs
47
- self.batch_size = kwargs.get("batch_size", 256)
48
  self.model.concept_projector.reset_weights(kwargs["concept_weights"])
49
- self.model.text_decoder.reset_cache(max_batch_size=self.batch_size)
50
 
51
  def preprocess_images(self, imgs):
52
  """Preprocess the inference images."""
@@ -85,21 +85,21 @@ class Predictor(object):
85
  mask_index = np.arange(rank_scores.shape[0]), rank_scores.argmax(1)
86
  iou_scores = outputs["iou_pred"][mask_index].cpu().numpy().reshape(batch_shape)
87
  # Upscale masks to the original image resolution.
88
- mask_pred = outputs["mask_pred"][mask_index][:, None]
89
  mask_pred = self.model.upscale_masks(mask_pred, im_batch.shape[1:-1])
90
  mask_pred = mask_pred.view(batch_shape + mask_pred.shape[2:])
91
  # Predict concepts.
92
  concepts, scores = self.model.predict_concept(outputs["sem_embeds"][mask_index])
93
  concepts, scores = [x.reshape(batch_shape) for x in (concepts, scores)]
94
  # Generate captions.
95
- sem_tokens = outputs["sem_tokens"][mask_index][:, None, :]
96
  captions = self.model.generate_text(sem_tokens).reshape(batch_shape)
97
  # Postprecess results.
98
  results = []
99
  for i in range(batch_shape[0]):
100
  pred_h, pred_w = im_info[i, :2].astype("int")
101
  masks = mask_pred[i : i + 1, :, :pred_h, :pred_w]
102
- masks = self.model.upscale_masks(masks, imgs[i].shape[:2])[0]
103
  results.append(
104
  {
105
  "scores": np.stack([iou_scores[i], scores[i]], axis=-1),
@@ -165,7 +165,8 @@ def build_gradio_app(queues, command):
165
  return click_img, draw_img, anno_img
166
 
167
  def on_submit_btn(click_img, mask_img, prompt, multipoint):
168
- if prompt == 0:
 
169
  img, points = click_img["image"], click_img["points"]
170
  points = np.array(points).reshape((-1, 2, 3))
171
  if multipoint == 1:
@@ -175,7 +176,7 @@ def build_gradio_app(queues, command):
175
  poly = points[np.where(points[:, 2] <= 1)[0]][None, :, :]
176
  points = [lt, rb, poly] if len(lt) > 0 else [poly, np.array([[[0, 0, 4]]])]
177
  points = np.concatenate(points, axis=1)
178
- elif prompt == 1:
179
  img, points = mask_img["background"], []
180
  for layer in mask_img["layers"]:
181
  ys, xs = np.nonzero(layer[:, :, 0])
@@ -189,8 +190,8 @@ def build_gradio_app(queues, command):
189
  points = np.concatenate([points, pad_points], axis=1)
190
  img = img[:, :, (2, 1, 0)] if img is not None else img
191
  img = np.zeros((480, 640, 3), dtype="uint8") if img is None else img
192
- points = (np.array([[[0, 0, 4]]]) if len(points) == 0 else points).astype("float32")
193
- inputs = {"img": img, "points": points}
194
  with command.output_index.get_lock():
195
  command.output_index.value += 1
196
  img_id = command.output_index.value
 
44
  def __init__(self, model, kwargs):
45
  self.model = model
46
  self.kwargs = kwargs
47
+ self.prompt_size = kwargs.get("prompt_size", 256)
48
  self.model.concept_projector.reset_weights(kwargs["concept_weights"])
49
+ self.model.text_decoder.reset_cache(max_batch_size=self.prompt_size)
50
 
51
  def preprocess_images(self, imgs):
52
  """Preprocess the inference images."""
 
85
  mask_index = np.arange(rank_scores.shape[0]), rank_scores.argmax(1)
86
  iou_scores = outputs["iou_pred"][mask_index].cpu().numpy().reshape(batch_shape)
87
  # Upscale masks to the original image resolution.
88
+ mask_pred = outputs["mask_pred"][mask_index].unsqueeze_(1)
89
  mask_pred = self.model.upscale_masks(mask_pred, im_batch.shape[1:-1])
90
  mask_pred = mask_pred.view(batch_shape + mask_pred.shape[2:])
91
  # Predict concepts.
92
  concepts, scores = self.model.predict_concept(outputs["sem_embeds"][mask_index])
93
  concepts, scores = [x.reshape(batch_shape) for x in (concepts, scores)]
94
  # Generate captions.
95
+ sem_tokens = outputs["sem_tokens"][mask_index].unsqueeze_(1)
96
  captions = self.model.generate_text(sem_tokens).reshape(batch_shape)
97
  # Postprecess results.
98
  results = []
99
  for i in range(batch_shape[0]):
100
  pred_h, pred_w = im_info[i, :2].astype("int")
101
  masks = mask_pred[i : i + 1, :, :pred_h, :pred_w]
102
+ masks = self.model.upscale_masks(masks, imgs[i].shape[:2]).flatten(0, 1)
103
  results.append(
104
  {
105
  "scores": np.stack([iou_scores[i], scores[i]], axis=-1),
 
165
  return click_img, draw_img, anno_img
166
 
167
  def on_submit_btn(click_img, mask_img, prompt, multipoint):
168
+ img, points = None, np.array([[[0, 0, 4]]])
169
+ if prompt == 0 and click_img is not None:
170
  img, points = click_img["image"], click_img["points"]
171
  points = np.array(points).reshape((-1, 2, 3))
172
  if multipoint == 1:
 
176
  poly = points[np.where(points[:, 2] <= 1)[0]][None, :, :]
177
  points = [lt, rb, poly] if len(lt) > 0 else [poly, np.array([[[0, 0, 4]]])]
178
  points = np.concatenate(points, axis=1)
179
+ elif prompt == 1 and mask_img is not None:
180
  img, points = mask_img["background"], []
181
  for layer in mask_img["layers"]:
182
  ys, xs = np.nonzero(layer[:, :, 0])
 
190
  points = np.concatenate([points, pad_points], axis=1)
191
  img = img[:, :, (2, 1, 0)] if img is not None else img
192
  img = np.zeros((480, 640, 3), dtype="uint8") if img is None else img
193
+ points = np.array([[[0, 0, 4]]]) if (len(points) == 0 or points.size == 0) else points
194
+ inputs = {"img": img, "points": points.astype("float32")}
195
  with command.output_index.get_lock():
196
  command.output_index.value += 1
197
  img_id = command.output_index.value