messis-demo / pages /2_Perform_Crop_Classification.py
yvokeller's picture
switch to AWS COG for satellite data
13efa6d
raw
history blame
3.61 kB
import streamlit as st
import leafmap.foliumap as leafmap
from transformers import PretrainedConfig
from folium import Icon
from messis.messis import Messis
from inference import perform_inference
st.set_page_config(layout="wide")
GEOTIFF_PATH = "./data/stacked_features.tif"
# Load the model
@st.cache_resource
def load_model():
config = PretrainedConfig.from_pretrained('crop-classification/messis', revision='47d9ca4')
model = Messis.from_pretrained('crop-classification/messis', cache_dir='./hf_cache/', revision='47d9ca4')
return model, config
model, config = load_model()
def perform_inference_step():
st.title("Step 2: Perform Crop Classification")
if "selected_location" not in st.session_state:
st.error("No location selected. Please select a location first.")
st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍")
return
lat, lon = st.session_state["selected_location"]
# Sidebar
st.sidebar.header("Settings")
# Timestep Slider
timestep = st.sidebar.slider("Select Timestep", 1, 9, 5)
# Band Dropdown
band_options = {
"RGB": [1, 2, 3], # Adjust indices based on the actual bands in your GeoTIFF
"NIR": [4],
"SWIR1": [5],
"SWIR2": [6]
}
vmin_vmax = {
"RGB": (89, 1878),
"NIR": (165, 5468),
"SWIR1": (120, 3361),
"SWIR2": (94, 2700)
}
selected_band = st.sidebar.selectbox("Select Satellite Band to Display", options=list(band_options.keys()), index=0)
# Calculate the band indices based on the selected timestep
selected_bands = [band + (timestep - 1) * 6 for band in band_options[selected_band]]
instructions = """
Click the button "Perform Crop Classification".
_Note:_
- Messis will classify the crop types for the fields in your selected location.
- Hover over the fields to see the predicted and true crop type.
- The satellite images might take a few seconds to load.
"""
st.sidebar.header("Instructions")
st.sidebar.markdown(instructions)
# Initialize the map
m = leafmap.Map(center=(lat, lon), zoom=10, draw_control=False)
# Perform inference
if st.sidebar.button("Perform Crop Classification", type="primary"):
predictions = perform_inference(lon, lat, model, config, debug=True)
m.add_data(predictions,
layer_name = "Predictions",
column="Correct",
add_legend=False,
style_function=lambda x: {"fillColor": "green" if x["properties"]["Correct"] else "red", "color": "black", "weight": 0, "fillOpacity": 0.25},
)
st.success("Inference completed!")
# GeoTIFF Satellite Imagery with selected timestep and band
# m.add_raster(
# GEOTIFF_PATH,
# layer_name="Satellite Image",
# bands=selected_bands,
# fit_bounds=True,
# vmin=vmin_vmax[selected_band][0],
# vmax=vmin_vmax[selected_band][1],
# )
# Add COG
m.add_cog_layer(
url="https://messis-demo.s3.amazonaws.com/stacked_features_cog.tif",
name="AWS COG",
bands=selected_bands,
rescale=f"{vmin_vmax[selected_band][0]},{vmin_vmax[selected_band][1]}",
zoom_to_layer=True
)
# Show the POI on the map
poi_icon = Icon(color="green", prefix="fa", icon="crosshairs")
m.add_marker(location=[lat, lon], popup="Selected Location", layer_name="POI", icon=poi_icon)
# Display the map in the Streamlit app
m.to_streamlit()
if __name__ == "__main__":
perform_inference_step()