amirkhanbloch commited on
Commit
8dd85af
1 Parent(s): 158698a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -136
app.py CHANGED
@@ -1,144 +1,67 @@
1
- import gradio as gr
2
- import numpy as np
3
- import cv2
4
- from PIL import Image
5
- from ultralytics import YOLO
6
- import os
7
- import google.generativeai as genai
8
-
9
- api_key = os.environ.get("GOOGLE_API_KEY")
10
- if api_key is None:
11
- api_key = os.getenv("GOOGLE_API_KEY")
12
- if api_key is None:
13
- raise ValueError(
14
- "GOOGLE_API_KEY environment variable not set. "
15
- "Please set it in your environment or pass it to the function."
16
- )
17
- genai.configure(api_key=api_key)
18
-
19
- # Generation config for Google Gemini
20
- generation_config = {
21
- "temperature": 1,
22
- "top_p": 0.95,
23
- "top_k": 0,
24
- "max_output_tokens": 8192,
25
- }
26
-
27
- # Safety settings for Google Gemini
28
- safety_settings = [
29
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
30
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
31
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
32
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
33
- ]
34
-
35
- # Load the models
36
- yolo_model_crop_disease = YOLO("models/crop_disease_model.pt")
37
- yolo_model_tomato = YOLO("models/tomato_freshness_model.pt")
38
 
39
- # Load the Gemini model for text generation
40
- def load_gemini_model():
41
- model = genai.GenerativeModel(
42
- model_name="gemini-1.5-pro",
43
- generation_config=generation_config,
44
- safety_settings=safety_settings
45
- )
46
- return model
47
 
48
- gemini_model = load_gemini_model()
49
-
50
- # Inference function for YOLOv8
51
- # Inference function to choose between models
52
- def inference(image, model_type):
53
- # Load the appropriate YOLO model based on the user's selection
54
- if model_type == "Crop Disease Detection":
55
- results = yolo_model_crop_disease(image, conf=0.4)
56
- else:
57
- results = yolo_model_tomato(image, conf=0.4)
58
-
59
- # Initialize output and class details
60
- infer = np.zeros(image.shape, dtype=np.uint8)
61
- classes = dict()
62
- names_infer = []
63
-
64
- # Process the detection results
65
- for r in results:
66
- infer = r.plot() # Visualize detection results
67
- classes = r.names # Retrieve class names
68
- names_infer = r.boxes.cls.tolist() # Get detected class indices
69
-
70
- return infer, names_infer, classes
71
 
72
- # Function to generate description using Gemini model based on predictions
73
- def generate_description(detected_classes, class_names, user_text, model_type):
74
- # Map the detected class indices to their corresponding class names
75
- detected_objects = [class_names[cls] for cls in detected_classes]
76
-
77
- # Modify the prompt based on the selected model
78
- if model_type == "Crop Disease Detection":
79
- prompt = f"""
80
- You are crop disease pathologist with extensive knowledge in agriculture.
81
- Your task is interpret the diagnoses of the infected crops.
82
- The following crop diseases have been detected based on the analysis: {', '.join(detected_objects)}.
83
-
84
- Please provide a detailed explanation of each disease including:
85
- - The nature of the disease
86
- - Typical symptoms and effects on crops
87
- - Recommended treatment options
88
- - Preventative measures to avoid future occurrences.
89
- """
90
- else:
91
- prompt = f"""
92
- The following condition of the tomato has been detected: {', '.join(detected_objects)}.
93
-
94
- Please provide a detailed explanation on:
95
- - Whether the tomato is fresh or rotten
96
- - How this condition is identified (e.g., characteristics)
97
- - Any handling recommendations for the tomato (e.g., consumption, disposal).
98
- """
99
 
100
- # Generate content using the Gemini model
101
- response = gemini_model.generate_content(prompt)
102
 
 
 
 
103
  return response.text
104
 
105
- # Gradio app
106
- with gr.Blocks() as iface:
107
- with gr.Row():
108
- with gr.Column():
109
- img = gr.Image(type="numpy", label="Upload Image")
110
- conf_threshold = gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold")
111
- iou_threshold = gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold")
112
- model_type = gr.Dropdown(choices=["Crop Disease Detection", "Tomato Freshness Detection"], label="Select Model")
113
-
114
- with gr.Column():
115
- processed_image_output = gr.Image(type="pil", label="Processed Image")
116
- with gr.Column():
117
- chatbot = gr.Chatbot()
118
- #msg = gr.Textbox(label="Your Question")
119
- submit = gr.Button("Submit")
120
- clear = gr.Button("Clear")
121
-
122
- def respond(img, conf_threshold, iou_threshold, chat_history, model_type):
123
- # Run YOLOv8 inference on the image based on the selected model
124
- processed_img, names_infer, classes = inference(img, model_type)
125
-
126
- # Get the last user message from the chat history, if any
127
- if chat_history:
128
- last_user_message = chat_history[-1][0]
129
- else:
130
- last_user_message = "" # Default to empty string if no history
131
-
132
- # Convert detected objects to text and generate a response using Gemini
133
- response = generate_description(names_infer, classes, last_user_message, model_type)
134
-
135
- # Append the user's question and AI's response to the chat history
136
- chat_history.append((last_user_message, response)) # Fixed: Add user message
137
 
138
- return processed_img, chat_history, response
139
-
140
- submit.click(respond, [img, conf_threshold, iou_threshold, chatbot, model_type], [processed_image_output, chatbot])
141
- clear.click(lambda: None, None, chatbot, queue=False)
142
-
143
- if __name__ == "__main__":
144
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ load_dotenv() ## load all the environment variables
 
 
 
 
 
 
 
4
 
5
+ import streamlit as st
6
+ import os
7
+ import google.generativeai as genai
8
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ ## Function to load Google Gemini Pro Vision API And get response
 
13
 
14
+ def get_gemini_repsonse(input,image,prompt):
15
+ model=genai.GenerativeModel('gemini-pro-vision')
16
+ response=model.generate_content([input,image[0],prompt])
17
  return response.text
18
 
19
+ def input_image_setup(uploaded_file):
20
+ # Check if a file has been uploaded
21
+ if uploaded_file is not None:
22
+ # Read the file into bytes
23
+ bytes_data = uploaded_file.getvalue()
24
+
25
+ image_parts = [
26
+ {
27
+ "mime_type": uploaded_file.type, # Get the mime type of the uploaded file
28
+ "data": bytes_data
29
+ }
30
+ ]
31
+ return image_parts
32
+ else:
33
+ raise FileNotFoundError("No file uploaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ ##initialize our streamlit app
36
+
37
+ st.set_page_config(page_title="Crop Disease Detection App")
38
+
39
+ st.header("Gemini Crop Disease App")
40
+ input=st.text_input("Input Prompt: ",key="input")
41
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
42
+ image=""
43
+ if uploaded_file is not None:
44
+ image = Image.open(uploaded_file)
45
+ st.image(image, caption="Uploaded Image.", use_column_width=True)
46
+
47
+
48
+ submit=st.button("Predict Crop/Plant Health")
49
+
50
+ input_prompt="""
51
+ "You are an expert in computer vision and agriculture who can easily predict the disease of the plant. "
52
+ "Analyze the following image and provide 6 outputs in a structured table format: "
53
+ "1. Crop in the image, "
54
+ "2. Whether it is infected or healthy, "
55
+ "3. Type of disease (if any), "
56
+ "4. How confident out of 100% whether image is healthy or infected "
57
+ "5. Reason for the disease such as whether it is happening due to fungus, bacteria, insect bite, poor nutrition, etc., "
58
+ "6. Precautions for it."
59
+ """
60
+
61
+ ## If submit button is clicked
62
+
63
+ if submit:
64
+ image_data=input_image_setup(uploaded_file)
65
+ response=get_gemini_repsonse(input_prompt,image_data,input)
66
+ st.subheader("The Response is")
67
+ st.write(response)