File size: 4,218 Bytes
370415f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d16f83e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import annotations

from importlib import import_module
from pathlib import Path

import gradio as gr
import gradio.inputs
import gradio.outputs
import numpy as np
import os
from anomalib.deploy import Inferencer


def get_inferencer(weight_path: Path, metadata_path: Path | None = None) -> Inferencer:
    """Parse args and open inferencer.

    Args:
        weight_path (Path): Path to model weights.
        metadata_path (Path | None, optional): Metadata is required for OpenVINO models. Defaults to None.

    Raises:
        ValueError: If unsupported model weight is passed.

    Returns:
        Inferencer: Torch or OpenVINO inferencer.
    """

    inferencer: Inferencer
    module = import_module("anomalib.deploy")

    openvino_inferencer = getattr(module, "OpenVINOInferencer")
    print(f"weight path: {weight_path}")
    print(f"metadata path: {metadata_path}")
    inferencer = openvino_inferencer(path=weight_path, metadata_path=metadata_path)


    return inferencer


def infer(radio: str, image: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Inference function, return anomaly map, score, heat map, prediction mask ans visualisation.

    Args:
        image (np.ndarray): image to compute
        inferencer (Inferencer): model inferencer

    Returns:
        tuple[np.ndarray, float, np.ndarray, np.ndarray, np.ndarray]:
        heat_map, pred_mask, segmentation result.
    """
    # Perform inference for the given image.
    print(f"Radio Value: {radio.lower()}")
    print(f"{os.getcwd()}")
    weight_path = f"cfa/mvtec/{radio.lower()}/run/weights/openvino/model.onnx"
    metadata_path = f"cfa/mvtec/{radio.lower()}/run/weights/openvino/metadata.json"
    inferencer = get_inferencer(weight_path, metadata_path)

    predictions = inferencer.predict(image=image)
    return (predictions.heat_map, predictions.pred_mask, predictions.segmentations)


if __name__ == "__main__":

    interface = gr.Interface(
        fn=lambda radio, image: infer(radio, image),
        inputs=[
            gr.Radio(
                [
                    "Bottle",
                    "Cable",
                    "Capsule",
                    "Carpet",
                    "Grid",
                    "Hazelnut",
                    "Leather",
                    "Metal_nut",
                    "Pill",
                    "Screw",
                    "Tile",
                    "Toothbrush",
                    "Transistor",
                    "Wood",
                    "Zipper",
                ],
                label="MVTEC Class Name",
                value="Bottle",
            ).style(height=400),
            gradio.inputs.Image(
                shape=None, image_mode="RGB", source="upload", tool="editor", type="numpy", label="Image"
            ).style(height=350),
        ],
        outputs=[
            gradio.outputs.Image(type="numpy", label="Predicted Heat Map").style(height=200),
            gradio.outputs.Image(type="numpy", label="Predicted Mask").style(height=200),
            gradio.outputs.Image(type="numpy", label="Segmentation Result").style(height=200),
        ],
        examples=[
            ["Bottle", "sample_images/bottle.png"],
            ["Cable", "sample_images/cable.png"],
            ["Capsule", "sample_images/capsule.png"],
            ["Carpet", "sample_images/carpet.png"],
            ["Grid", "sample_images/grid.png"],
            ["Hazelnut", "sample_images/hazelnut.png"],
            ["Leather", "sample_images/leather.png"],
            ["Metal_nut", "sample_images/metal_nut.png"],
            ["Pill", "sample_images/pill.png"],
            ["Screw", "sample_images/screw.png"],
            ["Tile", "sample_images/tile.png"],
            ["Toothbrush", "sample_images/toothbrush.png"],
            ["Transistor", "sample_images/transistor.png"],
            ["Wood", "sample_images/wood.png"],
            ["Zipper", "sample_images/zipper.png"],
        ],
        title="Anomaly Detection",
        description="Anomlay Detection on Industrial Images",
        css=".output-image, .image-preview {height: 300px !important}",
        allow_flagging="never",
    )

    interface.launch(share=False)