Spaces:
Sleeping
Sleeping
File size: 3,157 Bytes
e82a2b7 90baaac e82a2b7 90baaac 2a3e290 90baaac e82a2b7 90baaac e82a2b7 90baaac e82a2b7 90baaac e82a2b7 90baaac e82a2b7 90baaac 3dd227f 90baaac 3dd227f 90baaac e82a2b7 90baaac 3dd227f 90baaac 3dd227f 90baaac e82a2b7 3dd227f e82a2b7 90baaac e82a2b7 90baaac e82a2b7 90baaac e82a2b7 90baaac 31de6d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from shiny import App, ui, render
import base64
from io import BytesIO
from PIL import Image, ImageOps
import numpy as np
import torch
from transformers import SamModel, SamProcessor
# Load the processor and model
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base")
model_path = "mito_model_checkpoint.pth"
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
def preprocess_image(image, target_size=(256, 256)):
""" Resize the image to a standard dimension """
image = ImageOps.contain(image, target_size)
return image
def postprocess_mask(mask, threshold=0.95):
""" Apply threshold to clean up mask """
return (mask > threshold).astype(np.uint8) * 255
def process_image(image_path):
image = Image.open(image_path).convert("RGB")
image = preprocess_image(image) # Resize image before processing
image_np = np.array(image)
inputs = processor(images=image_np, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy()
# Ensure we only use the first mask and squeeze out any singleton dimensions
segmented_image = postprocess_mask(pred_masks.squeeze(), threshold=0.95) # Apply postprocessing
pil_img = Image.fromarray(segmented_image, mode="L")
buffered = BytesIO()
pil_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
app_ui = ui.page_fluid(
ui.layout_sidebar(
ui.panel_sidebar(
ui.input_file("image_upload", "Upload Satellite Image", accept=".jpg,.jpeg,.png,.tif")
),
ui.panel_main(
ui.output_image("uploaded_image", "Uploaded Image"),
ui.output_ui("segmented_image", "Segmented Image") # Use output_ui for HTML content
)
)
)
def server(input, output, session):
@output
@render.image
def uploaded_image():
file_info = input.image_upload()
if file_info:
file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath']
return {'src': file_path}
@output
@render.ui # Use render.ui for direct HTML output
def segmented_image():
file_info = input.image_upload()
if file_info:
try:
file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath']
if file_path:
base64_img = process_image(file_path)
# Return an HTML image tag with the base64 data URI
return ui.tags.img(src=base64_img, style="max-width: 100%; height: auto;")
except Exception as e:
print(f"Error processing image: {e}")
return "No image processed."
# Create and run the Shiny app
app = App(app_ui, server) |