freddyaboulton HF staff commited on
Commit
790227b
1 Parent(s): 66947f7
Files changed (1) hide show
  1. app.py +34 -13
app.py CHANGED
@@ -8,10 +8,13 @@ from PIL import Image, ImageDraw, ImageFont
8
  image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
9
  model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
10
 
 
11
  def draw_bounding_boxes(image, results, model, threshold=0.3):
12
  draw = ImageDraw.Draw(image)
13
  for result in results:
14
- for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
 
 
15
  if score > threshold:
16
  label = model.config.id2label[label_id.item()]
17
  box = [round(i) for i in box.tolist()]
@@ -22,13 +25,14 @@ def draw_bounding_boxes(image, results, model, threshold=0.3):
22
 
23
  @spaces.GPU
24
  def inference(image, conf_threshold):
25
-
26
  inputs = image_processor(images=image, return_tensors="pt")
27
 
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
 
31
- results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)
 
 
32
 
33
  return draw_bounding_boxes(image, results, model, threshold=conf_threshold)
34
 
@@ -37,7 +41,14 @@ def app():
37
  with gr.Blocks():
38
  with gr.Row():
39
  with gr.Column():
40
- image = gr.Image(type="pil", label="Image", visible=True, sources="webcam", height=500, width=500)
 
 
 
 
 
 
 
41
  conf_threshold = gr.Slider(
42
  label="Confidence Threshold",
43
  minimum=0.0,
@@ -50,10 +61,11 @@ def app():
50
  inputs=[image, conf_threshold],
51
  outputs=[image],
52
  stream_every=0.2,
53
- time_limit=30
54
  )
55
 
56
- css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
 
57
  .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
58
 
59
  with gr.Blocks(css=css) as app:
@@ -62,16 +74,25 @@ with gr.Blocks(css=css) as app:
62
  <h1 style='text-align: center'>
63
  Near Real-Time Webcam Stream with RT-DETR
64
  </h1>
65
- """)
 
66
  gr.HTML(
67
  """
68
  <h3 style='text-align: center'>
69
  <a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://github.com/lyuwenyu/RT-DETR' target='_blank'>github</a>
70
  </h3>
71
- """)
72
- with gr.Column(elem_classes=['my-column']):
73
- with gr.Group(elem_classes=["my-group"]):
74
- image = gr.Image(type="pil", label="Image", visible=True, sources="webcam", height=500, width=500)
 
 
 
 
 
 
 
 
75
  conf_threshold = gr.Slider(
76
  label="Confidence Threshold",
77
  minimum=0.0,
@@ -84,7 +105,7 @@ with gr.Blocks(css=css) as app:
84
  inputs=[image, conf_threshold],
85
  outputs=[image],
86
  stream_every=0.2,
87
- time_limit=30
88
  )
89
- if __name__ == '__main__':
90
  app.launch()
 
8
  image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
9
  model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
10
 
11
+
12
  def draw_bounding_boxes(image, results, model, threshold=0.3):
13
  draw = ImageDraw.Draw(image)
14
  for result in results:
15
+ for score, label_id, box in zip(
16
+ result["scores"], result["labels"], result["boxes"]
17
+ ):
18
  if score > threshold:
19
  label = model.config.id2label[label_id.item()]
20
  box = [round(i) for i in box.tolist()]
 
25
 
26
  @spaces.GPU
27
  def inference(image, conf_threshold):
 
28
  inputs = image_processor(images=image, return_tensors="pt")
29
 
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
 
33
+ results = image_processor.post_process_object_detection(
34
+ outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3
35
+ )
36
 
37
  return draw_bounding_boxes(image, results, model, threshold=conf_threshold)
38
 
 
41
  with gr.Blocks():
42
  with gr.Row():
43
  with gr.Column():
44
+ image = gr.Image(
45
+ type="pil",
46
+ label="Image",
47
+ visible=True,
48
+ sources="webcam",
49
+ height=500,
50
+ width=500,
51
+ )
52
  conf_threshold = gr.Slider(
53
  label="Confidence Threshold",
54
  minimum=0.0,
 
61
  inputs=[image, conf_threshold],
62
  outputs=[image],
63
  stream_every=0.2,
64
+ time_limit=30,
65
  )
66
 
67
+
68
+ css = """.my-group {max-width: 600px !important; max-height: 600 !important;}
69
  .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
70
 
71
  with gr.Blocks(css=css) as app:
 
74
  <h1 style='text-align: center'>
75
  Near Real-Time Webcam Stream with RT-DETR
76
  </h1>
77
+ """
78
+ )
79
  gr.HTML(
80
  """
81
  <h3 style='text-align: center'>
82
  <a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://github.com/lyuwenyu/RT-DETR' target='_blank'>github</a>
83
  </h3>
84
+ """
85
+ )
86
+ with gr.Column(elem_classes=["my-column"]):
87
+ with gr.Group(elem_classes=["my-group"]):
88
+ image = gr.Image(
89
+ type="pil",
90
+ label="Image",
91
+ visible=True,
92
+ sources="webcam",
93
+ height=500,
94
+ width=500,
95
+ )
96
  conf_threshold = gr.Slider(
97
  label="Confidence Threshold",
98
  minimum=0.0,
 
105
  inputs=[image, conf_threshold],
106
  outputs=[image],
107
  stream_every=0.2,
108
+ time_limit=30,
109
  )
110
+ if __name__ == "__main__":
111
  app.launch()