hasvio01 commited on
Commit
a8499ca
1 Parent(s): 32174b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -13
app.py CHANGED
@@ -1,30 +1,51 @@
1
  import tensorflow as tf
2
  import gradio as gr
3
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Load the trained model
6
- model = tf.keras.models.load_model('transferlearning_pokemon.keras')
7
 
8
  # Define class names (make sure this matches the classes used during training)
9
- class_names = ['Raichu', 'Vulpix', 'Machamp']
10
 
11
  # Define the prediction function
12
  def predict(image):
13
- image = tf.image.resize(image, (150, 150)) # Resize image to match model's input size
14
- image = np.expand_dims(image, axis=0) # Add batch dimension
 
 
15
  predictions = model.predict(image)
16
  predicted_class = np.argmax(predictions, axis=1)[0]
17
  confidence = np.max(predictions)
18
  return {class_names[predicted_class]: float(confidence)}
19
 
20
  # Create a Gradio interface
21
- iface = gr.Interface(
22
- fn=predict,
23
- inputs=gr.inputs.Image(type="numpy", label="Upload an image"),
24
- outputs=gr.outputs.Label(num_top_classes=3, label="Predictions"),
25
- title="Pokémon Classifier",
26
- description="Upload an image of a Pokémon and get the predicted class."
27
- )
28
 
29
  # Launch the Gradio interface
30
- iface.launch()
 
1
  import tensorflow as tf
2
  import gradio as gr
3
  import numpy as np
4
+ import os
5
+ from PIL import Image
6
+
7
+ print(tf.__version__)
8
+
9
+ print(f"Current working directory: {os.getcwd()}")
10
+ print(f"Contents of model directory: {os.listdir('model')}")
11
+
12
+ model_path = 'model/transferlearning_pokemon.h5'
13
+
14
+ # Check if the model exists
15
+ if os.path.exists(model_path):
16
+ print(f"Model found at {model_path}")
17
+ try:
18
+ # Load the trained model
19
+ model = tf.keras.models.load_model(model_path)
20
+ print("Model loaded successfully.")
21
+ except Exception as e:
22
+ print(f"Error loading model: {e}")
23
+ else:
24
+ print(f"Model not found at {model_path}. Please check the path.")
25
 
 
 
26
 
27
  # Define class names (make sure this matches the classes used during training)
28
+ class_names = ['Machamp', 'Raichu', 'Vulpix']
29
 
30
  # Define the prediction function
31
  def predict(image):
32
+ image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
33
+ image = image.resize((150, 150)) #resize the image to 28x28 and converts it to gray scale
34
+ image = np.array(image)
35
+ image = np.expand_dims(image, axis=0) # same as image[None, ...]
36
  predictions = model.predict(image)
37
  predicted_class = np.argmax(predictions, axis=1)[0]
38
  confidence = np.max(predictions)
39
  return {class_names[predicted_class]: float(confidence)}
40
 
41
  # Create a Gradio interface
42
+ input_image = gr.Image()
43
+ output_text = gr.Textbox(label="Predicted Value")
44
+ interface = gr.Interface(fn=predict,
45
+ inputs=input_image,
46
+ outputs=gr.Label(),
47
+ examples=["images/00000000.jpg", "images/00000001.jpg", "images/00000010.png", "images/00000017.jpg", "images/00000021.jpg", "images/00000067.jpg"],
48
+ description="A simple mlp classification model for image classification using the mnist dataset.")
49
 
50
  # Launch the Gradio interface
51
+ interface.launch()