Garage / app.py
anvilarth's picture
Update app.py
955b30c verified
import os
import cv2
import time
import torch
import spaces
import subprocess
import numpy as np
import gradio as gr
import urllib.request
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from Garage.models.GroundedSegmentAnything.segment_anything.segment_anything import SamPredictor, build_sam, sam_model_registry
from Garage.models.GroundedSegmentAnything.GroundingDINO.groundingdino.util.inference import Model
from Garage import Augmenter
MODEL_DICT = dict(
vit_h="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", # yapf: disable # noqa
vit_l="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", # yapf: disable # noqa
vit_b="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", # yapf: disable # noqa
)
GROUNDING_DINO_CONFIG_PATH = "Garage/models/GroundedSegmentAnything/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
SAM_CHECKPOINT_PATH = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
SAM_ENCODER_VERSION = "vit_h"
class GradioWindow():
def __init__(self) -> None:
self.points = []
self.mask = []
self.selected_mask = None
self.segmentation_mask = []
self.concatenated_masks = None
self.examples_masks = {
0: ["dog", "examples/dog_mask.jpg"],
1: ["bread", "examples/bread_mask.jpg"],
2: ["room", "examples/room_mask.jpg"],
3: ["spoon", "examples/spoon_mask.jpg"],
4: ["cat", "examples/image_mask.jpg"],
}
self.GROUNDING_DINO_CONFIG_PATH = GROUNDING_DINO_CONFIG_PATH
self.GROUNDING_DINO_CHECKPOINT_PATH = GROUNDING_DINO_CHECKPOINT_PATH
self.model_type = SAM_ENCODER_VERSION
self.SAM_CHECKPOINT_PATH = SAM_CHECKPOINT_PATH
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = "cpu"
# for debug
# self.augmenter = None
self.augmenter = Augmenter(device=self.device)
self.setup_model()
self.main()
def main(self):
with gr.Blocks() as self.demo:
with gr.Row():
input_img = gr.Image(type="pil", label="Input image", interactive=True)
selected_mask = gr.Image(type="pil", label="Selected Mask", interactive=True)
segmented_img = gr.Image(type="pil", label="Selected Segment")
with gr.Row():
with gr.Group():
gr.Markdown(
"## Grounded Segmentation\n"
"#### This tool segments the object in the image based on the text prompt via GroundedSAM model. "
"You can also load the mask of the object to segment or choose one of the examples below.\n"
)
self.current_object = gr.Textbox(label="Current object")
with gr.Accordion("Advanced options", open=False):
self.use_mask = gr.Checkbox(label="Use segmentation mask", value=False)
box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Box threshold")
text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text threshold")
segment_object = gr.Button("Segment object")
with gr.Column():
gr.Examples(
label="Images Examples",
examples=[
["examples/dog.jpg"],
["examples/bread.png"],
["examples/room.jpg"],
["examples/spoon.png"],
["examples/image.jpg"],
],
inputs=[input_img],
examples_per_page=5
)
gr.Examples(
label="Mask Examples",
examples=[
[self.examples_masks[0][1]],
[self.examples_masks[1][1]],
[self.examples_masks[2][1]],
[self.examples_masks[3][1]],
[self.examples_masks[4][1]],
],
inputs=[selected_mask, input_img],
outputs=[segmented_img, self.current_object, self.use_mask],
fn=self.set_mask,
run_on_click=True
)
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"## Augmentation\n"
"#### This tool generates an augmented image based on the input image, the object to augment, and the target object. "
"If you don't specify the target object, the model will generate a random object. "
"You can also specify the number of steps, guidance scale, and seed for the generation process.\n"
)
self.target_object = gr.Textbox(label="Target object")
with gr.Accordion("Generation options", open=False):
self.iter_number = gr.Number(value=50, label="Steps")
self.guidance_scale = gr.Number(value=5, label="Guidance Scale")
self.seed = gr.Number(value=1, label="Seed")
self.return_prompt = gr.Checkbox(value=True, label="Show generated prompt")
enter_prompt = gr.Button("Augment Image")
with gr.Column():
augmented_img = gr.Image(type="pil", label="Augmented Image")
generated_prompt = gr.Markdown(
f"<div class=\"message\" style=\"text-align: center; \
font-size: 18px;\"></div>",
visible=True)
# Connect the UI and logic
selected_mask.upload(
self.set_mask,
inputs=[selected_mask, input_img],
outputs=[segmented_img, self.current_object, self.use_mask],
)
segment_object.click(
self.detect,
inputs=[input_img, self.current_object,
self.use_mask, box_threshold,
text_threshold],
outputs=[segmented_img, selected_mask]
)
self.use_mask.change(
fn=self.change_mask_type,
inputs=[input_img, self.use_mask],
outputs=[selected_mask, segmented_img],
)
segmented_img.select(
self.select_mask,
inputs=[input_img],
outputs=[selected_mask, segmented_img],
)
enter_prompt.click(
self.augment_image,
inputs=[input_img, self.current_object, self.target_object,
self.iter_number, self.guidance_scale, self.seed, self.return_prompt],
outputs=[augmented_img, generated_prompt],
)
def setup_model(self) -> SamPredictor:
self.sam = sam_model_registry["vit_h"]()
self.sam.load_state_dict(torch.utils.model_zoo.load_url(MODEL_DICT["vit_h"]))
self.sam.to(device=self.device)
self.sam_predictor = SamPredictor(self.sam)
self.grounding_dino_model = Model(
model_config_path=self.GROUNDING_DINO_CONFIG_PATH,
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH,
device=self.device
)
print("MODELS LOADED! Device:", self.device)
def change_mask_type(self, image, is_segmmask):
self.selected_mask = None
masks = []
self.mask = []
if is_segmmask:
for segm_mask in self.segmentation_mask:
gray_mask = np.array(segm_mask)
if gray_mask.ndim == 3:
gray_mask = gray_mask[:, :, 0]
gray_mask = np.where(gray_mask > 200, True, False)
masks.append(gray_mask)
self.mask.append(Image.fromarray(gray_mask))
res, common_mask = self.concatenate_masks(masks, image)
else:
for segm_mask in self.segmentation_mask:
mask = self.get_bbox_mask(segm_mask)
gray_mask = np.array(mask)
masks.append(gray_mask)
self.mask.append(Image.fromarray(gray_mask))
res, common_mask = self.concatenate_masks(masks, image)
return common_mask, res
def get_bbox_mask(self, mask):
bbox = mask.getbbox()
new_mask = Image.new("L", mask.size, 0) # Start with an all-black mask
draw = ImageDraw.Draw(new_mask)
if bbox:
draw.rectangle(bbox, fill=255)
return new_mask
def select_mask(self, image: Image, evt: gr.SelectData):
self.points = [evt.index[0], evt.index[1]]
selected_mask = np.zeros_like(image)
self.selected_mask = None
for mask in self.mask:
mask = np.array(mask)
plt.imshow(mask)
plt.show()
print(f"SELECT MASK {mask.shape}, unique {np.unique(mask)}")
if mask[self.points[1]][self.points[0]]:
self.selected_mask = Image.fromarray(mask)
color = np.array([30 / 255, 144 / 255, 255 / 255])
selected_mask[mask > 0] = color.reshape(1, 1, -1) * 255
selected_mask = Image.fromarray(selected_mask, mode="RGB")
break
res = self.show_mask(selected_mask, image)
self.concatenated_masks = res
return self.selected_mask, res
def set_mask(self, mask: Image, image: Image):
self.selected_mask = mask
self.segmentation_mask = [mask]
current_object = None
for key, value in self.examples_masks.items():
m = Image.open(value[1])
if np.array_equal(np.array(m), np.array(mask)):
current_object = value[0]
break
gray_mask = np.array(mask)
gray_mask = gray_mask[:, :, 0]
bin_mask = np.where(gray_mask > 200, True, False)
print(f"SET MASK {bin_mask.shape}, unique {np.unique(bin_mask)}")
_, common_mask = self.concatenate_masks([bin_mask], image)
self.mask = [Image.fromarray(bin_mask)]
res = self.show_mask(common_mask, image)
self.concatenated_masks = res
return res, current_object, True
def detect(self, image: Image, prompt: str, is_segmmask: bool,
box_threshold: float, text_threshold: float):
detections = self.grounding_dino_model.predict_with_classes(
image=cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB),
classes=[prompt],
box_threshold=box_threshold,
text_threshold=text_threshold,
)
detections.mask = self.segment(
sam_predictor=self.sam_predictor,
image=cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB),
xyxy=detections.xyxy
)
if len(detections.mask) == 0:
return np.array(image), Image.fromarray(np.zeros_like(np.array(image)))
self.segmentation_mask = []
for mask in detections.mask:
self.segmentation_mask.append(Image.fromarray(mask))
if is_segmmask:
image, common_mask = self.concatenate_masks(detections.mask, image)
else:
masks = []
for mask in detections.mask:
bbox_mask = self.get_bbox_mask(Image.fromarray(mask))
masks.append(np.array(bbox_mask))
image, common_mask = self.concatenate_masks(masks, image)
return image, common_mask
def concatenate_masks(self, masks: np.ndarray, image: Image) -> np.ndarray:
self.mask = []
random_color = False
common_mask = np.zeros_like(image)
for i, mask in enumerate(masks):
if random_color:
color = np.concatenate([np.random.random(3)], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255])
self.mask.append(Image.fromarray(mask))
common_mask[mask > 0] = color.reshape(1, 1, -1) * 255
random_color = True
common_mask = Image.fromarray(common_mask, mode="RGB")
image = self.show_mask(common_mask, image, random_color)
common_mask = np.where(np.array(common_mask) != 0, 255, 0).astype(np.uint8)
return Image.fromarray(image), Image.fromarray(common_mask)
def show_mask(self, mask: Image, image: Image,
random_color: bool = False) -> np.ndarray:
"""Visualize a mask on top of an image.
Args:
mask (Image): A 2D array of shape (H, W, 3).
image (Image): A 3D array of shape (H, W, 3).
random_color (bool): Whether to use a random color for the mask.
Returns:
np.ndarray: A 3D array of shape (H, W, 3) with the mask
visualized on top of the image.
"""
mask, image = np.array(mask), np.array(image)
target_size = (image.shape[1], image.shape[0]) # width, height
mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
image = cv2.addWeighted(image, 0.7, mask, 0.3, 0)
return image
def segment(self, sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
sam_predictor.set_image(image)
result_masks = []
for box in xyxy:
masks, scores, logits = sam_predictor.predict(
box=box,
multimask_output=True
)
index = np.argmax(scores)
result_masks.append(masks[index])
return np.array(result_masks)
# @spaces.GPU(duration=120)
def augment_image(self, image: Image,
current_object: str, new_objects_list: str,
ddim_steps: int, guidance_scale: int, seed: int, return_prompt: str) -> tuple:
if self.selected_mask:
mask = self.selected_mask
else:
mask = self.mask[np.random.choice(len(self.mask))]
new_objects_list = new_objects_list.split(", ")
result, (prompt, _) = self.augmenter(
image=image,
mask=mask,
current_object=current_object,
new_objects_list=new_objects_list,
ddim_steps=ddim_steps,
guidance_scale=guidance_scale,
seed=seed,
return_prompt=return_prompt
)
# # for debug
# result = mask
# prompt = "just mask"
if not return_prompt:
prompt = ""
prompt_message = f"<div class=\"message\" style=\"text-align: center; \
font-size: 18px;\">Generated prompt: {prompt}</div>"
return result, prompt_message
if __name__ == "__main__":
window = GradioWindow()
window.demo.launch(share=False)
window.demo.close()