vierundvi / automatic_label_simple_demo.py
mart9992's picture
m
2cd560a
raw
history blame
5 kB
import cv2
import numpy as np
import supervision as sv
from typing import List
from PIL import Image
import torch
from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor
# Tag2Text
# from ram.models import tag2text_caption
from ram.models import ram
# from ram import inference_tag2text
from ram import inference_ram
import torchvision
import torchvision.transforms as TS
# Hyper-Params
SOURCE_IMAGE_PATH = "./assets/demo9.jpg"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
TAG2TEXT_CHECKPOINT_PATH = "./tag2text_swin_14m.pth"
RAM_CHECKPOINT_PATH = "./ram_swin_large_14m.pth"
TAG2TEXT_THRESHOLD = 0.64
BOX_THRESHOLD = 0.2
TEXT_THRESHOLD = 0.2
IOU_THRESHOLD = 0.5
# Building GroundingDINO inference model
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
# Building SAM Model and SAM Predictor
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
sam_predictor = SamPredictor(sam)
# Tag2Text
# initialize Tag2Text
normalize = TS.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
transform = TS.Compose(
[
TS.Resize((384, 384)),
TS.ToTensor(),
normalize
]
)
DELETE_TAG_INDEX = [] # filter out attributes and action which are difficult to be grounded
for idx in range(3012, 3429):
DELETE_TAG_INDEX.append(idx)
# tag2text_model = tag2text_caption(
# pretrained=TAG2TEXT_CHECKPOINT_PATH,
# image_size=384,
# vit='swin_b',
# delete_tag_index=DELETE_TAG_INDEX
# )
# # threshold for tagging
# # we reduce the threshold to obtain more tags
# tag2text_model.threshold = TAG2TEXT_THRESHOLD
# tag2text_model.eval()
# tag2text_model = tag2text_model.to(DEVICE)
ram_model = ram(pretrained=RAM_CHECKPOINT_PATH,
image_size=384,
vit='swin_l')
ram_model.eval()
ram_model = ram_model.to(DEVICE)
# load image
image = cv2.imread(SOURCE_IMAGE_PATH) # bgr
image_pillow = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # rgb
image_pillow = image_pillow.resize((384, 384))
image_pillow = transform(image_pillow).unsqueeze(0).to(DEVICE)
specified_tags='None'
# res = inference_tag2text(image_pillow , tag2text_model, specified_tags)
res = inference_ram(image_pillow , ram_model)
# Currently ", " is better for detecting single tags
# while ". " is a little worse in some case
AUTOMATIC_CLASSES=res[0].split(" | ")
print(f"Tags: {res[0].replace(' |', ',')}")
# detect objects
detections = grounding_dino_model.predict_with_classes(
image=image,
classes=AUTOMATIC_CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=BOX_THRESHOLD
)
# NMS post process
print(f"Before NMS: {len(detections.xyxy)} boxes")
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
IOU_THRESHOLD
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]
print(f"After NMS: {len(detections.xyxy)} boxes")
# annotate image with detections
box_annotator = sv.BoxAnnotator()
labels = [
f"{AUTOMATIC_CLASSES[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _
in detections]
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
# save the annotated grounding dino image
cv2.imwrite("groundingdino_auto_annotated_image.jpg", annotated_frame)
# Prompting SAM with detected boxes
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
sam_predictor.set_image(image)
result_masks = []
for box in xyxy:
masks, scores, logits = sam_predictor.predict(
box=box,
multimask_output=True
)
index = np.argmax(scores)
result_masks.append(masks[index])
return np.array(result_masks)
# convert detections to masks
detections.mask = segment(
sam_predictor=sam_predictor,
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
xyxy=detections.xyxy
)
# annotate image with detections
box_annotator = sv.BoxAnnotator()
mask_annotator = sv.MaskAnnotator()
labels = [
f"{AUTOMATIC_CLASSES[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _
in detections]
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
# save the annotated grounded-sam image
cv2.imwrite("ram_grounded_sam_auto_annotated_image.jpg", annotated_image)