from shiny import App, ui, render import base64 from io import BytesIO from PIL import Image, ImageOps import numpy as np import torch from transformers import SamModel, SamProcessor # Load the processor and model processor = SamProcessor.from_pretrained("facebook/sam-vit-base") model = SamModel.from_pretrained("facebook/sam-vit-base") model_path = "SAM/mito_model_checkpoint.pth" model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() def preprocess_image(image, target_size=(256, 256)): """ Resize the image to a standard dimension """ image = ImageOps.contain(image, target_size) return image def postprocess_mask(mask, threshold=0.95): """ Apply threshold to clean up mask """ return (mask > threshold).astype(np.uint8) * 255 def process_image(image_path): image = Image.open(image_path).convert("RGB") image = preprocess_image(image) # Resize image before processing image_np = np.array(image) inputs = processor(images=image_np, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs, multimask_output=False) pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy() # Ensure we only use the first mask and squeeze out any singleton dimensions segmented_image = postprocess_mask(pred_masks.squeeze(), threshold=0.95) # Apply postprocessing pil_img = Image.fromarray(segmented_image, mode="L") buffered = BytesIO() pil_img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/png;base64,{img_str}" app_ui = ui.page_fluid( ui.layout_sidebar( ui.panel_sidebar( ui.input_file("image_upload", "Upload Satellite Image", accept=".jpg,.jpeg,.png,.tif") ), ui.panel_main( ui.output_image("uploaded_image", "Uploaded Image"), ui.output_ui("segmented_image", "Segmented Image") # Use output_ui for HTML content ) ) ) def server(input, output, session): @output @render.image def uploaded_image(): file_info = input.image_upload() if file_info: file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath'] return {'src': file_path} @output @render.ui # Use render.ui for direct HTML output def segmented_image(): file_info = input.image_upload() if file_info: try: file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath'] if file_path: base64_img = process_image(file_path) # Return an HTML image tag with the base64 data URI return ui.tags.img(src=base64_img, style="max-width: 100%; height: auto;") except Exception as e: print(f"Error processing image: {e}") return "No image processed." # Create and run the Shiny app app = App(app_ui, server) app.run()