# no gpu required from transformers import pipeline, SamModel, SamProcessor import torch import numpy as np import spaces device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = "google/owlv2-base-patch16-ensemble" detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device=device) sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base").to(device) sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base") def query(image, texts, threshold): texts = texts.split(",") predictions = detector( image, candidate_labels=texts, threshold=threshold ) result_labels = [] for pred in predictions: score = pred["score"] if score > 0.5: box = pred["box"] label = pred["label"] box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2), round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)] inputs = sam_processor( image, input_boxes=[[[box]]], return_tensors="pt" ).to(device) with torch.no_grad(): outputs = sam_model(**inputs) mask = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() )[0][0][0].numpy() mask = mask[np.newaxis, ...] result_labels.append((mask, label)) return image, result_labels import gradio as gr description = ( "Welcome to RobustSAM by Snap Research." "This Space uses RobustSAM, an enhanced version of the Segment Anything Model (SAM) with improved performance on low-quality images while maintaining zero-shot segmentation capabilities. " "Thanks to its integration with OWLv2, RobustSAM becomes text-promptable, allowing for flexible and accurate segmentation, even with degraded image quality. Try the example or input an image with comma-separated candidate labels to see the enhanced segmentation results." ) demo = gr.Interface( query, inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")], outputs=gr.AnnotatedImage(label="Segmented Image"), title="RobustSAM", description=description, examples=[ ["./blur.jpg", "insect", 0.1], ["./lowlight.jpg", "bus, window", 0.1], ["./rain.jpg", "tree, leafs", 0.1], ["./haze.jpg", "", 0.1], ], cache_examples=True ) demo.launch()