Spaces:
Runtime error
Runtime error
File size: 4,571 Bytes
6ffeb01 f7c3c0c 6ffeb01 1e03c2b 6ffeb01 f7c3c0c 6ffeb01 1e03c2b 6ffeb01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
"""
Gradio app for pollen-vision
This script creates a Gradio app for pollen-vision. The app allows users to perform object detection and object segmentation using the OWL-ViT and MobileSAM models.
"""
from datasets import load_dataset
import gradio as gr
import numpy as np
import numpy.typing as npt
from typing import Any, Dict, List
from pollen_vision.vision_models.object_detection import OwlVitWrapper
from pollen_vision.vision_models.object_segmentation import MobileSamWrapper
from pollen_vision.vision_models.utils import Annotator, get_bboxes
owl_vit = OwlVitWrapper()
mobile_sam = MobileSamWrapper()
annotator = Annotator()
def object_detection(
img: npt.NDArray[np.uint8], text_queries: List[str], score_threshold: float
) -> List[Dict[str, Any]]:
predictions: List[Dict[str, Any]] = owl_vit.infer(
im=img, candidate_labels=text_queries, detection_threshold=score_threshold
)
return predictions
def object_segmentation(
img: npt.NDArray[np.uint8], object_detection_predictions: List[Dict[str, Any]]
) -> List[npt.NDArray[np.uint8]]:
bboxes = get_bboxes(predictions=object_detection_predictions)
masks: List[npt.NDArray[np.uint8]] = mobile_sam.infer(im=img, bboxes=bboxes)
return masks
def query(
task: str,
img: npt.NDArray[np.uint8],
text_queries: List[str],
score_threshold: float,
) -> npt.NDArray[np.uint8]:
object_detection_predictions = object_detection(
img=img, text_queries=text_queries, score_threshold=score_threshold
)
if task == "Object detection + segmentation (OWL-ViT + MobileSAM)":
masks = object_segmentation(
img=img, object_detection_predictions=object_detection_predictions
)
img = annotator.annotate(
im=img, detection_predictions=object_detection_predictions, masks=masks
)
return img
img = annotator.annotate(im=img, detection_predictions=object_detection_predictions)
return img
description = """
Welcome to the demo of pollen-vision, a simple and unified Python library to zero-shot computer vision models curated
for robotics use cases. **Pollen-vision** is designed for ease of installation and use, composed of independent modules
that can be combined to create a 3D object detection pipeline, getting the position of the objects in 3D space (x, y, z).
\n\nIn this demo, you have the option to choose between two tasks: object detection and object detection + segmentation.
The models available are:
- **OWL-VIT** (Open World Localization - Vision Transformer, By Google Research): this model performs text-conditionned
zero-shot 2D object localization in RGB images.
- **Mobile SAM**: A lightweight version of the Segment Anything Model (SAM) by Meta AI. SAM is a zero shot image
segmentation model. It can be prompted with bounding boxes or points. (https://github.com/ChaoningZhang/MobileSAM)
\n\nYou can input images in this demo in three ways: either by trying out the provided examples, by uploading an image
of your choice, or by capturing an image from your computer's webcam.
Additionally, you should provide text queries representing a list of objects to detect. Separate each object with a comma.
The last input parameter is the detection threshold (ranging from 0 to 1), which defaults to 0.1.
\n\nCheck out our blog post introducing pollen-vision or its <a href="https://github.com/pollen-robotics/pollen-vision">
Github repository</a> for more info!
"""
demo_inputs = [
gr.Dropdown(
[
"Object detection (OWL-ViT)",
"Object detection + segmentation (OWL-ViT + MobileSAM)",
],
label="Choose a task",
value="Object detection (OWL-ViT)",
),
gr.Image(),
"text",
gr.Slider(0, 1, value=0.1),
]
rdt_dataset = load_dataset("pollen-robotics/reachy-doing-things", split="train")
img_kitchen_detection = rdt_dataset[11]["image"]
img_kitchen_segmentation = rdt_dataset[12]["image"]
demo_examples = [
[
"Object detection (OWL-ViT)",
img_kitchen_detection,
["kettle", "black mug", "sink", "blue mug", "sponge", "bag of chips"],
0.15,
],
[
"Object detection + segmentation (OWL-ViT + MobileSAM)",
img_kitchen_segmentation,
["blue mug", "paper cup", "kettle", "sponge"],
0.12,
],
]
demo = gr.Interface(
fn=query,
inputs=demo_inputs,
outputs="image",
title="Use zero-shot computer vision models with pollen-vision",
description=description,
examples=demo_examples,
)
demo.launch()
|