|
from typing import Dict, List, Any |
|
import os |
|
import requests |
|
from flask import Flask, Response, request, jsonify |
|
from segment_anything import SamPredictor, sam_model_registry |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
model_type = "vit_b" |
|
|
|
print('current working directory', os.getcwd()) |
|
model_path = "models/tf_model.h5" |
|
|
|
sam = sam_model_registry[model_type](checkpoint=model_path) |
|
self.predictor = SamPredictor(sam) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
image_url = inputs.pop("imageUrl", None) |
|
|
|
if not image_url: |
|
return jsonify({"error": "image_url not provided"}), 400 |
|
|
|
try: |
|
response = requests.get(image_url) |
|
response.raise_for_status() |
|
image = response.content |
|
except requests.RequestException as e: |
|
return jsonify({"error": f"Error downloading image: {str(e)}"}), 500 |
|
|
|
self.predictor.set_image(image) |
|
|
|
image_embedding = self.predictor.get_image_embedding().cpu().numpy().tolist() |
|
|
|
return jsonify(image_embedding) |
|
|