Spaces:
Runtime error
Runtime error
import os | |
import sys | |
sys.path.append( | |
os.path.dirname(os.path.abspath(__file__)) | |
) | |
import copy | |
import torch | |
import numpy as np | |
from PIL import Image | |
import logging | |
from torch.hub import download_url_to_file | |
from urllib.parse import urlparse | |
import folder_paths | |
import comfy.model_management | |
from sam_hq.predictor import SamPredictorHQ | |
from sam_hq.build_sam_hq import sam_model_registry | |
from local_groundingdino.datasets import transforms as T | |
from local_groundingdino.util.utils import clean_state_dict as local_groundingdino_clean_state_dict | |
from local_groundingdino.util.slconfig import SLConfig as local_groundingdino_SLConfig | |
from local_groundingdino.models import build_model as local_groundingdino_build_model | |
import glob | |
import folder_paths | |
logger = logging.getLogger('comfyui_segment_anything') | |
sam_model_dir_name = "sams" | |
sam_model_list = { | |
"sam_vit_h (2.56GB)": { | |
"model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
}, | |
"sam_vit_l (1.25GB)": { | |
"model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" | |
}, | |
"sam_vit_b (375MB)": { | |
"model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" | |
}, | |
"sam_hq_vit_h (2.57GB)": { | |
"model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" | |
}, | |
"sam_hq_vit_l (1.25GB)": { | |
"model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" | |
}, | |
"sam_hq_vit_b (379MB)": { | |
"model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" | |
}, | |
"mobile_sam(39MB)": { | |
"model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" | |
} | |
} | |
groundingdino_model_dir_name = "grounding-dino" | |
groundingdino_model_list = { | |
"GroundingDINO_SwinT_OGC (694MB)": { | |
"config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", | |
"model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", | |
}, | |
"GroundingDINO_SwinB (938MB)": { | |
"config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", | |
"model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth" | |
}, | |
} | |
def get_bert_base_uncased_model_path(): | |
comfy_bert_model_base = os.path.join(folder_paths.models_dir, 'bert-base-uncased') | |
if glob.glob(os.path.join(comfy_bert_model_base, '**/model.safetensors'), recursive=True): | |
print('grounding-dino is using models/bert-base-uncased') | |
return comfy_bert_model_base | |
return 'bert-base-uncased' | |
def list_files(dirpath, extensions=[]): | |
return [f for f in os.listdir(dirpath) if os.path.isfile(os.path.join(dirpath, f)) and f.split('.')[-1] in extensions] | |
def list_sam_model(): | |
return list(sam_model_list.keys()) | |
def load_sam_model(model_name): | |
sam_checkpoint_path = get_local_filepath( | |
sam_model_list[model_name]["model_url"], sam_model_dir_name) | |
model_file_name = os.path.basename(sam_checkpoint_path) | |
model_type = model_file_name.split('.')[0] | |
if 'hq' not in model_type and 'mobile' not in model_type: | |
model_type = '_'.join(model_type.split('_')[:-1]) | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path) | |
sam_device = comfy.model_management.get_torch_device() | |
sam.to(device=sam_device) | |
sam.eval() | |
sam.model_name = model_file_name | |
return sam | |
def get_local_filepath(url, dirname, local_file_name=None): | |
if not local_file_name: | |
parsed_url = urlparse(url) | |
local_file_name = os.path.basename(parsed_url.path) | |
destination = folder_paths.get_full_path(dirname, local_file_name) | |
if destination: | |
logger.warn(f'using extra model: {destination}') | |
return destination | |
folder = os.path.join(folder_paths.models_dir, dirname) | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
destination = os.path.join(folder, local_file_name) | |
if not os.path.exists(destination): | |
logger.warn(f'downloading {url} to {destination}') | |
download_url_to_file(url, destination) | |
return destination | |
def load_groundingdino_model(model_name): | |
dino_model_args = local_groundingdino_SLConfig.fromfile( | |
get_local_filepath( | |
groundingdino_model_list[model_name]["config_url"], | |
groundingdino_model_dir_name | |
), | |
) | |
if dino_model_args.text_encoder_type == 'bert-base-uncased': | |
dino_model_args.text_encoder_type = get_bert_base_uncased_model_path() | |
dino = local_groundingdino_build_model(dino_model_args) | |
checkpoint = torch.load( | |
get_local_filepath( | |
groundingdino_model_list[model_name]["model_url"], | |
groundingdino_model_dir_name, | |
), | |
) | |
dino.load_state_dict(local_groundingdino_clean_state_dict( | |
checkpoint['model']), strict=False) | |
device = comfy.model_management.get_torch_device() | |
dino.to(device=device) | |
dino.eval() | |
return dino | |
def list_groundingdino_model(): | |
return list(groundingdino_model_list.keys()) | |
def groundingdino_predict( | |
dino_model, | |
image, | |
prompt, | |
threshold | |
): | |
def load_dino_image(image_pil): | |
transform = T.Compose( | |
[ | |
T.RandomResize([800], max_size=1333), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
image, _ = transform(image_pil, None) # 3, h, w | |
return image | |
def get_grounding_output(model, image, caption, box_threshold): | |
caption = caption.lower() | |
caption = caption.strip() | |
if not caption.endswith("."): | |
caption = caption + "." | |
device = comfy.model_management.get_torch_device() | |
image = image.to(device) | |
with torch.no_grad(): | |
outputs = model(image[None], captions=[caption]) | |
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256) | |
boxes = outputs["pred_boxes"][0] # (nq, 4) | |
# filter output | |
logits_filt = logits.clone() | |
boxes_filt = boxes.clone() | |
filt_mask = logits_filt.max(dim=1)[0] > box_threshold | |
logits_filt = logits_filt[filt_mask] # num_filt, 256 | |
boxes_filt = boxes_filt[filt_mask] # num_filt, 4 | |
return boxes_filt.cpu() | |
dino_image = load_dino_image(image.convert("RGB")) | |
boxes_filt = get_grounding_output( | |
dino_model, dino_image, prompt, threshold | |
) | |
H, W = image.size[1], image.size[0] | |
for i in range(boxes_filt.size(0)): | |
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) | |
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 | |
boxes_filt[i][2:] += boxes_filt[i][:2] | |
return boxes_filt | |
def create_pil_output(image_np, masks, boxes_filt): | |
output_masks, output_images = [], [] | |
boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None | |
for mask in masks: | |
output_masks.append(Image.fromarray(np.any(mask, axis=0))) | |
image_np_copy = copy.deepcopy(image_np) | |
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0]) | |
output_images.append(Image.fromarray(image_np_copy)) | |
return output_images, output_masks | |
def create_tensor_output(image_np, masks, boxes_filt): | |
output_masks, output_images = [], [] | |
boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None | |
for mask in masks: | |
image_np_copy = copy.deepcopy(image_np) | |
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0]) | |
output_image, output_mask = split_image_mask( | |
Image.fromarray(image_np_copy)) | |
output_masks.append(output_mask) | |
output_images.append(output_image) | |
return (output_images, output_masks) | |
def split_image_mask(image): | |
image_rgb = image.convert("RGB") | |
image_rgb = np.array(image_rgb).astype(np.float32) / 255.0 | |
image_rgb = torch.from_numpy(image_rgb)[None,] | |
if 'A' in image.getbands(): | |
mask = np.array(image.getchannel('A')).astype(np.float32) / 255.0 | |
mask = torch.from_numpy(mask)[None,] | |
else: | |
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") | |
return (image_rgb, mask) | |
def sam_segment( | |
sam_model, | |
image, | |
boxes | |
): | |
if boxes.shape[0] == 0: | |
return None | |
sam_is_hq = False | |
# TODO: more elegant | |
if hasattr(sam_model, 'model_name') and 'hq' in sam_model.model_name: | |
sam_is_hq = True | |
predictor = SamPredictorHQ(sam_model, sam_is_hq) | |
image_np = np.array(image) | |
image_np_rgb = image_np[..., :3] | |
predictor.set_image(image_np_rgb) | |
transformed_boxes = predictor.transform.apply_boxes_torch( | |
boxes, image_np.shape[:2]) | |
sam_device = comfy.model_management.get_torch_device() | |
masks, _, _ = predictor.predict_torch( | |
point_coords=None, | |
point_labels=None, | |
boxes=transformed_boxes.to(sam_device), | |
multimask_output=False) | |
masks = masks.permute(1, 0, 2, 3).cpu().numpy() | |
return create_tensor_output(image_np, masks, boxes) | |
class SAMModelLoader: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"model_name": (list_sam_model(), ), | |
} | |
} | |
CATEGORY = "segment_anything" | |
FUNCTION = "main" | |
RETURN_TYPES = ("SAM_MODEL", ) | |
def main(self, model_name): | |
sam_model = load_sam_model(model_name) | |
return (sam_model, ) | |
class GroundingDinoModelLoader: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"model_name": (list_groundingdino_model(), ), | |
} | |
} | |
CATEGORY = "segment_anything" | |
FUNCTION = "main" | |
RETURN_TYPES = ("GROUNDING_DINO_MODEL", ) | |
def main(self, model_name): | |
dino_model = load_groundingdino_model(model_name) | |
return (dino_model, ) | |
class GroundingDinoSAMSegment: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"sam_model": ('SAM_MODEL', {}), | |
"grounding_dino_model": ('GROUNDING_DINO_MODEL', {}), | |
"image": ('IMAGE', {}), | |
"prompt": ("STRING", {}), | |
"threshold": ("FLOAT", { | |
"default": 0.3, | |
"min": 0, | |
"max": 1.0, | |
"step": 0.01 | |
}), | |
} | |
} | |
CATEGORY = "segment_anything" | |
FUNCTION = "main" | |
RETURN_TYPES = ("IMAGE", "MASK") | |
def main(self, grounding_dino_model, sam_model, image, prompt, threshold): | |
res_images = [] | |
res_masks = [] | |
for item in image: | |
item = Image.fromarray( | |
np.clip(255. * item.cpu().numpy(), 0, 255).astype(np.uint8)).convert('RGBA') | |
boxes = groundingdino_predict( | |
grounding_dino_model, | |
item, | |
prompt, | |
threshold | |
) | |
if boxes.shape[0] == 0: | |
break | |
(images, masks) = sam_segment( | |
sam_model, | |
item, | |
boxes | |
) | |
res_images.extend(images) | |
res_masks.extend(masks) | |
if len(res_images) == 0: | |
_, height, width, _ = image.size() | |
empty_mask = torch.zeros((1, height, width), dtype=torch.uint8, device="cpu") | |
return (empty_mask, empty_mask) | |
return (torch.cat(res_images, dim=0), torch.cat(res_masks, dim=0)) | |
class InvertMask: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
} | |
} | |
CATEGORY = "segment_anything" | |
FUNCTION = "main" | |
RETURN_TYPES = ("MASK",) | |
def main(self, mask): | |
out = 1.0 - mask | |
return (out,) | |
class IsMaskEmptyNode: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
}, | |
} | |
RETURN_TYPES = ["NUMBER"] | |
RETURN_NAMES = ["boolean_number"] | |
FUNCTION = "main" | |
CATEGORY = "segment_anything" | |
def main(self, mask): | |
return (torch.all(mask == 0).int().item(), ) |