Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
# mm libs | |
from mmdet.registry import MODELS | |
from mmdet.structures import DetDataSample | |
from mmdet.visualization import DetLocalVisualizer | |
from mmengine import Config, print_log | |
from mmengine.structures import InstanceData | |
from mmdet.datasets.coco_panoptic import CocoPanopticDataset | |
from PIL import ImageDraw | |
import spaces | |
IMG_SIZE = 1024 | |
TITLE = "<center><strong><font size='8'>OMG-Seg: Is One Model Good Enough For All Segmentation?<font></strong></center>" | |
CSS = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" | |
model_cfg = Config.fromfile('app/configs/m2_convl.py') | |
model = MODELS.build(model_cfg.model) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device=device) | |
model = model.eval() | |
model.init_weights() | |
mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None] | |
std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None] | |
visualizer = DetLocalVisualizer() | |
examples = [ | |
["assets/000000000139.jpg"], | |
["assets/000000000285.jpg"], | |
["assets/000000000632.jpg"], | |
["assets/000000000724.jpg"], | |
] | |
class IMGState: | |
def __init__(self): | |
self.img = None | |
self.selected_points = [] | |
self.available_to_set = True | |
def set_img(self, img): | |
self.img = img | |
self.available_to_set = False | |
def clear(self): | |
self.img = None | |
self.selected_points = [] | |
self.available_to_set = True | |
def clean(self): | |
self.selected_points = [] | |
def available(self): | |
return self.available_to_set | |
def cls_clean(cls, state): | |
state.clean() | |
return Image.fromarray(state.img), None | |
def cls_clear(cls, state): | |
state.clear() | |
return None, None | |
def store_img(img, img_state): | |
w, h = img.size | |
scale = IMG_SIZE / max(w, h) | |
new_w = int(w * scale) | |
new_h = int(h * scale) | |
img = img.resize((new_w, new_h), resample=Image.Resampling.BILINEAR) | |
img_numpy = np.array(img) | |
img_state.set_img(img_numpy) | |
print_log(f"Successfully loaded an image with size {new_w} x {new_h}", logger='current') | |
return img, None | |
def get_points_with_draw(image, img_state, evt: gr.SelectData): | |
x, y = evt.index[0], evt.index[1] | |
print_log(f"Point: {x}_{y}", logger='current') | |
point_radius, point_color = 10, (97, 217, 54) | |
img_state.selected_points.append([x, y]) | |
if len(img_state.selected_points) > 0: | |
img_state.selected_points = img_state.selected_points[-1:] | |
image = Image.fromarray(img_state.img) | |
draw = ImageDraw.Draw(image) | |
draw.ellipse( | |
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], | |
fill=point_color, | |
) | |
return image | |
def segment_point(image, img_state, mode): | |
output_img = img_state.img | |
h, w = output_img.shape[:2] | |
img_tensor = torch.tensor(output_img, device=device, dtype=torch.float32).permute((2, 0, 1))[None] | |
img_tensor = (img_tensor - mean) / std | |
im_w = w if w % 32 == 0 else w // 32 * 32 + 32 | |
im_h = h if h % 32 == 0 else h // 32 * 32 + 32 | |
img_tensor = F.pad(img_tensor, (0, im_w - w, 0, im_h - h), 'constant', 0) | |
if len(img_state.selected_points) > 0: | |
input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device) | |
batch_data_samples = [DetDataSample()] | |
selected_point = torch.cat([input_points - 3, input_points + 3], 1) | |
gt_instances = InstanceData( | |
point_coords=selected_point, | |
) | |
pb_labels = torch.zeros(len(gt_instances), dtype=torch.long, device=device) | |
gt_instances.bp = pb_labels | |
batch_data_samples[0].gt_instances = gt_instances | |
batch_data_samples[0].data_tag = 'sam' | |
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) | |
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) | |
is_prompt = True | |
else: | |
batch_data_samples = [DetDataSample()] | |
batch_data_samples[0].data_tag = 'coco' | |
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) | |
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) | |
is_prompt = False | |
with torch.no_grad(): | |
results = model.predict(img_tensor, batch_data_samples, rescale=False) | |
masks = results[0] | |
if is_prompt: | |
masks = masks[0, :h, :w] | |
masks = masks > 0. # no sigmoid | |
rgb_shape = tuple(list(masks.shape) + [3]) | |
color = np.zeros(rgb_shape, dtype=np.uint8) | |
color[masks] = np.array([97, 217, 54]) | |
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8) | |
output_img = Image.fromarray(output_img) | |
else: | |
if mode == 'Panoptic Segmentation': | |
output_img = visualizer._draw_panoptic_seg( | |
output_img, | |
masks['pan_results'].to('cpu').numpy(), | |
classes=CocoPanopticDataset.METAINFO['classes'], | |
palette=CocoPanopticDataset.METAINFO['palette'] | |
) | |
elif mode == 'Instance Segmentation': | |
masks['ins_results'] = masks['ins_results'][masks['ins_results'].scores > .2] | |
output_img = visualizer._draw_instances( | |
output_img, | |
masks['ins_results'].to('cpu').numpy(), | |
classes=CocoPanopticDataset.METAINFO['classes'], | |
palette=CocoPanopticDataset.METAINFO['palette'] | |
) | |
return image, output_img | |
def register_title(): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(TITLE) | |
def register_point_mode(): | |
with gr.Tab("Point mode"): | |
img_state = gr.State(IMGState()) | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
img_p = gr.Image(label="Input Image", type="pil") | |
with gr.Column(scale=1): | |
segm_p = gr.Image(label="Segment", interactive=False, type="pil") | |
with gr.Row(): | |
with gr.Column(): | |
mode = gr.Radio( | |
["Panoptic Segmentation", "Instance Segmentation"], | |
label="Mode", | |
value="Panoptic Segmentation", | |
info="Please select the segmentation mode. (Ignored if provided with prompt.)" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
segment_btn = gr.Button("Segment", variant="primary") | |
with gr.Column(): | |
clean_btn = gr.Button("Clean Prompts", variant="secondary") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("Try some of the examples below ⬇️") | |
gr.Examples( | |
examples=examples, | |
inputs=[img_p, img_state], | |
outputs=[img_p, segm_p], | |
examples_per_page=4, | |
fn=store_img, | |
run_on_click=True | |
) | |
img_p.upload( | |
store_img, | |
[img_p, img_state], | |
[img_p, segm_p] | |
) | |
img_p.select( | |
get_points_with_draw, | |
[img_p, img_state], | |
img_p | |
) | |
segment_btn.click( | |
segment_point, | |
[img_p, img_state, mode], | |
[img_p, segm_p] | |
) | |
clean_btn.click( | |
IMGState.cls_clean, | |
img_state, | |
[img_p, segm_p] | |
) | |
img_p.clear( | |
IMGState.cls_clear, | |
img_state, | |
[img_p, segm_p] | |
) | |
def build_demo(): | |
with gr.Blocks(css=CSS, title="RAP-SAM") as _demo: | |
register_title() | |
register_point_mode() | |
return _demo | |
if __name__ == '__main__': | |
demo = build_demo() | |
demo.queue(api_open=False) | |
demo.launch(server_name='0.0.0.0') | |