picekl commited on
Commit
079a64b
1 Parent(s): 80b778b

feat:sample script with onnx model prediction

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ SnakeCLEF2024-TestMetadata.csv filter=lfs diff=lfs merge=lfs -text
37
+ swinv2_tiny_window16_256.onnx filter=lfs diff=lfs merge=lfs -text
SnakeCLEF2024-TestMetadata.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84a019142f4f674985599deaf0d705d3d30069eaddc1191b247d5cefd779f08a
3
+ size 404453
script.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import traceback
3
+ import pandas as pd
4
+ import numpy as np
5
+ from PIL import Image
6
+ import onnxruntime as ort
7
+ import os
8
+ from tqdm import tqdm
9
+
10
+
11
+ def is_gpu_available():
12
+ """Check if the python package `onnxruntime-gpu` is installed."""
13
+ return ort.get_device() == "GPU"
14
+
15
+
16
+ class ONNXWorker:
17
+ """Run inference using ONNX runtime."""
18
+
19
+ def __init__(self, onnx_path: str):
20
+ print("Setting up ONNX runtime session.")
21
+ self.use_gpu = is_gpu_available()
22
+ if self.use_gpu:
23
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
24
+ else:
25
+ providers = ["CPUExecutionProvider"]
26
+ self.ort_session = ort.InferenceSession(onnx_path, providers=providers)
27
+
28
+ def predict_image(self, image: np.ndarray) -> list():
29
+ """Run inference using ONNX runtime.
30
+
31
+ :param image: Input image as numpy array.
32
+ :return: A list with logits and confidences.
33
+ """
34
+
35
+ logits, _ = self.ort_session.run(None, {"input": image.astype(dtype=np.uint8)})
36
+
37
+ return logits.tolist()
38
+
39
+
40
+ def make_submission(test_metadata, model_path, output_csv_path="./submission.csv", data_root_path="/tmp/data"):
41
+ """Make submission with given """
42
+
43
+ model = ONNXWorker(model_path)
44
+
45
+ predictions = []
46
+
47
+ for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
48
+ image_path = os.path.join(data_root_path, row.filename)
49
+ test_image = np.asarray(Image.open(image_path).convert("RGB"))
50
+ logits = model.predict_image(test_image)
51
+
52
+ predictions.append(np.argmax(logits))
53
+
54
+ test_metadata["class_id"] = predictions
55
+
56
+ user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
57
+ user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
58
+
59
+
60
+ if __name__ == "__main__":
61
+
62
+ ONNX_MODEL_PATH = "./swinv2_tiny_window16_256.onnx"
63
+
64
+ metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv"
65
+ test_metadata = pd.read_csv(metadata_file_path)
66
+
67
+ make_submission(
68
+ test_metadata=test_metadata,
69
+ model_path=ONNX_MODEL_PATH,
70
+ )
swinv2_tiny_window16_256.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff53172bace1485a6e582e1c3dc9719fa4b6acf7ba4481061a220220faaa2eb2
3
+ size 122122210