SaladSlayer00 commited on
Commit
010f3a7
1 Parent(s): cee3072

the REAL APP

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -1,45 +1,42 @@
1
- import cv2
2
- import torch
3
  from transformers import DetrImageProcessor, DetrForObjectDetection
4
  from PIL import Image
5
- import gradio as gr
 
6
  import numpy as np
7
 
8
- # Function for DETR object detection
9
- def inf(_, webcam_image):
10
- # Initialize model and processor
11
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
13
- yellow = (0, 255, 255) # in BGR
14
- stroke = 2
15
 
16
- # Convert the webcam image to the correct format
17
- img = cv2.cvtColor(webcam_image, cv2.COLOR_BGR2RGB)
 
18
  pil_image = Image.fromarray(img)
19
 
20
- # Process the image with DETR
21
  inputs = processor(images=pil_image, return_tensors="pt")
22
  outputs = model(**inputs)
23
  target_sizes = torch.tensor([pil_image.size[::-1]])
24
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
25
 
26
- # Draw bounding boxes and labels
27
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
- cv2.rectangle(webcam_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), yellow, stroke)
29
- cv2.putText(webcam_image, model.config.id2label[label.item()], (int(box[0]), int(box[1]-10)), cv2.FONT_HERSHEY_SIMPLEX, 1, yellow, stroke, cv2.LINE_AA)
 
 
30
 
31
- # Return the processed image
32
- return webcam_image
 
33
 
34
- # Gradio interface with webcam support
35
  demo = gr.Interface(
36
- inf,
37
- [
38
- gr.Markdown("## Real-Time Object Detection"),
39
- gr.Image(source="webcam", streaming=True)
40
- ],
41
- "image",
42
  live=True
43
  )
44
- demo.launch(server_name="0.0.0.0", share=True)
45
 
 
 
1
+ import gradio as gr
 
2
  from transformers import DetrImageProcessor, DetrForObjectDetection
3
  from PIL import Image
4
+ import torch
5
+ import cv2
6
  import numpy as np
7
 
8
+ # Initialize the model and processor
9
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
10
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
 
 
 
 
11
 
12
+ def process_frame(webcam_image):
13
+ # Convert the webcam image from Gradio to the format expected by the model
14
+ img = cv2.cvtColor(np.array(webcam_image), cv2.COLOR_RGB2BGR)
15
  pil_image = Image.fromarray(img)
16
 
17
+ # Process the image
18
  inputs = processor(images=pil_image, return_tensors="pt")
19
  outputs = model(**inputs)
20
  target_sizes = torch.tensor([pil_image.size[::-1]])
21
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
22
 
23
+ # Draw bounding boxes and labels on the image
24
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
25
+ box = [int(round(i, 0)) for i in box.tolist()]
26
+ cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), (0, 255, 255), 2)
27
+ label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}"
28
+ cv2.putText(img, label_text, (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)
29
 
30
+ # Convert back to RGB for Gradio display
31
+ processed_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
32
+ return Image.fromarray(processed_image)
33
 
34
+ # Gradio interface
35
  demo = gr.Interface(
36
+ fn=process_frame,
37
+ inputs=gr.Image(source="webcam", streaming=True),
38
+ outputs="image",
 
 
 
39
  live=True
40
  )
 
41
 
42
+ demo.launch()