import streamlit as st from PIL import Image import numpy as np from data.rg_masks import get_transforms from models import tiramisu from torchvision.transforms.functional import to_pil_image import torch from astropy.io import fits def load_fits(path): array = fits.getdata(path).astype(np.float32) array = np.expand_dims(array, 2) return array def load_image(path): image = Image.open(path) array = np.array(image) array = np.expand_dims(array[:,:,0], 2) return array def load_weights(model, fpath, device="cuda"): print("loading weights '{}'".format(fpath)) weights = torch.load(fpath, map_location=torch.device(device)) model.load_state_dict(weights['state_dict']) # Function to apply color overlay to the input image based on the segmentation mask def apply_color_overlay(input_image, segmentation_mask, alpha=0.5): r = (segmentation_mask == 1).float() g = (segmentation_mask == 2).float() b = (segmentation_mask == 3).float() overlay = torch.cat([r, g, b], dim=0) overlay = to_pil_image(overlay) output = Image.blend(input_image, overlay, alpha=alpha) return output # Streamlit app def main(): st.title("Tiramisu for semantic segmentation of radio astronomy images") st.write("Upload an image and see the segmentation result!") uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "fits"]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = tiramisu.FCDenseNet67(n_classes=4).to(device) load_weights(model, "weights/real.pth", device) model.eval() st.markdown( """ Category Legend: - :blue[Extended] - :green[Compact] - :red[Spurious] """ ) if uploaded_image is not None: # Load the uploaded image if uploaded_image.name.endswith(".fits"): input_array = load_fits(uploaded_image) else: input_array = load_image(uploaded_image) input_array = input_array.transpose(2,0,1) transforms = get_transforms(input_array.shape[1]) image = transforms(input_array) image = image.to(device) with torch.no_grad(): output = model(image) preds = output.argmax(1) pil_image = to_pil_image(image[0]) # Apply color overlay to the input image segmented_image = apply_color_overlay(pil_image, preds) # Display the input image and the segmented output st.image([pil_image, segmented_image], caption=["Input Image", "Segmented Output"], width=300) if __name__ == "__main__": main()