ademibeh commited on
Commit
a21c3cc
1 Parent(s): 3deb922

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ from tensorflow.keras.preprocessing import image as keras_image
5
+ from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input
6
+ from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess_input
7
+ from tensorflow.keras.models import load_model
8
+
9
+ # Load your trained models
10
+ resnet_model = load_model('/Users/beharademi/Documents/ZHAW/6. Sem LOKAL/KI/module exam/ademibeh/resnet_best_model.keras') # Update path
11
+ mobilenet_model = load_model('/Users/beharademi/Documents/ZHAW/6. Sem LOKAL/KI/module exam/ademibeh/mobilenet_best_model.keras') # Update path
12
+
13
+ def predict_comic_character(img, model_type):
14
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
15
+ img = img.resize((224, 224)) # Resize the image to fit model input
16
+ img_array = keras_image.img_to_array(img)
17
+ img_array = np.expand_dims(img_array, axis=0)
18
+
19
+ if model_type == 'ResNet50':
20
+ img_array = resnet_preprocess_input(img_array)
21
+ prediction = resnet_model.predict(img_array)
22
+ elif model_type == 'MobileNetV2':
23
+ img_array = mobilenet_preprocess_input(img_array)
24
+ prediction = mobilenet_model.predict(img_array)
25
+ else:
26
+ return {"error": "Invalid model type selected"}
27
+
28
+ classes = ['Superman', 'Batman', 'WonderWoman', 'Riddler', 'Spider-Man', 'Iron-Man',
29
+ 'Hulk', 'The Joker', 'Magneto', 'Wolverine', 'Deadpool', 'Catwoman']
30
+
31
+ return {classes[i]: float(prediction[0][i]) for i in range(len(classes))}
32
+
33
+ # Define the Gradio interface
34
+ interface = gr.Interface(
35
+ fn=predict_comic_character,
36
+ inputs="image",
37
+ outputs="label",
38
+ title="Comic Character Classifier",
39
+ description="Upload an image of a comic character and the classifier will predict the character.",
40
+ )
41
+
42
+ # Launch the interface
43
+ interface.launch()