Spaces:
Sleeping
Sleeping
import streamlit as st | |
import leafmap.foliumap as leafmap | |
from transformers import PretrainedConfig | |
from folium import Icon | |
from messis.messis import Messis | |
from inference import perform_inference | |
st.set_page_config(layout="wide") | |
GEOTIFF_PATH = "./data/stacked_features.tif" | |
# Load the model | |
def load_model(): | |
config = PretrainedConfig.from_pretrained('crop-classification/messis', revision='47d9ca4') | |
model = Messis.from_pretrained('crop-classification/messis', cache_dir='./hf_cache/', revision='47d9ca4') | |
return model, config | |
model, config = load_model() | |
def perform_inference_step(): | |
st.title("Step 2: Perform Crop Classification") | |
if "selected_location" not in st.session_state: | |
st.error("No location selected. Please select a location first.") | |
st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍") | |
return | |
lat, lon = st.session_state["selected_location"] | |
# Sidebar | |
st.sidebar.header("Settings") | |
# Timestep Slider | |
timestep = st.sidebar.slider("Select Timestep", 1, 9, 5) | |
# Band Dropdown | |
band_options = { | |
"RGB": [1, 2, 3], # Adjust indices based on the actual bands in your GeoTIFF | |
"NIR": [4], | |
"SWIR1": [5], | |
"SWIR2": [6] | |
} | |
vmin_vmax = { | |
"RGB": (89, 1878), | |
"NIR": (165, 5468), | |
"SWIR1": (120, 3361), | |
"SWIR2": (94, 2700) | |
} | |
selected_band = st.sidebar.selectbox("Select Satellite Band to Display", options=list(band_options.keys()), index=0) | |
# Calculate the band indices based on the selected timestep | |
selected_bands = [band + (timestep - 1) * 6 for band in band_options[selected_band]] | |
instructions = """ | |
Click the button "Perform Crop Classification". | |
_Note:_ | |
- Messis will classify the crop types for the fields in your selected location. | |
- Hover over the fields to see the predicted and true crop type. | |
- The satellite images might take a few seconds to load. | |
""" | |
st.sidebar.header("Instructions") | |
st.sidebar.markdown(instructions) | |
# Initialize the map | |
m = leafmap.Map(center=(lat, lon), zoom=10, draw_control=False) | |
# Perform inference | |
if st.sidebar.button("Perform Crop Classification", type="primary"): | |
predictions = perform_inference(lon, lat, model, config, debug=True) | |
m.add_data(predictions, | |
layer_name = "Predictions", | |
column="Correct", | |
add_legend=False, | |
style_function=lambda x: {"fillColor": "green" if x["properties"]["Correct"] else "red", "color": "black", "weight": 0, "fillOpacity": 0.25}, | |
) | |
st.success("Inference completed!") | |
# GeoTIFF Satellite Imagery with selected timestep and band | |
# m.add_raster( | |
# GEOTIFF_PATH, | |
# layer_name="Satellite Image", | |
# bands=selected_bands, | |
# fit_bounds=True, | |
# vmin=vmin_vmax[selected_band][0], | |
# vmax=vmin_vmax[selected_band][1], | |
# ) | |
# Add COG | |
m.add_cog_layer( | |
url="https://messis-demo.s3.amazonaws.com/stacked_features_cog.tif", | |
name="AWS COG", | |
bands=selected_bands, | |
rescale=f"{vmin_vmax[selected_band][0]},{vmin_vmax[selected_band][1]}", | |
zoom_to_layer=True | |
) | |
# Show the POI on the map | |
poi_icon = Icon(color="green", prefix="fa", icon="crosshairs") | |
m.add_marker(location=[lat, lon], popup="Selected Location", layer_name="POI", icon=poi_icon) | |
# Display the map in the Streamlit app | |
m.to_streamlit() | |
if __name__ == "__main__": | |
perform_inference_step() | |