Update convert_tflite_2_onnx.py
Browse files- convert_tflite_2_onnx.py +54 -0
convert_tflite_2_onnx.py
CHANGED
@@ -40,3 +40,57 @@ output = session.run(None, {input_name: image_data})
|
|
40 |
print(output)
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
print(output)
|
41 |
|
42 |
|
43 |
+
import onnxruntime as ort
|
44 |
+
import numpy as np
|
45 |
+
from PIL import Image
|
46 |
+
|
47 |
+
# Load ONNX model
|
48 |
+
onnx_model_path = 'model.onnx'
|
49 |
+
session = ort.InferenceSession(onnx_model_path)
|
50 |
+
|
51 |
+
# Function to preprocess a single image (resize and normalize)
|
52 |
+
def preprocess_image(image_path, input_size=(320, 320)):
|
53 |
+
image = Image.open(image_path).resize(input_size) # Resize to match model input size
|
54 |
+
image_data = np.array(image).astype('float32') # Convert to float32
|
55 |
+
image_data = np.expand_dims(image_data, axis=0) # Add batch dimension (1, height, width, channels)
|
56 |
+
return image_data
|
57 |
+
|
58 |
+
# Prepare a batch of images
|
59 |
+
image_paths = ['image1.jpg', 'image2.jpg', 'image3.jpg'] # List of image file paths
|
60 |
+
batch_size = len(image_paths)
|
61 |
+
|
62 |
+
# Preprocess each image and stack them into a batch
|
63 |
+
batch_images = np.vstack([preprocess_image(image_path) for image_path in image_paths])
|
64 |
+
|
65 |
+
# Check input name from the ONNX model
|
66 |
+
input_name = session.get_inputs()[0].name
|
67 |
+
|
68 |
+
# Run batch inference
|
69 |
+
outputs = session.run(None, {input_name: batch_images})
|
70 |
+
|
71 |
+
# Postprocessing: Extract scores, bounding boxes, and labels for each image in the batch
|
72 |
+
scores_batch, bboxes_batch, labels_batch = outputs[0], outputs[1], outputs[2]
|
73 |
+
|
74 |
+
# Iterate over the batch of results and filter based on score threshold
|
75 |
+
score_threshold = 0.5
|
76 |
+
|
77 |
+
for i in range(batch_size):
|
78 |
+
scores = scores_batch[i] # Scores for i-th image
|
79 |
+
bboxes = bboxes_batch[i] # Bounding boxes for i-th image
|
80 |
+
labels = labels_batch[i] # Labels for i-th image
|
81 |
+
|
82 |
+
# Filter indices where scores are greater than the threshold
|
83 |
+
valid_indices = np.where(scores > score_threshold)
|
84 |
+
|
85 |
+
# Filter the outputs based on valid indices
|
86 |
+
filtered_scores = scores[valid_indices]
|
87 |
+
filtered_bboxes = bboxes[valid_indices]
|
88 |
+
filtered_labels = labels[valid_indices]
|
89 |
+
|
90 |
+
print(f"Image {i+1}:")
|
91 |
+
print("Filtered Scores:", filtered_scores)
|
92 |
+
print("Filtered Bounding Boxes:", filtered_bboxes)
|
93 |
+
print("Filtered Labels:", filtered_labels)
|
94 |
+
print('---')
|
95 |
+
|
96 |
+
|