Spaces:
Running
on
T4
Running
on
T4
liuyizhang
commited on
Commit
•
63e6e86
1
Parent(s):
b269211
update app.py
Browse files
app.py
CHANGED
@@ -434,8 +434,7 @@ def concatenate_images_vertical(image1, image2):
|
|
434 |
|
435 |
return new_image
|
436 |
|
437 |
-
def relate_anything(
|
438 |
-
input_image = input_image_mask['image']
|
439 |
logger.info(f'relate_anything_1_{input_image.size}_')
|
440 |
w, h = input_image.size
|
441 |
max_edge = 1500
|
@@ -478,15 +477,17 @@ def relate_anything(input_image_mask, k):
|
|
478 |
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
|
479 |
pil_image_list.append(concate_pil_image)
|
480 |
|
481 |
-
|
482 |
-
yield pil_image_list
|
483 |
-
|
484 |
|
485 |
mask_source_draw = "draw a mask on input image"
|
486 |
mask_source_segment = "type what to detect below"
|
487 |
|
488 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
489 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation):
|
|
|
|
|
|
|
|
|
490 |
text_prompt = text_prompt.strip()
|
491 |
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
492 |
if text_prompt == '':
|
@@ -510,7 +511,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
510 |
size = image_pil.size
|
511 |
|
512 |
output_images = []
|
513 |
-
|
514 |
# run grounding dino model
|
515 |
if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
|
516 |
pass
|
@@ -538,11 +539,12 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
538 |
"labels": pred_phrases,
|
539 |
}
|
540 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
541 |
-
image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
|
542 |
-
image_with_box.save(image_path)
|
543 |
-
detection_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
544 |
-
os.remove(image_path)
|
545 |
-
output_images.append(detection_image_result)
|
|
|
546 |
|
547 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
548 |
if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
@@ -600,13 +602,12 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
600 |
mask = masks[0][0].cpu().numpy()
|
601 |
mask_pil = Image.fromarray(mask)
|
602 |
|
603 |
-
image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
|
604 |
-
#
|
605 |
-
#
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
output_images.append(image_result)
|
610 |
|
611 |
if task_type == 'inpainting':
|
612 |
# inpainting pipeline
|
@@ -645,24 +646,23 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
645 |
mask_imgs.append(mask_pil_exp)
|
646 |
mask_pil = mix_masks(mask_imgs)
|
647 |
|
648 |
-
image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
|
649 |
-
#
|
650 |
-
#
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
output_images.append(image_result)
|
655 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
|
656 |
|
657 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
658 |
|
659 |
-
image_path = os.path.join(output_dir, f"grounded_sam_inpainting_output_{file_temp}.jpg")
|
660 |
-
image_inpainting.save(image_path)
|
661 |
-
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
662 |
-
os.remove(image_path)
|
663 |
-
logger.info(f'run_anything_task_[{file_temp}]_{task_type}
|
664 |
-
output_images.append(image_inpainting)
|
665 |
# output_images.append(image_result)
|
|
|
666 |
return output_images, gr.Gallery.update(label='result images')
|
667 |
else:
|
668 |
logger.info(f"task_type:{task_type} error!")
|
@@ -674,10 +674,10 @@ def change_radio_display(task_type, mask_source_radio):
|
|
674 |
inpaint_prompt_visible = False
|
675 |
mask_source_radio_visible = False
|
676 |
num_relation_visible = False
|
677 |
-
run_button_visible = True
|
678 |
-
relate_all_button_visible = False
|
679 |
-
gsa_gallery_visible = True
|
680 |
-
ram_gallery_visible = False
|
681 |
if task_type == "inpainting":
|
682 |
inpaint_prompt_visible = True
|
683 |
if task_type == "inpainting" or task_type == "remove":
|
@@ -687,11 +687,12 @@ def change_radio_display(task_type, mask_source_radio):
|
|
687 |
if task_type == "relate anything":
|
688 |
text_prompt_visible = False
|
689 |
num_relation_visible = True
|
690 |
-
run_button_visible = False
|
691 |
-
relate_all_button_visible = True
|
692 |
-
gsa_gallery_visible = False
|
693 |
-
ram_gallery_visible = True
|
694 |
-
return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
|
|
|
695 |
|
696 |
if __name__ == "__main__":
|
697 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
@@ -715,7 +716,7 @@ if __name__ == "__main__":
|
|
715 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
716 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
717 |
run_button = gr.Button(label="Run", visible=True)
|
718 |
-
relate_all_button = gr.Button(label="Run", visible=False)
|
719 |
with gr.Accordion("Advanced options", open=False) as advanced_options:
|
720 |
box_threshold = gr.Slider(
|
721 |
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
@@ -734,17 +735,19 @@ if __name__ == "__main__":
|
|
734 |
remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
|
735 |
|
736 |
with gr.Column():
|
737 |
-
|
738 |
-
).style(preview=True,
|
739 |
-
|
740 |
-
|
|
|
|
|
741 |
|
742 |
run_button.click(fn=run_anything_task, inputs=[
|
743 |
-
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation], outputs=[
|
744 |
-
relate_all_button.click(fn=relate_anything, inputs=[input_image, num_relation], outputs=[ram_gallery], show_progress=True, queue=True)
|
745 |
|
746 |
-
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation
|
747 |
-
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation
|
748 |
|
749 |
DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
750 |
DESCRIPTION += 'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
|
|
|
434 |
|
435 |
return new_image
|
436 |
|
437 |
+
def relate_anything(input_image, k):
|
|
|
438 |
logger.info(f'relate_anything_1_{input_image.size}_')
|
439 |
w, h = input_image.size
|
440 |
max_edge = 1500
|
|
|
477 |
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
|
478 |
pil_image_list.append(concate_pil_image)
|
479 |
|
480 |
+
return pil_image_list
|
|
|
|
|
481 |
|
482 |
mask_source_draw = "draw a mask on input image"
|
483 |
mask_source_segment = "type what to detect below"
|
484 |
|
485 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
486 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation):
|
487 |
+
if (task_type == 'relate anything'):
|
488 |
+
output_images = relate_anything(input_image['image'], num_relation)
|
489 |
+
return output_images, gr.Gallery.update(label='relate images')
|
490 |
+
|
491 |
text_prompt = text_prompt.strip()
|
492 |
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
493 |
if text_prompt == '':
|
|
|
511 |
size = image_pil.size
|
512 |
|
513 |
output_images = []
|
514 |
+
output_images.append(input_image['image'])
|
515 |
# run grounding dino model
|
516 |
if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
|
517 |
pass
|
|
|
539 |
"labels": pred_phrases,
|
540 |
}
|
541 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
542 |
+
# image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
|
543 |
+
# image_with_box.save(image_path)
|
544 |
+
# detection_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
545 |
+
# os.remove(image_path)
|
546 |
+
# output_images.append(detection_image_result)
|
547 |
+
output_images.append(image_with_box)
|
548 |
|
549 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
550 |
if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
|
|
602 |
mask = masks[0][0].cpu().numpy()
|
603 |
mask_pil = Image.fromarray(mask)
|
604 |
|
605 |
+
# image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
|
606 |
+
# mask_pil.convert("RGB").save(image_path)
|
607 |
+
# image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
608 |
+
# os.remove(image_path)
|
609 |
+
# output_images.append(image_result)
|
610 |
+
output_images.append(mask_pil.convert("RGB"))
|
|
|
611 |
|
612 |
if task_type == 'inpainting':
|
613 |
# inpainting pipeline
|
|
|
646 |
mask_imgs.append(mask_pil_exp)
|
647 |
mask_pil = mix_masks(mask_imgs)
|
648 |
|
649 |
+
# image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
|
650 |
+
# mask_pil.convert("RGB").save(image_path)
|
651 |
+
# image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
652 |
+
# os.remove(image_path)
|
653 |
+
# output_images.append(image_result)
|
654 |
+
output_images.append(mask_pil.convert("RGB"))
|
|
|
655 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
|
656 |
|
657 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
658 |
|
659 |
+
# image_path = os.path.join(output_dir, f"grounded_sam_inpainting_output_{file_temp}.jpg")
|
660 |
+
# image_inpainting.save(image_path)
|
661 |
+
# image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
662 |
+
# os.remove(image_path)
|
663 |
+
# logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
|
|
664 |
# output_images.append(image_result)
|
665 |
+
output_images.append(image_inpainting)
|
666 |
return output_images, gr.Gallery.update(label='result images')
|
667 |
else:
|
668 |
logger.info(f"task_type:{task_type} error!")
|
|
|
674 |
inpaint_prompt_visible = False
|
675 |
mask_source_radio_visible = False
|
676 |
num_relation_visible = False
|
677 |
+
# run_button_visible = True
|
678 |
+
# relate_all_button_visible = False
|
679 |
+
# gsa_gallery_visible = True
|
680 |
+
# ram_gallery_visible = False
|
681 |
if task_type == "inpainting":
|
682 |
inpaint_prompt_visible = True
|
683 |
if task_type == "inpainting" or task_type == "remove":
|
|
|
687 |
if task_type == "relate anything":
|
688 |
text_prompt_visible = False
|
689 |
num_relation_visible = True
|
690 |
+
# run_button_visible = False
|
691 |
+
# relate_all_button_visible = True
|
692 |
+
# gsa_gallery_visible = False
|
693 |
+
# ram_gallery_visible = True
|
694 |
+
return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
|
695 |
+
#, gr.Button.update(visible=run_button_visible), gr.Button.update(visible=relate_all_button_visible), gr.Gallery.update(visible=gsa_gallery_visible), gr.Gallery.update(visible=ram_gallery_visible)
|
696 |
|
697 |
if __name__ == "__main__":
|
698 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
|
|
716 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
717 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
718 |
run_button = gr.Button(label="Run", visible=True)
|
719 |
+
# relate_all_button = gr.Button(label="Run", visible=False)
|
720 |
with gr.Accordion("Advanced options", open=False) as advanced_options:
|
721 |
box_threshold = gr.Slider(
|
722 |
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
|
|
735 |
remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
|
736 |
|
737 |
with gr.Column():
|
738 |
+
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gsa_allery", visible=True
|
739 |
+
).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
|
740 |
+
# gsa_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gsa_allery", visible=True
|
741 |
+
# ).style(preview=True, grid=[2], full_width=True, full_height=True)
|
742 |
+
# ram_gallery = gr.Gallery(label="Your Result", show_label=True, elem_id="ram_gallery", visible=False
|
743 |
+
# ).style(preview=True, columns=5, object_fit="scale-down")
|
744 |
|
745 |
run_button.click(fn=run_anything_task, inputs=[
|
746 |
+
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation], outputs=[image_gallery, image_gallery], show_progress=True, queue=True)
|
747 |
+
# relate_all_button.click(fn=relate_anything, inputs=[input_image, num_relation], outputs=[ram_gallery], show_progress=True, queue=True)
|
748 |
|
749 |
+
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
750 |
+
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
751 |
|
752 |
DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
753 |
DESCRIPTION += 'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
|