|
import streamlit as st |
|
import numpy as np |
|
import os |
|
import pathlib |
|
from inference import infer, InferenceModel |
|
from text import intro |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SatvisionDemoApp: |
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self.thumbnail_dir = pathlib.Path('data/thumbnails') |
|
self.image_dir = pathlib.Path('data/images') |
|
print(self.thumbnail_dir) |
|
self.thumbnail_files = sorted(list(self.thumbnail_dir.glob('sv-*.png'))) |
|
self.image_files = sorted(list(self.image_dir.glob('sv-*.npy'))) |
|
print(list(self.image_files)) |
|
self.thumbnail_names = [str(tn_path.name) for tn_path in self.thumbnail_files] |
|
print(self.thumbnail_names) |
|
|
|
self.inferenceModel = InferenceModel() |
|
|
|
|
|
|
|
|
|
def render_sidebar(self): |
|
|
|
st.sidebar.header("Select an Image") |
|
|
|
for index, thumbnail in enumerate(self.thumbnail_names): |
|
|
|
thumbnail_path = self.thumbnail_dir / thumbnail |
|
|
|
|
|
print(str(thumbnail_path)) |
|
|
|
st.sidebar.image(str(thumbnail_path), use_column_width=True, caption=thumbnail) |
|
|
|
|
|
|
|
|
|
def render_main_app(self): |
|
|
|
st.title("Satvision-Base Demo") |
|
|
|
st.header("Image Reconstruction Process") |
|
selected_image_index = st.sidebar.selectbox( |
|
"Select an Image", |
|
self.thumbnail_names) |
|
print(selected_image_index) |
|
|
|
selected_image = self.load_selected_image(selected_image_index) |
|
|
|
image, masked_input, output = self.inferenceModel.infer(selected_image) |
|
|
|
col1, col2, col3 = st.columns(3, gap="large") |
|
|
|
|
|
|
|
with col1: |
|
st.image(image, use_column_width=True, caption="Input") |
|
|
|
with col2: |
|
st.image(masked_input, use_column_width=True, caption="Input Masked") |
|
|
|
with col3: |
|
st.image(output, use_column_width=True, caption="Reconstruction") |
|
|
|
st.markdown(intro) |
|
|
|
st.image('data/figures/reconstruction.png') |
|
|
|
|
|
|
|
|
|
def load_selected_image(self, image_name): |
|
|
|
|
|
image_name = image_name.replace('.png', '.npy') |
|
|
|
image = np.load(self.image_dir / image_name) |
|
image = np.moveaxis(image, 0, 2) |
|
return image |
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
app = SatvisionDemoApp() |
|
|
|
app.render_main_app() |
|
|
|
app.render_sidebar() |
|
|
|
if __name__ == "__main__": |
|
|
|
main() |