DeepForest / app.py
daviddao
add comma
29889aa
import gradio as gr
from deepforest import main
import matplotlib.pyplot as plt
# Initialize the deepforest model and use the released version
model = main.deepforest()
model.use_release()
def predict_and_visualize(image):
"""
Function to predict and visualize the image using deepforest model.
Args:
- image: An image array.
Returns:
- An image with predictions visualized.
"""
# Predict image and return plot. Since Gradio passes image as array, save it temporarily.
temp_path = "/tmp/uploaded_image.png"
plt.imsave(temp_path, image)
img = model.predict_image(path=temp_path, return_plot=True)
# Since the output is BGR and matplotlib (and hence Gradio) needs RGB, we convert the color scheme
img_rgb = img[:, :, ::-1]
# Return the RGB image
return img_rgb
# Define the Gradio interface
iface = gr.Interface(fn=predict_and_visualize,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.Image(label="Predicted Image"),
title="DeepForest Tree Detection",
examples=["./example.jpg"],
description="Upload an image to detect trees using the DeepForest model.")
# Launch the Gradio app
iface.launch()