|
import os |
|
|
|
if not os.path.isdir("weights"): |
|
os.mkdir("weights") |
|
|
|
os.system("python -m pip install --upgrade pip") |
|
os.system( |
|
"wget https://raw.githubusercontent.com/asharma381/cs291I/main/backend/original_images/000749.png" |
|
) |
|
os.system( |
|
"wget -q -O weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
) |
|
os.system( |
|
"wget -q -O weights/ram_plus_swin_large_14m.pth https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth" |
|
) |
|
os.system( |
|
"wget -q -O weights/groundingdino_swint_ogc.pth https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" |
|
) |
|
os.system("pip install git+https://github.com/xinyu1205/recognize-anything.git") |
|
os.system("pip install git+https://github.com/IDEA-Research/GroundingDINO.git") |
|
os.system("pip install git+https://github.com/facebookresearch/segment-anything.git") |
|
os.system("pip install openai==0.27.4") |
|
os.system("pip install tenacity") |
|
|
|
|
|
from typing import List, Tuple |
|
|
|
import cv2 |
|
import gradio as gr |
|
import groundingdino.config.GroundingDINO_SwinT_OGC |
|
import numpy as np |
|
import openai |
|
import torch |
|
from groundingdino.util.inference import Model |
|
from PIL import Image, ImageDraw |
|
from ram import get_transform |
|
from ram import inference_ram as inference |
|
from ram.models import ram_plus |
|
from scipy.spatial.distance import cdist |
|
from segment_anything import SamPredictor, sam_model_registry |
|
from supervision import Detections |
|
from tenacity import retry, wait_fixed |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
ram_model = None |
|
ram_threshold_multiplier = 1 |
|
gdino_model = None |
|
sam_model = None |
|
sam_predictor = None |
|
|
|
print("CUDA Available:", torch.cuda.is_available()) |
|
|
|
|
|
def get_tags_ram( |
|
image: Image.Image, threshold_multiplier=0.8, weights_folder="weights" |
|
) -> List[str]: |
|
global ram_model, ram_threshold_multiplier |
|
if ram_model is None: |
|
print("Loading RAM++ Model...") |
|
ram_model = ram_plus( |
|
pretrained=f"{weights_folder}/ram_plus_swin_large_14m.pth", |
|
vit="swin_l", |
|
image_size=384, |
|
) |
|
ram_model.eval() |
|
ram_model = ram_model.to(device) |
|
|
|
ram_model.class_threshold *= threshold_multiplier / ram_threshold_multiplier |
|
ram_threshold_multiplier = threshold_multiplier |
|
transform = get_transform() |
|
|
|
image = transform(image).unsqueeze(0).to(device) |
|
res = inference(image, ram_model) |
|
return [s.strip() for s in res[0].split("|")] |
|
|
|
|
|
def get_gdino_result( |
|
image: Image.Image, |
|
classes: List[str], |
|
box_threshold: float = 0.25, |
|
weights_folder="weights", |
|
) -> Tuple[Detections, List[str]]: |
|
global gdino_model |
|
|
|
if gdino_model is None: |
|
print("Loading GroundingDINO Model...") |
|
config_path = groundingdino.config.GroundingDINO_SwinT_OGC.__file__ |
|
gdino_model = Model( |
|
model_config_path=config_path, |
|
model_checkpoint_path=f"{weights_folder}/groundingdino_swint_ogc.pth", |
|
device=device, |
|
) |
|
|
|
detections, phrases = gdino_model.predict_with_caption( |
|
image=np.array(image), |
|
caption=", ".join(classes), |
|
box_threshold=box_threshold, |
|
text_threshold=0.25, |
|
) |
|
|
|
return detections, phrases |
|
|
|
|
|
def get_sam_model(weights_folder="weights"): |
|
global sam_model |
|
if sam_model is None: |
|
sam_checkpoint = f"{weights_folder}/sam_vit_h_4b8939.pth" |
|
sam_model = sam_model_registry["vit_h"](checkpoint=sam_checkpoint) |
|
sam_model.to(device=device) |
|
return sam_model |
|
|
|
|
|
def filter_tags_gdino(image: Image.Image, tags: List[str]) -> List[str]: |
|
detections, phrases = get_gdino_result(image, tags) |
|
filtered_tags = [] |
|
for tag in tags: |
|
for ( |
|
phrase, |
|
area, |
|
) in zip(phrases, detections.area): |
|
if area < 0.9 * image.size[0] * image.size[1] and tag in phrase: |
|
filtered_tags.append(tag) |
|
break |
|
return filtered_tags |
|
|
|
|
|
def read_file_to_string(file_path: str) -> str: |
|
content = "" |
|
|
|
try: |
|
with open(file_path, "r", encoding="utf8") as file: |
|
content = file.read() |
|
except FileNotFoundError: |
|
print(f"The file {file_path} was not found.") |
|
except Exception as e: |
|
print(f"An error occurred while reading {file_path}: {e}") |
|
|
|
return content |
|
|
|
|
|
@retry(wait=wait_fixed(2)) |
|
def completion_with_backoff(**kwargs): |
|
return openai.ChatCompletion.create(**kwargs) |
|
|
|
|
|
def gpt4( |
|
usr_prompt: str, sys_prompt: str = "", api_key: str = "", model: str = "gpt-4" |
|
) -> str: |
|
openai.api_key = api_key |
|
|
|
message = [ |
|
{"role": "system", "content": sys_prompt}, |
|
{"role": "user", "content": usr_prompt}, |
|
] |
|
|
|
response = completion_with_backoff( |
|
model=model, |
|
messages=message, |
|
temperature=0.2, |
|
max_tokens=1000, |
|
frequency_penalty=0.0, |
|
) |
|
|
|
return response["choices"][0]["message"]["content"] |
|
|
|
|
|
def select_best_tag( |
|
filtered_tags: List[str], object_to_place: str, api_key: str = "" |
|
) -> str: |
|
user_template = read_file_to_string("user_template.txt").format(object=object_to_place) |
|
user_prompt = user_template + "\n".join(filtered_tags) |
|
system_prompt = read_file_to_string("system_template.txt") |
|
return gpt4(user_prompt, system_prompt, api_key=api_key) |
|
|
|
|
|
def get_location_gsam( |
|
image: Image.Image, prompt: str, weights_folder="weights" |
|
) -> Tuple[int, int]: |
|
global sam_predictor |
|
|
|
BOX_TRESHOLD = 0.25 |
|
RESIZE_RATIO = 3 |
|
|
|
detections, phrases = get_gdino_result( |
|
image=image, |
|
classes=[prompt], |
|
box_threshold=BOX_TRESHOLD, |
|
) |
|
|
|
while len(detections.xyxy) == 0: |
|
BOX_TRESHOLD -= 0.02 |
|
detections, phrases = get_gdino_result( |
|
image=image, |
|
classes=[prompt], |
|
box_threshold=BOX_TRESHOLD, |
|
) |
|
|
|
sam_model = get_sam_model(weights_folder) |
|
|
|
if sam_predictor is None: |
|
print("Loading SAM Model...") |
|
sam_predictor = SamPredictor(sam_model) |
|
|
|
sam_predictor.set_image(np.array(image)) |
|
result_masks = [] |
|
for box in detections.xyxy: |
|
masks, scores, logits = sam_predictor.predict(box=box, multimask_output=True) |
|
index = np.argmax(scores) |
|
result_masks.append(masks[index]) |
|
detections.mask = np.array(result_masks) |
|
|
|
combined_mask = detections.mask[0] |
|
for mask in detections.mask[1:]: |
|
combined_mask += mask |
|
combined_mask[combined_mask > 1] = 1 |
|
mask = cv2.resize( |
|
combined_mask.astype("uint8"), |
|
( |
|
combined_mask.shape[1] // RESIZE_RATIO, |
|
combined_mask.shape[0] // RESIZE_RATIO, |
|
), |
|
) |
|
|
|
mask_2_pad = np.pad(mask, pad_width=2, mode="constant", constant_values=0) |
|
mask_1_pad = np.pad(mask, pad_width=1, mode="constant", constant_values=0) |
|
|
|
windows = np.lib.stride_tricks.sliding_window_view(mask_2_pad, (3, 3)) |
|
windows_all_zero = (windows == 0).all(axis=(2, 3)) |
|
|
|
result = np.where(windows_all_zero, 2, mask_1_pad) |
|
mask_0_coordinates = np.argwhere(result == 0) |
|
mask_1_coordinates = np.argwhere(result == 1) |
|
distances = cdist(mask_1_coordinates, mask_0_coordinates, "euclidean") |
|
max_min_distance_index = np.argmax(np.min(distances, axis=1)) |
|
y, x = mask_1_coordinates[max_min_distance_index] |
|
|
|
return int(x) * RESIZE_RATIO, int(y) * RESIZE_RATIO |
|
|
|
|
|
def run_octo_pipeline(input_image, object, api_key): |
|
print("Inside run_octo_pipeline with input_image=", input_image, "object=", object) |
|
|
|
print("Loading Image...") |
|
image = input_image.convert("RGB") |
|
|
|
print("Stage 1...") |
|
tags = get_tags_ram(image, threshold_multiplier=0.8) |
|
print("RAM++ Tags", tags) |
|
filtered_tags = filter_tags_gdino(image, tags) |
|
print("Filtered Tags", filtered_tags) |
|
|
|
print("Stage 2...") |
|
selected_tag = select_best_tag(filtered_tags, object, api_key=api_key) |
|
print("GPT-4 Selected Tag", selected_tag) |
|
|
|
print("Stage 3...") |
|
x, y = get_location_gsam(image, selected_tag) |
|
print("G-SAM Location", "(" + str(x) + "," + str(y) + ")") |
|
|
|
draw = ImageDraw.Draw(image) |
|
radius = 10 |
|
bbox = (x - radius, y - radius, x + radius, y + radius) |
|
draw.ellipse(bbox, fill="red") |
|
return [image] |
|
|
|
|
|
block = gr.Blocks() |
|
|
|
with block: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", value="000749.png") |
|
object = gr.Textbox(label="Object", placeholder="Enter an object") |
|
api_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter OpenAI API Key") |
|
|
|
with gr.Column(): |
|
gallery = gr.Gallery( |
|
label="Output", |
|
show_label=False, |
|
elem_id="gallery", |
|
preview=True, |
|
object_fit="scale-down", |
|
) |
|
|
|
iface = gr.Interface( |
|
fn=run_octo_pipeline, inputs=[input_image, object, api_key], outputs=gallery |
|
) |
|
iface.launch() |
|
|