Spaces:
Sleeping
Sleeping
from shiny import App, ui, render | |
import base64 | |
from io import BytesIO | |
from PIL import Image, ImageOps | |
import numpy as np | |
import torch | |
from transformers import SamModel, SamProcessor | |
# Load the processor and model | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
model = SamModel.from_pretrained("facebook/sam-vit-base") | |
model_path = "SAM/mito_model_checkpoint.pth" | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
model.eval() | |
def preprocess_image(image, target_size=(256, 256)): | |
""" Resize the image to a standard dimension """ | |
image = ImageOps.contain(image, target_size) | |
return image | |
def postprocess_mask(mask, threshold=0.95): | |
""" Apply threshold to clean up mask """ | |
return (mask > threshold).astype(np.uint8) * 255 | |
def process_image(image_path): | |
image = Image.open(image_path).convert("RGB") | |
image = preprocess_image(image) # Resize image before processing | |
image_np = np.array(image) | |
inputs = processor(images=image_np, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs, multimask_output=False) | |
pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy() | |
# Ensure we only use the first mask and squeeze out any singleton dimensions | |
segmented_image = postprocess_mask(pred_masks.squeeze(), threshold=0.95) # Apply postprocessing | |
pil_img = Image.fromarray(segmented_image, mode="L") | |
buffered = BytesIO() | |
pil_img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
return f"data:image/png;base64,{img_str}" | |
app_ui = ui.page_fluid( | |
ui.layout_sidebar( | |
ui.panel_sidebar( | |
ui.input_file("image_upload", "Upload Satellite Image", accept=".jpg,.jpeg,.png,.tif") | |
), | |
ui.panel_main( | |
ui.output_image("uploaded_image", "Uploaded Image"), | |
ui.output_ui("segmented_image", "Segmented Image") # Use output_ui for HTML content | |
) | |
) | |
) | |
def server(input, output, session): | |
def uploaded_image(): | |
file_info = input.image_upload() | |
if file_info: | |
file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath'] | |
return {'src': file_path} | |
# Use render.ui for direct HTML output | |
def segmented_image(): | |
file_info = input.image_upload() | |
if file_info: | |
try: | |
file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath'] | |
if file_path: | |
base64_img = process_image(file_path) | |
# Return an HTML image tag with the base64 data URI | |
return ui.tags.img(src=base64_img, style="max-width: 100%; height: auto;") | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
return "No image processed." | |
# Create and run the Shiny app | |
app = App(app_ui, server) | |
app.run() |