Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
#!pip install tensorflow tensorflow-datasets gradio pillow matplotlib | |
model_path = "pokemon-model_transferlearning.keras" | |
model = tf.keras.models.load_model(model_path) | |
# Define the core prediction function | |
def predict_pokemon(image): | |
# Preprocess image | |
image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image | |
image = image.resize((150, 150)) # Resize the image to 150x150 | |
image = np.array(image) | |
image = np.expand_dims(image, axis=0) # Add batch dimension | |
# Predict | |
prediction = model.predict(image) | |
# Apply softmax to get probabilities for each class | |
probabilities = tf.nn.softmax(prediction) | |
# Map probabilities to Pokemon classes | |
pokemon_classes = ['Articuno', 'Bulbasaur', 'Charmander'] | |
probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(pokemon_classes, probabilities[0])} | |
return probabilities_dict | |
# Create the Gradio interface | |
input_image = gr.Image() | |
iface = gr.Interface( | |
fn=predict_pokemon, | |
inputs=input_image, | |
outputs=gr.Label(), | |
live=True, | |
examples=["images/01.jpg", "images/02.png", "images/03.png", "images/04.jpg", "images/05.png", "images/06.png"], | |
description="A simple mlp classification model for image classification using the mnist dataset.") | |
iface.launch() |