cfa_mvtec_test / app.py
masoudpz's picture
Change gradio share flag
d16f83e
from __future__ import annotations
from importlib import import_module
from pathlib import Path
import gradio as gr
import gradio.inputs
import gradio.outputs
import numpy as np
import os
from anomalib.deploy import Inferencer
def get_inferencer(weight_path: Path, metadata_path: Path | None = None) -> Inferencer:
"""Parse args and open inferencer.
Args:
weight_path (Path): Path to model weights.
metadata_path (Path | None, optional): Metadata is required for OpenVINO models. Defaults to None.
Raises:
ValueError: If unsupported model weight is passed.
Returns:
Inferencer: Torch or OpenVINO inferencer.
"""
inferencer: Inferencer
module = import_module("anomalib.deploy")
openvino_inferencer = getattr(module, "OpenVINOInferencer")
print(f"weight path: {weight_path}")
print(f"metadata path: {metadata_path}")
inferencer = openvino_inferencer(path=weight_path, metadata_path=metadata_path)
return inferencer
def infer(radio: str, image: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Inference function, return anomaly map, score, heat map, prediction mask ans visualisation.
Args:
image (np.ndarray): image to compute
inferencer (Inferencer): model inferencer
Returns:
tuple[np.ndarray, float, np.ndarray, np.ndarray, np.ndarray]:
heat_map, pred_mask, segmentation result.
"""
# Perform inference for the given image.
print(f"Radio Value: {radio.lower()}")
print(f"{os.getcwd()}")
weight_path = f"cfa/mvtec/{radio.lower()}/run/weights/openvino/model.onnx"
metadata_path = f"cfa/mvtec/{radio.lower()}/run/weights/openvino/metadata.json"
inferencer = get_inferencer(weight_path, metadata_path)
predictions = inferencer.predict(image=image)
return (predictions.heat_map, predictions.pred_mask, predictions.segmentations)
if __name__ == "__main__":
interface = gr.Interface(
fn=lambda radio, image: infer(radio, image),
inputs=[
gr.Radio(
[
"Bottle",
"Cable",
"Capsule",
"Carpet",
"Grid",
"Hazelnut",
"Leather",
"Metal_nut",
"Pill",
"Screw",
"Tile",
"Toothbrush",
"Transistor",
"Wood",
"Zipper",
],
label="MVTEC Class Name",
value="Bottle",
).style(height=400),
gradio.inputs.Image(
shape=None, image_mode="RGB", source="upload", tool="editor", type="numpy", label="Image"
).style(height=350),
],
outputs=[
gradio.outputs.Image(type="numpy", label="Predicted Heat Map").style(height=200),
gradio.outputs.Image(type="numpy", label="Predicted Mask").style(height=200),
gradio.outputs.Image(type="numpy", label="Segmentation Result").style(height=200),
],
examples=[
["Bottle", "sample_images/bottle.png"],
["Cable", "sample_images/cable.png"],
["Capsule", "sample_images/capsule.png"],
["Carpet", "sample_images/carpet.png"],
["Grid", "sample_images/grid.png"],
["Hazelnut", "sample_images/hazelnut.png"],
["Leather", "sample_images/leather.png"],
["Metal_nut", "sample_images/metal_nut.png"],
["Pill", "sample_images/pill.png"],
["Screw", "sample_images/screw.png"],
["Tile", "sample_images/tile.png"],
["Toothbrush", "sample_images/toothbrush.png"],
["Transistor", "sample_images/transistor.png"],
["Wood", "sample_images/wood.png"],
["Zipper", "sample_images/zipper.png"],
],
title="Anomaly Detection",
description="Anomlay Detection on Industrial Images",
css=".output-image, .image-preview {height: 300px !important}",
allow_flagging="never",
)
interface.launch(share=False)