AIBoy1993 commited on
Commit
ae97c0d
1 Parent(s): 6cc1ca0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +104 -21
  2. inference.py +69 -37
app.py CHANGED
@@ -1,8 +1,35 @@
1
  import os
 
 
2
  import gradio as gr
3
  from inference import run_inference
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  with gr.Blocks() as demo:
@@ -19,7 +46,7 @@ with gr.Blocks() as demo:
19
  # select device
20
  device = gr.Dropdown(["cpu", "cuda"], value='cpu', label="Select Device")
21
 
22
- # parameters
23
  with gr.Accordion(label='Parameters', open=False):
24
  with gr.Row():
25
  points_per_side = gr.Number(value=32, label="points_per_side", precision=0,
@@ -45,11 +72,21 @@ with gr.Blocks() as demo:
45
  info='''The box IoU cutoff used by non-maximal suppression to filter duplicate
46
  masks between different crops.''')
47
 
48
- # Show image
49
  with gr.Tab(label='Image'):
50
  with gr.Row().style(equal_height=True):
51
  with gr.Column():
 
 
52
  input_image = gr.Image(type="numpy")
 
 
 
 
 
 
 
 
53
  text = gr.Textbox(label='Text prompt(optional)', info=
54
  'If you type words, the OWL-ViT model will be used to detect the objects in the image, '
55
  'and the boxes will be feed into SAM model to predict mask. Please use English.',
@@ -57,28 +94,26 @@ with gr.Blocks() as demo:
57
  owl_vit_threshold = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="OWL ViT Object Detection threshold",
58
  info='''A small threshold will generate more objects, but may causing OOM.
59
  A big threshold may not detect objects, resulting in an error ''')
 
60
  button = gr.Button("Auto!")
 
61
  with gr.Tab(label='Image+Mask'):
62
  output_image = gr.Image(type='numpy')
 
63
  with gr.Tab(label='Mask'):
64
  output_mask = gr.Image(type='numpy')
 
 
65
 
66
- gr.Examples(
67
- examples=[os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"),
68
- os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"),
69
- os.path.join(os.path.dirname(__file__), "./images/1.jpg"),
70
- os.path.join(os.path.dirname(__file__), "./images/2.jpg"),
71
- os.path.join(os.path.dirname(__file__), "./images/3.jpg"),
72
- os.path.join(os.path.dirname(__file__), "./images/4.jpg"),
73
- os.path.join(os.path.dirname(__file__), "./images/5.jpg"),
74
- os.path.join(os.path.dirname(__file__), "./images/6.jpg"),
75
- os.path.join(os.path.dirname(__file__), "./images/7.jpg"),
76
- os.path.join(os.path.dirname(__file__), "./images/8.jpg"),
77
- ],
78
- inputs=input_image,
79
- outputs=output_image,
80
  )
81
- # Show video
 
82
  with gr.Tab(label='Video'):
83
  with gr.Row().style(equal_height=True):
84
  with gr.Column():
@@ -90,17 +125,65 @@ with gr.Blocks() as demo:
90
  **Note:** processing video will take a long time, please upload a short video.
91
  ''')
92
  gr.Examples(
93
- examples=[os.path.join(os.path.dirname(__file__), "./images/video1.mp4"),
94
- os.path.join(os.path.dirname(__file__), "./images/video2.mp4")
95
- ],
96
  inputs=input_video,
97
  outputs=output_video
98
  )
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # button image
101
  button.click(run_inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
102
  min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
103
- crop_nms_thresh, owl_vit_threshold, input_image, text],
104
  outputs=[output_image, output_mask])
105
  # button video
106
  button_video.click(run_inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
 
1
  import os
2
+ import cv2
3
+ import numpy as np
4
  import gradio as gr
5
  from inference import run_inference
6
 
7
 
8
+ # points color and marker
9
+ colors = [(255, 0, 0), (0, 255, 0)]
10
+ markers = [1, 5]
11
+
12
+ # image examples
13
+ # in each list, the first element is image path,
14
+ # the second is id (used for original_image State),
15
+ # the third is an empty list (used for selected_points State)
16
+ image_examples = [
17
+ [os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"), 0, []],
18
+ [os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"), 1, []],
19
+ [os.path.join(os.path.dirname(__file__), "./images/1.jpg"),2,[]],
20
+ [os.path.join(os.path.dirname(__file__), "./images/2.jpg"),3,[]],
21
+ [os.path.join(os.path.dirname(__file__), "./images/3.jpg"),4,[]],
22
+ [os.path.join(os.path.dirname(__file__), "./images/4.jpg"),5,[]],
23
+ [os.path.join(os.path.dirname(__file__), "./images/5.jpg"),6,[]],
24
+ [os.path.join(os.path.dirname(__file__), "./images/6.jpg"),7,[]],
25
+ [os.path.join(os.path.dirname(__file__), "./images/7.jpg"),8,[]],
26
+ [os.path.join(os.path.dirname(__file__), "./images/8.jpg"),9,[]]
27
+ ]
28
+ # video examples
29
+ video_examples = [
30
+ os.path.join(os.path.dirname(__file__), "./images/video1.mp4"),
31
+ os.path.join(os.path.dirname(__file__), "./images/video2.mp4")
32
+ ]
33
 
34
 
35
  with gr.Blocks() as demo:
 
46
  # select device
47
  device = gr.Dropdown(["cpu", "cuda"], value='cpu', label="Select Device")
48
 
49
+ # SAM parameters
50
  with gr.Accordion(label='Parameters', open=False):
51
  with gr.Row():
52
  points_per_side = gr.Number(value=32, label="points_per_side", precision=0,
 
72
  info='''The box IoU cutoff used by non-maximal suppression to filter duplicate
73
  masks between different crops.''')
74
 
75
+ # Segment image
76
  with gr.Tab(label='Image'):
77
  with gr.Row().style(equal_height=True):
78
  with gr.Column():
79
+ # input image
80
+ original_image = gr.State(value=None) # store original image without points, default None
81
  input_image = gr.Image(type="numpy")
82
+ # point prompt
83
+ with gr.Column():
84
+ selected_points = gr.State([]) # store points
85
+ with gr.Row():
86
+ gr.Markdown('You can click on the image to select points prompt. Default: foreground_point.')
87
+ undo_button = gr.Button('Undo point')
88
+ radio = gr.Radio(['foreground_point', 'background_point'], label='point labels')
89
+ # text prompt to generate box prompt
90
  text = gr.Textbox(label='Text prompt(optional)', info=
91
  'If you type words, the OWL-ViT model will be used to detect the objects in the image, '
92
  'and the boxes will be feed into SAM model to predict mask. Please use English.',
 
94
  owl_vit_threshold = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="OWL ViT Object Detection threshold",
95
  info='''A small threshold will generate more objects, but may causing OOM.
96
  A big threshold may not detect objects, resulting in an error ''')
97
+ # run button
98
  button = gr.Button("Auto!")
99
+ # show the image with mask
100
  with gr.Tab(label='Image+Mask'):
101
  output_image = gr.Image(type='numpy')
102
+ # show only mask
103
  with gr.Tab(label='Mask'):
104
  output_mask = gr.Image(type='numpy')
105
+ def process_example(img, ori_img, sel_p):
106
+ return ori_img, []
107
 
108
+ example = gr.Examples(
109
+ examples=image_examples,
110
+ inputs=[input_image, original_image, selected_points],
111
+ outputs=[original_image, selected_points],
112
+ fn=process_example,
113
+ run_on_click=True
 
 
 
 
 
 
 
 
114
  )
115
+
116
+ # Segment video
117
  with gr.Tab(label='Video'):
118
  with gr.Row().style(equal_height=True):
119
  with gr.Column():
 
125
  **Note:** processing video will take a long time, please upload a short video.
126
  ''')
127
  gr.Examples(
128
+ examples=video_examples,
 
 
129
  inputs=input_video,
130
  outputs=output_video
131
  )
132
 
133
+ # once user upload an image, the original image is stored in `original_image`
134
+ def store_img(img):
135
+ return img, [] # when new image is uploaded, `selected_points` should be empty
136
+ input_image.upload(
137
+ store_img,
138
+ [input_image],
139
+ [original_image, selected_points]
140
+ )
141
+
142
+ # user click the image to get points, and show the points on the image
143
+ def get_point(img, sel_pix, point_type, evt: gr.SelectData):
144
+ if point_type == 'foreground_point':
145
+ sel_pix.append((evt.index, 1)) # append the foreground_point
146
+ elif point_type == 'background_point':
147
+ sel_pix.append((evt.index, 0)) # append the background_point
148
+ else:
149
+ sel_pix.append((evt.index, 1)) # default foreground_point
150
+ # draw points
151
+ for point, label in sel_pix:
152
+ cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
153
+ if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB
154
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
155
+ return img if isinstance(img, np.ndarray) else np.array(img)
156
+ input_image.select(
157
+ get_point,
158
+ [input_image, selected_points, radio],
159
+ [input_image],
160
+ )
161
+
162
+ # undo the selected point
163
+ def undo_points(orig_img, sel_pix):
164
+ if isinstance(orig_img, int): # if orig_img is int, the image if select from examples
165
+ temp = cv2.imread(image_examples[orig_img][0])
166
+ temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
167
+ else:
168
+ temp = orig_img.copy()
169
+ # draw points
170
+ if len(sel_pix) != 0:
171
+ sel_pix.pop()
172
+ for point, label in sel_pix:
173
+ cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
174
+ if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB
175
+ temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
176
+ return temp if isinstance(temp, np.ndarray) else np.array(temp)
177
+ undo_button.click(
178
+ undo_points,
179
+ [original_image, selected_points],
180
+ [input_image]
181
+ )
182
+
183
  # button image
184
  button.click(run_inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
185
  min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
186
+ crop_nms_thresh, owl_vit_threshold, original_image, text, selected_points],
187
  outputs=[output_image, output_mask])
188
  # button video
189
  button_video.click(run_inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
inference.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import cv2
2
  import torch
3
  import numpy as np
@@ -13,6 +14,19 @@ models = {
13
  'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
14
  }
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def plot_boxes(img, boxes):
18
  img_pil = Image.fromarray(np.uint8(img * 255)).convert('RGB')
@@ -82,44 +96,55 @@ def generator_inference(device, model_type, points_per_side, pred_iou_thresh, st
82
  return 'output.mp4'
83
 
84
 
85
- def predictor_inference(device, model_type, input_x, input_text, owl_vit_threshold=0.1):
86
  # sam model
87
  sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
88
  predictor = SamPredictor(sam)
89
  predictor.set_image(input_x) # Process the image to produce an image embedding
90
 
91
- # split input text
92
- input_text = [input_text.split(',')]
93
-
94
- # OWL-ViT model
95
- # processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
96
- # owlvit_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
97
- processor = OwlViTProcessor.from_pretrained('./checkpoints/models--google--owlvit-base-patch32')
98
- owlvit_model = OwlViTForObjectDetection.from_pretrained("./checkpoints/models--google--owlvit-base-patch32").to(device)
99
-
100
- # get outputs
101
- input_text = processor(text=input_text, images=input_x, return_tensors="pt").to(device)
102
- outputs = owlvit_model(**input_text)
103
- target_size = torch.Tensor([input_x.shape[:2]]).to(device)
104
- results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_size,
105
- threshold=owl_vit_threshold)
106
-
107
- # get the box with best score
108
- scores = torch.sigmoid(outputs.logits)
109
- # best_scores, best_idxs = torch.topk(scores, k=1, dim=1)
110
- # best_idxs = best_idxs.squeeze(1).tolist()
111
-
112
- i = 0 # Retrieve predictions for the first image for the corresponding text queries
113
- boxes_tensor = results[i]["boxes"] # [best_idxs]
114
- print(boxes_tensor.size())
115
- boxes = boxes_tensor.cpu().detach().numpy()
116
- transformed_boxes = predictor.transform.apply_boxes_torch(torch.Tensor(boxes).to(device),
117
- input_x.shape[:2]) # apply transform to original boxes
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # predict segmentation according to the boxes
120
  masks, scores, logits = predictor.predict_torch(
121
- point_coords=None,
122
- point_labels=None,
123
  boxes=transformed_boxes, # only one box
124
  multimask_output=False,
125
  )
@@ -130,11 +155,13 @@ def predictor_inference(device, model_type, input_x, input_text, owl_vit_thresho
130
  for i in range(3):
131
  mask_all[ann[0] == True, i] = color_mask[i]
132
  img = input_x / 255 * 0.3 + mask_all * 0.7
133
- img = plot_boxes(img, boxes_tensor) # image + mask + boxes
 
134
 
135
  # free the memory
136
- owlvit_model.cpu()
137
- del owlvit_model
 
138
  del input_text
139
  gc.collect()
140
  torch.cuda.empty_cache()
@@ -144,11 +171,16 @@ def predictor_inference(device, model_type, input_x, input_text, owl_vit_thresho
144
 
145
  def run_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area,
146
  stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, owl_vit_threshold, input_x,
147
- input_text):
148
- print('prompt text: ', input_text)
149
- if input_text != '' and not isinstance(input_x, str): # user input text
 
 
 
150
  print('use predictor_inference')
151
- return predictor_inference(device, model_type, input_x, input_text, owl_vit_threshold)
 
 
152
  else:
153
  print('use generator_inference')
154
  return generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
 
1
+ import os
2
  import cv2
3
  import torch
4
  import numpy as np
 
14
  'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
15
  }
16
 
17
+ image_examples = [
18
+ [os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"), 0, []],
19
+ [os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"), 1, []],
20
+ [os.path.join(os.path.dirname(__file__), "./images/1.jpg"),2,[]],
21
+ [os.path.join(os.path.dirname(__file__), "./images/2.jpg"),3,[]],
22
+ [os.path.join(os.path.dirname(__file__), "./images/3.jpg"),4,[]],
23
+ [os.path.join(os.path.dirname(__file__), "./images/4.jpg"),5,[]],
24
+ [os.path.join(os.path.dirname(__file__), "./images/5.jpg"),6,[]],
25
+ [os.path.join(os.path.dirname(__file__), "./images/6.jpg"),7,[]],
26
+ [os.path.join(os.path.dirname(__file__), "./images/7.jpg"),8,[]],
27
+ [os.path.join(os.path.dirname(__file__), "./images/8.jpg"),9,[]]
28
+ ]
29
+
30
 
31
  def plot_boxes(img, boxes):
32
  img_pil = Image.fromarray(np.uint8(img * 255)).convert('RGB')
 
96
  return 'output.mp4'
97
 
98
 
99
+ def predictor_inference(device, model_type, input_x, input_text, selected_points, owl_vit_threshold=0.1):
100
  # sam model
101
  sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
102
  predictor = SamPredictor(sam)
103
  predictor.set_image(input_x) # Process the image to produce an image embedding
104
 
105
+ if input_text != '':
106
+ # split input text
107
+ input_text = [input_text.split(',')]
108
+ print(input_text)
109
+ # OWL-ViT model
110
+ processor = OwlViTProcessor.from_pretrained('./checkpoints/models--google--owlvit-base-patch32')
111
+ owlvit_model = OwlViTForObjectDetection.from_pretrained("./checkpoints/models--google--owlvit-base-patch32").to(device)
112
+ # get outputs
113
+ input_text = processor(text=input_text, images=input_x, return_tensors="pt").to(device)
114
+ outputs = owlvit_model(**input_text)
115
+ target_size = torch.Tensor([input_x.shape[:2]]).to(device)
116
+ results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_size,
117
+ threshold=owl_vit_threshold)
118
+
119
+ # get the box with best score
120
+ scores = torch.sigmoid(outputs.logits)
121
+ # best_scores, best_idxs = torch.topk(scores, k=1, dim=1)
122
+ # best_idxs = best_idxs.squeeze(1).tolist()
123
+
124
+ i = 0 # Retrieve predictions for the first image for the corresponding text queries
125
+ boxes_tensor = results[i]["boxes"] # [best_idxs]
126
+ boxes = boxes_tensor.cpu().detach().numpy()
127
+ # boxes = boxes[np.newaxis, :, :]
128
+ transformed_boxes = predictor.transform.apply_boxes_torch(torch.Tensor(boxes).to(device),
129
+ input_x.shape[:2]) # apply transform to original boxes
130
+ # transformed_boxes = transformed_boxes.unsqueeze(0)
131
+ print(transformed_boxes.size(), boxes.shape)
132
+ else:
133
+ transformed_boxes = None
134
+
135
+ # points
136
+ if len(selected_points) != 0:
137
+ points = torch.Tensor([p for p, _ in selected_points]).to(device).unsqueeze(1)
138
+ labels = torch.Tensor([int(l) for _, l in selected_points]).to(device).unsqueeze(1)
139
+ transformed_points = predictor.transform.apply_coords_torch(points, input_x.shape[:2])
140
+ print(points.size(), transformed_points.size(), labels.size(), input_x.shape, points)
141
+ else:
142
+ transformed_points, labels = None, None
143
 
144
  # predict segmentation according to the boxes
145
  masks, scores, logits = predictor.predict_torch(
146
+ point_coords=transformed_points,
147
+ point_labels=labels,
148
  boxes=transformed_boxes, # only one box
149
  multimask_output=False,
150
  )
 
155
  for i in range(3):
156
  mask_all[ann[0] == True, i] = color_mask[i]
157
  img = input_x / 255 * 0.3 + mask_all * 0.7
158
+ if input_text != '':
159
+ img = plot_boxes(img, boxes_tensor) # image + mask + boxes
160
 
161
  # free the memory
162
+ if input_text != '':
163
+ owlvit_model.cpu()
164
+ del owlvit_model
165
  del input_text
166
  gc.collect()
167
  torch.cuda.empty_cache()
 
171
 
172
  def run_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area,
173
  stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, owl_vit_threshold, input_x,
174
+ input_text, selected_points):
175
+ # if input_x is int, the image is selected from examples
176
+ if isinstance(input_x, int):
177
+ input_x = cv2.imread(image_examples[input_x][0])
178
+ input_x = cv2.cvtColor(input_x, cv2.COLOR_BGR2RGB)
179
+ if (input_text != '' and not isinstance(input_x, str)) or len(selected_points) != 0: # user input text or points
180
  print('use predictor_inference')
181
+ print('prompt text: ', input_text)
182
+ print('prompt points length: ', len(selected_points))
183
+ return predictor_inference(device, model_type, input_x, input_text, selected_points, owl_vit_threshold)
184
  else:
185
  print('use generator_inference')
186
  return generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,