Segment Anything 8-Bit ONNX
How to run:
import onnxruntime as ort
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
# Path to the image file
image_path = "example.png"
# Load the image and preprocess it
image = Image.open(image_path).convert("RGB")
orig_width, orig_height = image.size
input_tensor = np.array(image)
mean = np.array([123.675, 116.28, 103.53])
std = np.array([58.395, 57.12, 57.375])
input_tensor = (input_tensor - mean) / std
input_tensor = input_tensor.transpose(2, 0, 1)[None, :, :, :].astype(np.float32)
# Pad input tensor to 1024x1024
pad_height = 1024 - input_tensor.shape[2]
pad_width = 1024 - input_tensor.shape[3]
input_tensor = np.pad(input_tensor, ((0, 0), (0, 0), (0, pad_height), (0, pad_width)))
# Load the encoder model and run inference
encoder = ort.InferenceSession("sam_encoder.onnx")
embeddings = encoder.run(None, {"images": input_tensor})[0]
# Choose a point (e.g., x=150, y=100) in the original image
point = [150, 100]
# Convert point coordinates to match the padded image
point = np.array([[point]])
coords = point.astype(float)
coords[..., 0] = coords[..., 0] * (1024 / orig_width)
coords[..., 1] = coords[..., 1] * (1024 / orig_height)
onnx_coord = coords.astype("float32")
# Prepare inputs for the decoder
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
onnx_label = np.array([1, -1]).astype(np.float32)[None, :]
# Load the decoder model and run inference
decoder = ort.InferenceSession("sam_decoder.onnx")
masks_output, _, _ = decoder.run(None, {
"image_embeddings": embeddings,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
})
# Process the output mask
mask = masks_output[0][0]
mask = (mask > 0).astype('uint8') * 255