liuyizhang
commited on
Commit
•
2e4e1c8
1
Parent(s):
81fed1b
update app.py
Browse files
app.py
CHANGED
@@ -209,7 +209,7 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
|
|
209 |
|
210 |
size = image_pil.size
|
211 |
|
212 |
-
if task_type == '
|
213 |
# initialize SAM
|
214 |
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
215 |
image = np.array(image_path)
|
@@ -233,7 +233,7 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
|
|
233 |
|
234 |
# masks: [1, 1, 512, 512]
|
235 |
|
236 |
-
if task_type == '
|
237 |
pred_dict = {
|
238 |
"boxes": boxes_filt,
|
239 |
"size": [size[1], size[0]], # H,W
|
@@ -245,7 +245,7 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
|
|
245 |
image_with_box.save(image_path)
|
246 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
247 |
return image_result
|
248 |
-
elif task_type == '
|
249 |
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
250 |
|
251 |
# draw output image
|
@@ -302,8 +302,8 @@ if __name__ == "__main__":
|
|
302 |
with gr.Column():
|
303 |
input_image = gr.Image(source='upload', type="pil")
|
304 |
task_type = gr.Radio(["detection", "segment", "inpainting"], value="detection",
|
305 |
-
label='Task type
|
306 |
-
text_prompt = gr.Textbox(label="Detection Prompt")
|
307 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt", visible=True)
|
308 |
run_button = gr.Button(label="Run")
|
309 |
with gr.Accordion("Advanced options", open=False):
|
|
|
209 |
|
210 |
size = image_pil.size
|
211 |
|
212 |
+
if task_type == 'segment' or task_type == 'inpainting':
|
213 |
# initialize SAM
|
214 |
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
215 |
image = np.array(image_path)
|
|
|
233 |
|
234 |
# masks: [1, 1, 512, 512]
|
235 |
|
236 |
+
if task_type == 'detection':
|
237 |
pred_dict = {
|
238 |
"boxes": boxes_filt,
|
239 |
"size": [size[1], size[0]], # H,W
|
|
|
245 |
image_with_box.save(image_path)
|
246 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
247 |
return image_result
|
248 |
+
elif task_type == 'segment':
|
249 |
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
250 |
|
251 |
# draw output image
|
|
|
302 |
with gr.Column():
|
303 |
input_image = gr.Image(source='upload', type="pil")
|
304 |
task_type = gr.Radio(["detection", "segment", "inpainting"], value="detection",
|
305 |
+
label='Task type',interactive=True, visible=True)
|
306 |
+
text_prompt = gr.Textbox(label="Detection Prompt", placeholder="Cannot be empty")
|
307 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt", visible=True)
|
308 |
run_button = gr.Button(label="Run")
|
309 |
with gr.Accordion("Advanced options", open=False):
|