Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
import cv2 | |
def show_anns(anns): | |
if len(anns) == 0: | |
return | |
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) | |
img[:,:,3] = 0 | |
for ann in sorted_anns: | |
m = ann['segmentation'] | |
color_mask = np.concatenate([np.random.random(3), [0.35]]) | |
img[m] = color_mask | |
ax.imshow(img) | |
import sys | |
sys.path.append("..") | |
from tinysam import sam_model_registry, SamHierarchicalMaskGenerator | |
model_type = "vit_t" | |
sam = sam_model_registry[model_type](checkpoint="./weights/tinysam.pth") | |
sam.eval() | |
mask_generator = SamHierarchicalMaskGenerator(sam) | |
image = cv2.imread('fig/picture3.jpg') | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
masks = mask_generator.hierarchical_generate(image) | |
plt.figure(figsize=(20,20)) | |
plt.imshow(image) | |
show_anns(masks) | |
plt.axis('off') | |
plt.savefig("test_everthing.png") | |