vincentclaes's picture
return confidence score
5ed6ee0
raw
history blame
1.74 kB
import pathlib
import gradio as gr
from loguru import logger
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
logger.info("starting gradio app")
CURRENT_DIR = pathlib.Path(__file__).resolve().parent
APP_NAME = "Mona Lisa Detection"
logger.debug("loading processor and model.")
processor = AutoFeatureExtractor.from_pretrained(
"drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
)
model = AutoModelForImageClassification.from_pretrained(
"drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
)
logger.debug("loading processor and model succeeded.")
def process_image(image, model=model, processor=processor):
logger.info("Making a prediction ...")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
label = {1: "Not Mona Lisa", 0: "Mona Lisa"}
predictions = logits.softmax(dim=-1).tolist()
result = {label[predicted_class_idx]: predictions[0][predicted_class_idx]}
print("Predicted class:", result)
logger.info("Prediction finished.")
return result
examples = [
"mona-lisa-1.jpg",
"mona-lisa-2.jpg",
"mona-lisa-3.jpg",
"not-mona-lisa-1.jpg",
"not-mona-lisa-2.jpg",
"not-mona-lisa-3.jpg",
]
if __name__ == "__main__":
title = """
Mona Lisa Detection.
"""
app = gr.Interface(
fn=process_image,
inputs=[
gr.inputs.Image(type="pil", label="Image"),
],
outputs=gr.Label(label="Predictions:", show_label=True),
examples=examples,
examples_per_page=32,
title=title,
enable_queue=True,
).launch()