|
|
|
import base64 |
|
from io import BytesIO |
|
import os |
|
from typing import Dict, List, Any |
|
import cv2 |
|
import groundingdino |
|
from groundingdino.util.inference import load_model, load_image, predict, annotate |
|
import tempfile |
|
|
|
|
|
HOME = os.getcwd() |
|
|
|
|
|
PACKAGE_HOME = os.path.dirname(groundingdino.__file__) |
|
CONFIG_PATH = os.path.join(PACKAGE_HOME, "config", "GroundingDINO_SwinT_OGC.py") |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path): |
|
|
|
|
|
self.model = load_model(CONFIG_PATH, os.path.join(path, "weights", "groundingdino_swint_ogc.pth")) |
|
|
|
self.box_threshold = 0.35 |
|
self.text_threshold = 0.25 |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
inputs = data.pop("inputs") |
|
image_base64 = inputs.pop("image") |
|
prompt = inputs.pop("prompt") |
|
|
|
image_data = base64.b64decode(image_base64) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=True) as f: |
|
f.write(image_data) |
|
image_source, image = load_image(f.name) |
|
boxes, logits, phrases = predict( |
|
model=self.model, |
|
image=image, |
|
caption=prompt, |
|
box_threshold=self.box_threshold, |
|
text_threshold=self.text_threshold |
|
) |
|
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases) |
|
_, annotated_image = cv2.imencode(".jpg", annotated_frame) |
|
annotated_image_b64 = base64.b64encode(annotated_image).decode("utf-8") |
|
num_found = boxes.size(0) |
|
|
|
return [{ |
|
"image": annotated_image_b64, |
|
"prompt": prompt, |
|
"num_found": num_found, |
|
}] |
|
|