picekl commited on
Commit
387d5e9
1 Parent(s): 89347b6

feat: resizing before inference

Browse files
Files changed (1) hide show
  1. script.py +17 -2
script.py CHANGED
@@ -21,8 +21,20 @@ class ONNXWorker:
21
  providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
22
  else:
23
  providers = ["CPUExecutionProvider"]
 
 
24
  self.ort_session = ort.InferenceSession(onnx_path, providers=providers)
25
 
 
 
 
 
 
 
 
 
 
 
26
  def predict_image(self, image: np.ndarray) -> list():
27
  """Run inference using ONNX runtime.
28
 
@@ -44,8 +56,11 @@ def make_submission(test_metadata, model_path, output_csv_path="./submission.csv
44
 
45
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
46
  image_path = os.path.join(images_root_path, row.filename)
47
- test_image = np.asarray(Image.open(image_path).convert("RGB"))
48
- logits = model.predict_image(test_image)
 
 
 
49
 
50
  predictions.append(np.argmax(logits))
51
 
 
21
  providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
22
  else:
23
  providers = ["CPUExecutionProvider"]
24
+
25
+ print(f"Using {providers}")
26
  self.ort_session = ort.InferenceSession(onnx_path, providers=providers)
27
 
28
+ def _resize_image(self, image: np.ndarray) -> np.ndarray:
29
+ """
30
+
31
+ :param image:
32
+ :return:
33
+ """
34
+
35
+ newsize = (300, 300)
36
+ im1 = im1.resize(newsize)
37
+
38
  def predict_image(self, image: np.ndarray) -> list():
39
  """Run inference using ONNX runtime.
40
 
 
56
 
57
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
58
  image_path = os.path.join(images_root_path, row.filename)
59
+
60
+ test_image = Image.open(image_path).convert("RGB")
61
+ test_image_resized = np.asarray(test_image.resize((256, 256)))
62
+
63
+ logits = model.predict_image(test_image_resized)
64
 
65
  predictions.append(np.argmax(logits))
66