etahamad's picture
handle low accuracy results
39f4ad3
import gradio as gr
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
import numpy as np
# Load the model
model = tf.keras.models.load_model('plant_model_v5-beta.h5')
# Define the class names
class_names = {
0: 'Apple___Apple_scab',
1: 'Apple___Black_rot',
2: 'Apple___Cedar_apple_rust',
3: 'Apple___healthy',
4: 'Not a plant',
5: 'Blueberry___healthy',
6: 'Cherry___Powdery_mildew',
7: 'Cherry___healthy',
8: 'Corn___Cercospora_leaf_spot Gray_leaf_spot',
9: 'Corn___Common_rust',
10: 'Corn___Northern_Leaf_Blight',
11: 'Corn___healthy',
12: 'Grape___Black_rot',
13: 'Grape___Esca_(Black_Measles)',
14: 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
15: 'Grape___healthy',
16: 'Orange___Haunglongbing_(Citrus_greening)',
17: 'Peach___Bacterial_spot',
18: 'Peach___healthy',
19: 'Pepper,_bell___Bacterial_spot',
20: 'Pepper,_bell___healthy',
21: 'Potato___Early_blight',
22: 'Potato___Late_blight',
23: 'Potato___healthy',
24: 'Raspberry___healthy',
25: 'Soybean___healthy',
26: 'Squash___Powdery_mildew',
27: 'Strawberry___Leaf_scorch',
28: 'Strawberry___healthy',
29: 'Tomato___Bacterial_spot',
30: 'Tomato___Early_blight',
31: 'Tomato___Late_blight',
32: 'Tomato___Leaf_Mold',
33: 'Tomato___Septoria_leaf_spot',
34: 'Tomato___Spider_mites Two-spotted_spider_mite',
35: 'Tomato___Target_Spot',
36: 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
37: 'Tomato___Tomato_mosaic_virus',
38: 'Tomato___healthy'
}
def edge_and_cut(img, threshold1, threshold2):
emb_img = img.copy()
edges = cv2.Canny(img, threshold1, threshold2)
edge_coors = []
for i in range(edges.shape[0]):
for j in range(edges.shape[1]):
if edges[i][j] != 0:
edge_coors.append((i, j))
if len(edge_coors) == 0:
return emb_img
row_min = edge_coors[np.argsort([coor[0] for coor in edge_coors])[0]][0]
row_max = edge_coors[np.argsort([coor[0] for coor in edge_coors])[-1]][0]
col_min = edge_coors[np.argsort([coor[1] for coor in edge_coors])[0]][1]
col_max = edge_coors[np.argsort([coor[1] for coor in edge_coors])[-1]][1]
new_img = img[row_min:row_max, col_min:col_max]
emb_color = np.array([255], dtype=np.uint8) # Grayscale version of [255, 0, 0]
emb_img[row_min-10:row_min+10, col_min:col_max] = emb_color
emb_img[row_max-10:row_max+10, col_min:col_max] = emb_color
emb_img[row_min:row_max, col_min-10:col_min+10] = emb_color
emb_img[row_min:row_max, col_max-10:col_max+10] = emb_color
return emb_img
def classify_and_visualize(image):
# Preprocess the image
img_array = tf.image.resize(image, [256, 256])
img_array = tf.expand_dims(img_array, 0) / 255.0
# Make a prediction
prediction = model.predict(img_array)
predicted_class = tf.argmax(prediction[0], axis=-1)
confidence = np.max(prediction[0])
if confidence < 0.60:
class_name = "The image you uploaded might not be in the dataset. Try making your leaf background white."
bounded_image = image
else:
class_name = class_names[predicted_class.numpy()]
bounded_image = edge_and_cut(image, 200, 400)
return class_name, confidence, bounded_image
iface = gr.Interface(
fn=classify_and_visualize,
inputs="image",
outputs=["text", "number", "image"],
interpretation="default",
examples=[
['examples/grot.jpg'],
['examples/corn.jpg'],
])
iface.launch()