|
import gradio as gr |
|
import os |
|
import tensorflow |
|
|
|
from typing import Tuple, Dict |
|
import tensorflow as tf |
|
import numpy as np |
|
from PIL import Image |
|
from timeit import default_timer as timer |
|
|
|
|
|
|
|
with open('class_names.txt') as f: |
|
class_names = [breed.strip() for breed in f.readlines()] |
|
|
|
|
|
effnet = tensorflow.keras.models.load_model('demo/dog_breed_classifier/dog_breed_effnet_augmentation.h5') |
|
|
|
|
|
effnet_preprocess_input = tensorflow.keras.applications.efficientnet_v2.preprocess_input |
|
|
|
|
|
example_list = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
|
|
def predict(img) -> Tuple[Dict, float]: |
|
""" |
|
Transforms and performs a prediction on an image and returns prediction and time taken. |
|
|
|
Args: |
|
image_path (str): Path to the input image. |
|
|
|
Returns: |
|
Tuple[Dict, float]: A tuple containing a dictionary of class labels and prediction probabilities |
|
and the prediction time. |
|
""" |
|
|
|
start_time = timer() |
|
|
|
|
|
img = img.resize((224, 224)) |
|
|
|
|
|
x = np.array(img) |
|
x = effnet_preprocess_input(x) |
|
|
|
|
|
x = tf.expand_dims(x, axis=0) |
|
|
|
|
|
predictions = effnet(x) |
|
|
|
top_classes_indices = np.argsort(predictions[0])[::-1][:3] |
|
top_classes = [class_names[i] for i in top_classes_indices] |
|
top_probabilities = [predictions[0][index] for index in top_classes_indices] * 100 |
|
|
|
|
|
pred_labels_and_probs = {top_classes[i]: float(top_probabilities[i]) for i in range(len(top_classes_indices))} |
|
|
|
|
|
pred_time = round(timer() - start_time, 5) |
|
|
|
|
|
return pred_labels_and_probs, pred_time |
|
|
|
|
|
|
|
title = "πΆ Dog Breeds Classifier πΎ" |
|
description = "π An EfficientNetV2S feature extractor computer vision model to classify images of 120 different breeds. πΈ" |
|
article = "π Created at [GitHub](https://github.com/adinmg/dog_breed_classifier)." |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[ |
|
gr.Label(num_top_classes=3, label="Predictions"), |
|
gr.Number(label="Prediction time (s)"), |
|
], |
|
examples=example_list, |
|
title=title, |
|
description=description, |
|
article=article, |
|
) |
|
|
|
|
|
demo.launch() |
|
|