SaladSlayer00 commited on
Commit
cee3072
1 Parent(s): d7feb62
Files changed (1) hide show
  1. app.py +30 -19
app.py CHANGED
@@ -1,34 +1,45 @@
1
- import gradio as gr
 
2
  from transformers import DetrImageProcessor, DetrForObjectDetection
3
  from PIL import Image
4
- import torch
5
- import cv2
6
  import numpy as np
7
 
8
- def process_image(input_image):
 
 
9
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
10
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
11
- yellow = (0, 255, 255) # BGR
12
- font = cv2.FONT_HERSHEY_SIMPLEX
13
  stroke = 2
14
 
15
- # Convert PIL image to OpenCV format
16
- img = np.array(input_image)
17
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
18
 
19
- # Process the image
20
- inputs = processor(images=input_image, return_tensors="pt")
21
  outputs = model(**inputs)
22
- target_sizes = torch.tensor([input_image.size[::-1]])
23
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
24
 
 
25
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
26
- cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), yellow, stroke)
27
- cv2.putText(img, model.config.id2label[label.item()], (int(box[0]), int(box[1]-10)), font, 1, yellow, stroke, cv2.LINE_AA)
 
 
 
28
 
29
- # Convert back to PIL image
30
- return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
 
 
 
 
 
 
 
 
 
31
 
32
- # Create Gradio interface
33
- iface = gr.Interface(fn=process_image, inputs=gr.inputs.Image(), outputs="image")
34
- iface.launch()
 
1
+ import cv2
2
+ import torch
3
  from transformers import DetrImageProcessor, DetrForObjectDetection
4
  from PIL import Image
5
+ import gradio as gr
 
6
  import numpy as np
7
 
8
+ # Function for DETR object detection
9
+ def inf(_, webcam_image):
10
+ # Initialize model and processor
11
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
13
+ yellow = (0, 255, 255) # in BGR
 
14
  stroke = 2
15
 
16
+ # Convert the webcam image to the correct format
17
+ img = cv2.cvtColor(webcam_image, cv2.COLOR_BGR2RGB)
18
+ pil_image = Image.fromarray(img)
19
 
20
+ # Process the image with DETR
21
+ inputs = processor(images=pil_image, return_tensors="pt")
22
  outputs = model(**inputs)
23
+ target_sizes = torch.tensor([pil_image.size[::-1]])
24
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
25
 
26
+ # Draw bounding boxes and labels
27
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
+ cv2.rectangle(webcam_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), yellow, stroke)
29
+ cv2.putText(webcam_image, model.config.id2label[label.item()], (int(box[0]), int(box[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 1, yellow, stroke, cv2.LINE_AA)
30
+
31
+ # Return the processed image
32
+ return webcam_image
33
 
34
+ # Gradio interface with webcam support
35
+ demo = gr.Interface(
36
+ inf,
37
+ [
38
+ gr.Markdown("## Real-Time Object Detection"),
39
+ gr.Image(source="webcam", streaming=True)
40
+ ],
41
+ "image",
42
+ live=True
43
+ )
44
+ demo.launch(server_name="0.0.0.0", share=True)
45