jrsimuix commited on
Commit
8614e23
1 Parent(s): 2830133

scripts to validate and test

Browse files
Files changed (2) hide show
  1. python/test_image.py +19 -0
  2. python/verify_onnx_model.py +31 -0
python/test_image.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ # Load the ONNX model
6
+ session = ort.InferenceSession('./saved-model/model.onnx')
7
+
8
+ # Get input and output names
9
+ input_name = session.get_inputs()[0].name
10
+ output_name = session.get_outputs()[0].name
11
+
12
+ # Load and preprocess the image
13
+ img = Image.open('./training_images/shirt/00e745c9-97d9-429d-8c3f-d3db7a2d2991.jpg').resize((128, 128))
14
+ img_array = np.array(img).astype(np.float32) / 255.0 # Normalize pixel values to [0, 1]
15
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
16
+
17
+ # Run inference
18
+ outputs = session.run([output_name], {input_name: img_array})
19
+ print(f"Inference outputs: {outputs}")
python/verify_onnx_model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+
4
+ def verify_onnx_model(onnx_model_path):
5
+ # Load the ONNX model
6
+ onnx_session = ort.InferenceSession(onnx_model_path)
7
+
8
+ # Display model input details
9
+ input_name = onnx_session.get_inputs()[0].name
10
+ input_shape = onnx_session.get_inputs()[0].shape
11
+ input_type = onnx_session.get_inputs()[0].type
12
+ print(f"Input Name: {input_name}, Shape: {input_shape}, Type: {input_type}")
13
+
14
+ # Display model output details
15
+ output_name = onnx_session.get_outputs()[0].name
16
+ output_shape = onnx_session.get_outputs()[0].shape
17
+ output_type = onnx_session.get_outputs()[0].type
18
+ print(f"Output Name: {output_name}, Shape: {output_shape}, Type: {output_type}")
19
+
20
+ # Generate a dummy input matching the input shape
21
+ # Assuming input shape is [None, 128, 128, 3], where None is the batch size
22
+ dummy_input = np.random.rand(1, 128, 128, 3).astype(np.float32)
23
+
24
+ # Perform inference
25
+ result = onnx_session.run([output_name], {input_name: dummy_input})
26
+ print(f"Inference Result: {result}")
27
+
28
+ # Path to the ONNX model
29
+ onnx_model_path = './model.onnx'
30
+
31
+ verify_onnx_model(onnx_model_path)