milestone4 / app.py
LemonPit's picture
Update app.py
e82a2b7 verified
raw
history blame
3.17 kB
from shiny import App, ui, render
import base64
from io import BytesIO
from PIL import Image, ImageOps
import numpy as np
import torch
from transformers import SamModel, SamProcessor
# Load the processor and model
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base")
model_path = "SAM/mito_model_checkpoint.pth"
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
def preprocess_image(image, target_size=(256, 256)):
""" Resize the image to a standard dimension """
image = ImageOps.contain(image, target_size)
return image
def postprocess_mask(mask, threshold=0.95):
""" Apply threshold to clean up mask """
return (mask > threshold).astype(np.uint8) * 255
def process_image(image_path):
image = Image.open(image_path).convert("RGB")
image = preprocess_image(image) # Resize image before processing
image_np = np.array(image)
inputs = processor(images=image_np, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy()
# Ensure we only use the first mask and squeeze out any singleton dimensions
segmented_image = postprocess_mask(pred_masks.squeeze(), threshold=0.95) # Apply postprocessing
pil_img = Image.fromarray(segmented_image, mode="L")
buffered = BytesIO()
pil_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
app_ui = ui.page_fluid(
ui.layout_sidebar(
ui.panel_sidebar(
ui.input_file("image_upload", "Upload Satellite Image", accept=".jpg,.jpeg,.png,.tif")
),
ui.panel_main(
ui.output_image("uploaded_image", "Uploaded Image"),
ui.output_ui("segmented_image", "Segmented Image") # Use output_ui for HTML content
)
)
)
def server(input, output, session):
@output
@render.image
def uploaded_image():
file_info = input.image_upload()
if file_info:
file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath']
return {'src': file_path}
@output
@render.ui # Use render.ui for direct HTML output
def segmented_image():
file_info = input.image_upload()
if file_info:
try:
file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath']
if file_path:
base64_img = process_image(file_path)
# Return an HTML image tag with the base64 data URI
return ui.tags.img(src=base64_img, style="max-width: 100%; height: auto;")
except Exception as e:
print(f"Error processing image: {e}")
return "No image processed."
# Create and run the Shiny app
app = App(app_ui, server)
app.run()