File size: 1,954 Bytes
d2e9ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (C) 2022, Pyronear.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import argparse
import json

import gradio as gr
import numpy as np
import onnxruntime
from huggingface_hub import hf_hub_download
from PIL import Image

REPO = "pyronear/rexnet1_0x"


# Download model config & checkpoint
with open(hf_hub_download(REPO, filename="config.json"), "rb") as f:
    cfg = json.load(f)

ort_session = onnxruntime.InferenceSession(hf_hub_download(REPO, filename="model.onnx"))

def preprocess_image(pil_img: Image.Image) -> np.ndarray:
    """Preprocess an image for inference

    Args:
        pil_img: a valid pillow image

    Returns:
        the resized and normalized image of shape (1, C, H, W)
    """

    # Resizing
    img = pil_img.resize(cfg["input_shape"][-2:], Image.BILINEAR)
    # (H, W, C) --> (C, H, W)
    img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255
    # Normalization
    img -= np.array(cfg["mean"])[:, None, None]
    img /= np.array(cfg["std"])[:, None, None]

    return img[None, ...]

def predict(image):
    # Preprocessing
    np_img = preprocess_image(image)
    ort_input = {ort_session.get_inputs()[0].name: np_img}

    # Inference
    ort_out = ort_session.run(None, ort_input)
    # Post-processing
    probs = 1 / (1 + np.exp(-ort_out[0][0]))

    return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs)}


img = gr.inputs.Image(type="pil")
outputs = gr.outputs.Label(num_top_classes=1)


gr.Interface(
    fn=predict,
    inputs=[img],
    outputs=outputs,
    title="PyroVision: image classification demo",
    article=(
        "<p style='text-align: center'><a href='https://github.com/pyronear/pyro-vision'>"
        "Github Repo</a> | "
        "<a href='https://pyronear.org/pyro-vision/'>Documentation</a></p>"
    ),
    live=True,
).launch()