freddyaboulton's picture
compile
8fbff22
raw
history blame
3.18 kB
import spaces
import gradio as gr
import cv2
import tempfile
from PIL import Image, ImageDraw, ImageFont
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
import torch
import requests
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd", torch_dtype=torch.float16).to("cuda")
model = torch.compile(model, mode="reduce-overhead")
# Compile by running inference
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, return_tensors="pt").to("cuda", torch.float16)
with torch.no_grad():
outputs = model(**inputs)
def draw_bounding_boxes(image, results, model, threshold=0.3):
draw = ImageDraw.Draw(image)
for result in results:
for score, label_id, box in zip(
result["scores"], result["labels"], result["boxes"]
):
if score > threshold:
label = model.config.id2label[label_id.item()]
box = [round(i) for i in box.tolist()]
draw.rectangle(box, outline="red", width=3)
draw.text((box[0], box[1]), f"{label}: {score:.2f}", fill="red")
return image
import time
@spaces.GPU
def inference(image, conf_threshold):
inputs = image_processor(images=image, return_tensors="pt")
start = time.time()
with torch.no_grad():
outputs = model(**inputs)
results = image_processor.post_process_object_detection(
outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=conf_threshold
)
end = time.time()
print("time: ", end - start)
bbs = draw_bounding_boxes(image, results, model, threshold=conf_threshold)
print("bbs: ", time.time() - end)
return bbs
css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
with gr.Blocks(css=css) as app:
gr.HTML(
"""
<h1 style='text-align: center'>
Near Real-Time Webcam Stream with RT-DETR
</h1>
"""
)
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://github.com/lyuwenyu/RT-DETR' target='_blank'>github</a>
</h3>
"""
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = gr.Image(
type="pil",
label="Image",
sources="webcam",
)
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.85,
)
image.stream(
fn=inference,
inputs=[image, conf_threshold],
outputs=[image],
stream_every=0.1,
time_limit=30,
)
if __name__ == "__main__":
app.launch()