|
import os |
|
import socket |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image, ImageDraw |
|
from pathlib import Path |
|
from loguru import logger |
|
import cv2 |
|
import torch |
|
import time |
|
import base64 |
|
import requests |
|
import json |
|
|
|
|
|
DL4EO_API_URL = "https://dl4eo--groundingdino-predict.modal.run" |
|
|
|
|
|
DL4EO_API_KEY = os.environ['DL4EO_API_KEY'] |
|
|
|
|
|
LINE_WIDTH = 2 |
|
|
|
|
|
logger.info(f"Gradio version: {gr.__version__}") |
|
|
|
|
|
def predict_image(image, text_prompt, box_threshold, text_threshold): |
|
|
|
|
|
|
|
|
|
if isinstance(image, Image.Image): |
|
img = np.array(image) |
|
|
|
if not isinstance(img, np.ndarray) or len(img.shape) != 3 or img.shape[2] != 3: |
|
raise BaseException("predit_image(): input 'img' shoud be single RGB image in PIL or Numpy array format.") |
|
|
|
|
|
|
|
|
|
image_base64 = base64.b64encode(np.ascontiguousarray(img)).decode() |
|
|
|
|
|
payload = { |
|
'image': image_base64, |
|
'shape': img.shape, |
|
'text_prompt': text_prompt, |
|
'box_threshold': box_threshold, |
|
'text_threshold': text_threshold, |
|
} |
|
|
|
headers = { |
|
'Authorization': 'Bearer ' + DL4EO_API_KEY, |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
|
|
response = requests.post(DL4EO_API_URL, json=payload, headers=headers) |
|
|
|
|
|
if response.status_code != 200: |
|
raise Exception( |
|
f"Received status code={response.status_code} in inference API: {response.text}" |
|
) |
|
|
|
json_data = json.loads(response.content) |
|
duration = json_data['duration'] |
|
boxes = json_data['boxes'] |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
|
for box in boxes: |
|
left, top, right, bottom = box |
|
|
|
if left <= 0: left = -LINE_WIDTH |
|
if top <= 0: top = top - LINE_WIDTH |
|
if right >= img.shape[0] - 1: right = img.shape[0] - 1 + LINE_WIDTH |
|
if bottom >= img.shape[1] - 1: bottom = img.shape[1] - 1 + LINE_WIDTH |
|
|
|
draw.rectangle([left, top, right, bottom], outline="red", width=LINE_WIDTH) |
|
|
|
return image, str(image.size), len(boxes), duration |
|
|
|
|
|
|
|
example_data = [ |
|
["./demo/Pleiades_Neo_Tucson_USA.jpg", 'plane', 0.24, 0.24], |
|
["./demo/Pleiades_Neo_Tucson_USA.jpg", 'building', 0.24, 0.24], |
|
|
|
|
|
|
|
|
|
|
|
["./demo/Pleiades_HD15_Miami_Marina.jpg", "motorboat", 0.3, 0.0], |
|
["./demo/Pleiades_HD15_Miami_Marina.jpg", "palm tree", 0.15, 0.3], |
|
["./demo/Pleiades_HD15_Miami_Marina.jpg", "building", 0.3, 0.0], |
|
] |
|
|
|
|
|
css = """ |
|
.image-preview { |
|
height: 820px !important; |
|
width: 800px !important; |
|
} |
|
""" |
|
TITLE = "Open detection on optical satellite images" |
|
|
|
|
|
demo = gr.Blocks(title=TITLE, css=css).queue() |
|
with demo: |
|
gr.Markdown(f"<h1><center>{TITLE}<center><h1>") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0): |
|
input_image = gr.Image(type="pil", interactive=True, scale=1) |
|
text_prompt = gr.Textbox(label="Text prompt") |
|
run_button = gr.Button(value="Run", scale=0) |
|
with gr.Accordion("Advanced options", open=True): |
|
box_threshold = gr.Slider(label="Box threshold", minimum=0.0, maximum=1.0, value=0.24, step=0.01) |
|
text_threshold = gr.Slider(label="Text threshold", minimum=0.0, maximum=1.0, value=0.24, step=0.01) |
|
dimensions = gr.Textbox(label="Image size", interactive=False) |
|
detections = gr.Number(label="Predicted objects", interactive=False) |
|
stopwatch = gr.Number(label="Execution time (sec.)", interactive=False, precision=3) |
|
|
|
with gr.Column(scale=2): |
|
output_image = gr.Image(type="pil", elem_classes='image-preview', interactive=False, width=800, height=800) |
|
|
|
run_button.click(fn=predict_image, inputs=[input_image, text_prompt, box_threshold, text_threshold], outputs=[output_image, dimensions, detections, stopwatch]) |
|
gr.Examples( |
|
examples=example_data, |
|
inputs = [input_image, text_prompt, box_threshold, text_threshold], |
|
outputs = [output_image, dimensions, detections, stopwatch], |
|
fn=predict_image, |
|
cache_examples=True, |
|
label='Try these images!' |
|
) |
|
|
|
gr.Markdown("<p>This demo is provided by <a href='https://www.linkedin.com/in/faudi/'>Jeff Faudi</a> \ |
|
and <a href='https://www.dl4eo.com/'>DL4EO</a>. The demonstration images are Pléiades \ |
|
images provided by CNES with distribution by Airbus DS. The model architecture and weights \ |
|
are provided <a href='https://github.com/IDEA-Research/GroundingDINO'>Grounding DINO</a>. \ |
|
The model has not been trained specifically on satellite imagery and should be finetuned for this task. \ |
|
This is for demonstration only. Please contact <a href='mailto:[email protected]'>me</a> \ |
|
for more information on how you could get access to a commercial model or API. </p>") |
|
|
|
demo.launch( |
|
inline=False, |
|
show_api=False, |
|
debug=False |
|
) |
|
|