Spaces:
Runtime error
Runtime error
Create new file
Browse files
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import gradio as gr
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import requests, validators
|
5 |
+
import torch
|
6 |
+
import pathlib
|
7 |
+
from PIL import Image
|
8 |
+
import datasets
|
9 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
14 |
+
|
15 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("/content/drive/MyDrive/Week 1 Project/saved_model_files")
|
16 |
+
model = AutoModelForImageClassification.from_pretrained("/content/drive/MyDrive/Week 1 Project/saved_model_files")
|
17 |
+
|
18 |
+
labels = ['angular_leaf_spot', 'bean_rust', 'healthy']
|
19 |
+
|
20 |
+
def classify(im):
|
21 |
+
'''FUnction for classifying plant health status'''
|
22 |
+
|
23 |
+
features = feature_extractor(im, return_tensors='pt')
|
24 |
+
with torch.no_grad():
|
25 |
+
logits = model(**features).logits
|
26 |
+
probability = torch.nn.functional.softmax(logits, dim=-1)
|
27 |
+
probs = probability[0].detach().numpy()
|
28 |
+
confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
|
29 |
+
|
30 |
+
return confidences
|
31 |
+
|
32 |
+
def get_original_image(url_input):
|
33 |
+
'''Get image from URL'''
|
34 |
+
if validators.url(url_input):
|
35 |
+
|
36 |
+
image = Image.open(requests.get(url_input, stream=True).raw)
|
37 |
+
|
38 |
+
return image
|
39 |
+
|
40 |
+
def detect_plant_health(url_input,image_input,webcam_input):
|
41 |
+
|
42 |
+
if validators.url(url_input):
|
43 |
+
image = Image.open(requests.get(url_input, stream=True).raw)
|
44 |
+
|
45 |
+
elif image_input:
|
46 |
+
image = image_input
|
47 |
+
|
48 |
+
elif webcam_input:
|
49 |
+
image = webcam_input
|
50 |
+
|
51 |
+
#Make prediction
|
52 |
+
label_probs = classify(image)
|
53 |
+
|
54 |
+
return label_probs
|
55 |
+
|
56 |
+
def set_example_image(example: list) -> dict:
|
57 |
+
return gr.Image.update(value=example[0])
|
58 |
+
|
59 |
+
def set_example_url(example: list) -> dict:
|
60 |
+
return gr.Textbox.update(value=example[0]), gr.Image.update(value=get_original_image(example[0]))
|
61 |
+
|
62 |
+
|
63 |
+
title = """<h1 id="title">Plant Health Classification with ViT</h1>"""
|
64 |
+
|
65 |
+
description = """
|
66 |
+
This Plant Health classifier app was built to detect the health of plants using images of leaves by fine-tuning a Vision Transformer (ViT) [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) on the [Beans](https://huggingface.co/datasets/beans) dataset.
|
67 |
+
The finetuned model has an accuracy of 98.4% on the test (unseen) dataset and 100% on the validation dataset.
|
68 |
+
|
69 |
+
How to use the app:
|
70 |
+
- Upload an image via 3 options, uploading the image from local device, using a URL (image from the web) or a webcam
|
71 |
+
- The app will take a few seconds to generate a prediction with the following labels:
|
72 |
+
- *'angular_leaf_spot'*
|
73 |
+
- *'bean_rust'*
|
74 |
+
- *'healthy'*
|
75 |
+
- Feel free to click the image examples as well.
|
76 |
+
"""
|
77 |
+
urls = ["https://www.healthbenefitstimes.com/green-beans/","https://huggingface.co/nateraw/vit-base-beans/resolve/main/angular_leaf_spot.jpeg", "https://huggingface.co/nateraw/vit-base-beans/resolve/main/bean_rust.jpeg"]
|
78 |
+
images = [[path.as_posix()] for path in sorted(pathlib.Path('images').rglob('*.j*g'))]
|
79 |
+
|
80 |
+
twitter_link = """
|
81 |
+
[![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi)
|
82 |
+
"""
|
83 |
+
|
84 |
+
css = '''
|
85 |
+
h1#title {
|
86 |
+
text-align: center;
|
87 |
+
}
|
88 |
+
'''
|
89 |
+
demo = gr.Blocks(css=css)
|
90 |
+
|
91 |
+
with demo:
|
92 |
+
gr.Markdown(title)
|
93 |
+
gr.Markdown(description)
|
94 |
+
gr.Markdown(twitter_link)
|
95 |
+
|
96 |
+
with gr.Tabs():
|
97 |
+
with gr.TabItem('Image Upload'):
|
98 |
+
with gr.Row():
|
99 |
+
with gr.Column():
|
100 |
+
img_input = gr.Image(type='pil',shape=(750,750))
|
101 |
+
label_from_upload= gr.Label()
|
102 |
+
|
103 |
+
with gr.Row():
|
104 |
+
example_images = gr.Examples(examples=images,inputs=[img_input])
|
105 |
+
|
106 |
+
|
107 |
+
img_but = gr.Button('Classify')
|
108 |
+
|
109 |
+
with gr.TabItem('Image URL'):
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
url_input = gr.Textbox(lines=2,label='Enter valid image URL here..')
|
113 |
+
original_image = gr.Image(shape=(750,750))
|
114 |
+
url_input.change(get_original_image, url_input, original_image)
|
115 |
+
with gr.Column():
|
116 |
+
label_from_url = gr.Label()
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
example_url = gr.Examples(examples=urls,inputs=[url_input])
|
120 |
+
|
121 |
+
|
122 |
+
url_but = gr.Button('Classify')
|
123 |
+
|
124 |
+
with gr.TabItem('WebCam'):
|
125 |
+
with gr.Row():
|
126 |
+
with gr.Column():
|
127 |
+
web_input = gr.Image(source='webcam',type='pil',shape=(750,750),streaming=True)
|
128 |
+
with gr.Column():
|
129 |
+
label_from_webcam= gr.Label()
|
130 |
+
|
131 |
+
cam_but = gr.Button('Classify')
|
132 |
+
|
133 |
+
url_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_url],queue=True)
|
134 |
+
img_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_upload],queue=True)
|
135 |
+
cam_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_webcam],queue=True)
|
136 |
+
|
137 |
+
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-plant-health)")
|
138 |
+
|
139 |
+
|
140 |
+
demo.launch(debug=True,enable_queue=True)
|