import gradio as gr from timeit import default_timer as timer from typing import Tuple , Dict import tensorflow as tf import numpy as np from PIL import Image import os # 1.Import and class names setup class_names = ['CNV','DME','DRUSEN','NORMAL'] # 2. Model annd transforms prepration # model = tf.keras.models.load_model( # 'oct_classification_final_model_lg.keras', custom_objects=None, compile=True, safe_mode=True # ) model = tf.keras.models.load_model( 'oct_classification_final_model_lg.keras', custom_objects=None, compile=True, safe_mode=False ) # Load save weights # 3.prediction function (predict()) def load_and_prep_imgg(img : Image.Image, img_shape=224, scale=True): # if not isinstance(filename, str): # raise ValueError("The filename must be a string representing the file path.") # img = tf.io.read_file(filename) # img = tf.io.decode_image(img, channels=3) # img = tf.image.resize(img, size=[img_shape, img_shape]) # if scale: # return img / 255 # else: # return img img = img.resize((img_shape, img_shape)) img = np.array(img) if img.shape[-1] == 1: # If the image is grayscale img = np.stack([img] * 3, axis=-1) img = tf.convert_to_tensor(img, dtype=tf.float32) if scale: return img / 255.0 else: return img def predict(img) -> Tuple[Dict,float] : start_time = timer() image = load_and_prep_imgg(img) #image = Image.open(image) pred_img = model.predict(tf.expand_dims(image, axis=0)) pred_class = class_names[pred_img.argmax()] print(f"Predicted macular diseases is: {pred_class} with probability: {pred_img.max():.2f}") end_time = timer() pred_time = round(end_time - start_time , 4) return pred_class , pred_time ### 4. Gradio app - our Gradio interface + launch command title = 'Macular Disease Classification' description = 'Feature Extraction VGG model to classify Macular Diseases by OCT' article = 'Created with TensorFlow Model Deployment' # Create example list example_list = [['examples/'+ example] for example in os.listdir('examples')] example_list # create a gradio demo demo = gr.Interface(fn=predict , inputs=gr.Image(type='pil'), outputs=[gr.Label(num_top_classes = 3 , label= 'prediction'), gr.Number(label= 'Prediction time (s)')], examples = example_list, title = title, description = description, article= article) # Launch the demo demo.launch(debug= False)