Nekshay commited on
Commit
6d38c83
1 Parent(s): 04e6b4c

Update convert_tflite_2_onnx.py

Browse files
Files changed (1) hide show
  1. convert_tflite_2_onnx.py +54 -0
convert_tflite_2_onnx.py CHANGED
@@ -40,3 +40,57 @@ output = session.run(None, {input_name: image_data})
40
  print(output)
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  print(output)
41
 
42
 
43
+ import onnxruntime as ort
44
+ import numpy as np
45
+ from PIL import Image
46
+
47
+ # Load ONNX model
48
+ onnx_model_path = 'model.onnx'
49
+ session = ort.InferenceSession(onnx_model_path)
50
+
51
+ # Function to preprocess a single image (resize and normalize)
52
+ def preprocess_image(image_path, input_size=(320, 320)):
53
+ image = Image.open(image_path).resize(input_size) # Resize to match model input size
54
+ image_data = np.array(image).astype('float32') # Convert to float32
55
+ image_data = np.expand_dims(image_data, axis=0) # Add batch dimension (1, height, width, channels)
56
+ return image_data
57
+
58
+ # Prepare a batch of images
59
+ image_paths = ['image1.jpg', 'image2.jpg', 'image3.jpg'] # List of image file paths
60
+ batch_size = len(image_paths)
61
+
62
+ # Preprocess each image and stack them into a batch
63
+ batch_images = np.vstack([preprocess_image(image_path) for image_path in image_paths])
64
+
65
+ # Check input name from the ONNX model
66
+ input_name = session.get_inputs()[0].name
67
+
68
+ # Run batch inference
69
+ outputs = session.run(None, {input_name: batch_images})
70
+
71
+ # Postprocessing: Extract scores, bounding boxes, and labels for each image in the batch
72
+ scores_batch, bboxes_batch, labels_batch = outputs[0], outputs[1], outputs[2]
73
+
74
+ # Iterate over the batch of results and filter based on score threshold
75
+ score_threshold = 0.5
76
+
77
+ for i in range(batch_size):
78
+ scores = scores_batch[i] # Scores for i-th image
79
+ bboxes = bboxes_batch[i] # Bounding boxes for i-th image
80
+ labels = labels_batch[i] # Labels for i-th image
81
+
82
+ # Filter indices where scores are greater than the threshold
83
+ valid_indices = np.where(scores > score_threshold)
84
+
85
+ # Filter the outputs based on valid indices
86
+ filtered_scores = scores[valid_indices]
87
+ filtered_bboxes = bboxes[valid_indices]
88
+ filtered_labels = labels[valid_indices]
89
+
90
+ print(f"Image {i+1}:")
91
+ print("Filtered Scores:", filtered_scores)
92
+ print("Filtered Bounding Boxes:", filtered_bboxes)
93
+ print("Filtered Labels:", filtered_labels)
94
+ print('---')
95
+
96
+