freddyaboulton HF staff commited on
Commit
6a95f1f
1 Parent(s): cbc2dd6
Files changed (3) hide show
  1. app.py +65 -21
  2. draw_boxes.py +41 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,21 +1,67 @@
1
  import spaces
2
  import gradio as gr
3
  import cv2
4
- import tempfile
5
- from ultralytics import YOLOv10
6
 
7
- model = YOLOv10.from_pretrained(f'jameslahm/yolov10n')
 
 
 
 
 
8
 
9
  @spaces.GPU
10
- def yolov10_inference(image, conf_threshold):
11
- width, _ = image.size
12
- import time
13
- start = time.time()
14
- results = model.predict(source=image, imgsz=width, conf=conf_threshold)
15
- end = time.time()
16
- annotated_image = results[0].plot()
17
- print("time", end - start)
18
- return annotated_image[:, :, ::-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
@@ -26,18 +72,18 @@ with gr.Blocks(css=css) as app:
26
  gr.HTML(
27
  """
28
  <h1 style='text-align: center'>
29
- YOLOv10 Webcam Stream
30
  </h1>
31
  """)
32
  gr.HTML(
33
  """
34
  <h3 style='text-align: center'>
35
- <a href='https://arxiv.org/abs/2405.14458' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yolov10' target='_blank'>github</a>
36
  </h3>
37
  """)
38
  with gr.Column(elem_classes=["my-column"]):
39
  with gr.Group(elem_classes=["my-group"]):
40
- image = gr.Image(type="pil", label="Image", sources="webcam")
41
  conf_threshold = gr.Slider(
42
  label="Confidence Threshold",
43
  minimum=0.0,
@@ -45,12 +91,10 @@ with gr.Blocks(css=css) as app:
45
  step=0.05,
46
  value=0.30,
47
  )
48
- image.stream(
49
- fn=yolov10_inference,
50
- inputs=[image, conf_threshold],
51
- outputs=[image],
52
- stream_every=0.1,
53
- time_limit=30
54
  )
55
 
56
  if __name__ == '__main__':
 
1
  import spaces
2
  import gradio as gr
3
  import cv2
4
+ from PIL import Image
 
5
 
6
+ from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
7
+
8
+ from draw_boxes import draw_bounding_boxes
9
+
10
+ image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
11
+ model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
12
 
13
  @spaces.GPU
14
+ def stream_object_detection(video, conf_threshold):
15
+ cap = cv2.VideoCapture(video)
16
+
17
+ video_codec = cv2.VideoWriter_fourcc(*"x264") # type: ignore
18
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
19
+ desired_fps = fps // 3
20
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
21
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
22
+
23
+ iterating, frame = cap.read()
24
+
25
+ n_frames = 0
26
+ n_chunks = 0
27
+ name = str(current_dir / f"output_{n_chunks}.ts")
28
+ segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore
29
+ batch = []
30
+
31
+ while iterating:
32
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
33
+ if n_frames % 3 == 0:
34
+ batch.append(frame)
35
+ if len(batch) == desired_fps:
36
+ inputs = image_processor(images=batch, return_tensors="pt")
37
+
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
+
41
+ boxes = image_processor.post_process_object_detection(
42
+ outputs,
43
+ target_sizes=torch.tensor([batch[0].shape[::-1]] * len(batch)),
44
+ threshold=conf_threshold)
45
+
46
+ for array, box in zip(batch, boxes):
47
+ pil_image = draw_bounding_boxes(Image.from_array(array), boxes[0], model, 0.3)
48
+ frame = numpy.array(pil_image)
49
+ # Convert RGB to BGR
50
+ frame = frame[:, :, ::-1].copy()
51
+ segment_file.write(frame)
52
+
53
+ segment_file.release()
54
+ n_frames = 0
55
+ n_chunks += 1
56
+ yield name
57
+ name = str(current_dir / f"output_{n_chunks}.ts")
58
+ segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore
59
+
60
+ iterating, frame = cap.read()
61
+ n_frames += 1
62
+
63
+ segment_file.release()
64
+ yield name
65
 
66
 
67
  css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
 
72
  gr.HTML(
73
  """
74
  <h1 style='text-align: center'>
75
+ Video Object Detection with RT-DETR
76
  </h1>
77
  """)
78
  gr.HTML(
79
  """
80
  <h3 style='text-align: center'>
81
+ <a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a>
82
  </h3>
83
  """)
84
  with gr.Column(elem_classes=["my-column"]):
85
  with gr.Group(elem_classes=["my-group"]):
86
+ video = gr.Video(label="Video Source")
87
  conf_threshold = gr.Slider(
88
  label="Confidence Threshold",
89
  minimum=0.0,
 
91
  step=0.05,
92
  value=0.30,
93
  )
94
+ video.upload(
95
+ fn=stream_object_detection,
96
+ inputs=[video, conf_threshold],
97
+ outputs=[video],
 
 
98
  )
99
 
100
  if __name__ == '__main__':
draw_boxes.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
+ import numpy as np
3
+ import colorsys
4
+
5
+ def get_color(label):
6
+ # Simple hash function to generate consistent colors for each label
7
+ hash_value = hash(label)
8
+ hue = (hash_value % 100) / 100.0
9
+ saturation = 0.7
10
+ value = 0.9
11
+ rgb = colorsys.hsv_to_rgb(hue, saturation, value)
12
+ return tuple(int(x * 255) for x in rgb)
13
+
14
+ def draw_bounding_boxes(image: Image, results: dict, model, threshold=0.3):
15
+ draw = ImageDraw.Draw(image)
16
+ font = ImageFont.load_default()
17
+
18
+ for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
19
+ if score > threshold:
20
+ label = model.config.id2label[label_id.item()]
21
+ box = [round(i, 2) for i in box.tolist()]
22
+ color = get_color(label)
23
+
24
+ # Draw bounding box
25
+ draw.rectangle(box, outline=color, width=3)
26
+
27
+ # Prepare text
28
+ text = f"{label}: {score:.2f}"
29
+ text_bbox = draw.textbbox((0, 0), text, font=font)
30
+ text_width = text_bbox[2] - text_bbox[0]
31
+ text_height = text_bbox[3] - text_bbox[1]
32
+
33
+ # Draw text background
34
+ draw.rectangle([box[0], box[1] - text_height - 4, box[0] + text_width, box[1]], fill=color)
35
+
36
+ # Draw text
37
+ draw.text((box[0], box[1] - text_height - 4), text, fill="white", font=font)
38
+
39
+ return image
40
+
41
+ import numpy as np
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  safetensors==0.4.3
2
- git+https://github.com/THU-MIG/yolov10.git
3
  gradio-client @ git+https://github.com/gradio-app/gradio@66349fe26827e3a3c15b738a1177e95fec7f5554#subdirectory=client/python
4
  https://gradio-pypi-previews.s3.amazonaws.com/66349fe26827e3a3c15b738a1177e95fec7f5554/gradio-4.42.0-py3-none-any.whl
 
1
  safetensors==0.4.3
2
+ transformers
3
  gradio-client @ git+https://github.com/gradio-app/gradio@66349fe26827e3a3c15b738a1177e95fec7f5554#subdirectory=client/python
4
  https://gradio-pypi-previews.s3.amazonaws.com/66349fe26827e3a3c15b738a1177e95fec7f5554/gradio-4.42.0-py3-none-any.whl