import tempfile from pathlib import Path import SimpleITK as sitk import torch from mrsegmentator import inference from mrsegmentator.utils import add_postfix import gradio as gr import utils description_markdown = """ - **GitHub: https://github.com/hhaentze/mrsegmentator - **Paper: https://arxiv.org/abs/2405.06463" - **Please Note:** This tool is intended for research purposes only. """ css = """ h1 { text-align: center; display:block; } .markdown-block { background-color: #0b0f1a; /* Light gray background */ color: white; /* Black text */ padding: 10px; /* Padding around the text */ border-radius: 5px; /* Rounded corners */ box-shadow: 0 0 10px rgba(11,15,26,1); display: inline-flex; /* Use inline-flex to shrink to content size */ flex-direction: column; justify-content: center; /* Vertically center content */ align-items: center; /* Horizontally center items within */ margin: auto; /* Center the block */ } .markdown-block ul, .markdown-block ol { background-color: #1e2936; border-radius: 5px; padding: 10px; box-shadow: 0 0 10px rgba(0,0,0,0.3); padding-left: 20px; /* Adjust padding for bullet alignment */ text-align: left; /* Ensure text within list is left-aligned */ list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */ } footer { display:none !important } """ examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"] def save_file(segmentation, path): """If the segmentation comes from our sample files directly return the path. Otherwise save it to the temporary file that was previously allocated by the input image""" if Path(path).name in examples: path = "segmentations/" + add_postfix(path, "seg") else: sitk.WriteImage(segmentation, path) return path def infer(image_path): with tempfile.TemporaryDirectory() as tmpdirname: inference.infer( [image_path], tmpdirname, [0, 1, 2, 3, 4], cpu_only=False if torch.cuda.is_available() else True ) filename = add_postfix(Path(image_path).name, "seg") segmentation = sitk.ReadImage(tmpdirname + "/" + filename) return segmentation def infer_wrapper(input_file, image_state, seg_state, slider=50): filename = Path(input_file).name # inference if filename in examples: segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg")) else: segmentation = infer(input_file.name) # save file seg_path = save_file(segmentation, input_file.name) seg_state.append(utils.sitk2numpy(segmentation)) return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path with gr.Blocks(css=css, title="MRSegmentator") as iface: gr.Markdown("# Robust Multi-Modality Segmentation of 40 Classes in MRI and CT Imaging") gr.Markdown(description_markdown, elem_classes="markdown-block") image_state = gr.State([]) seg_state = gr.State([]) with gr.Row(): with gr.Column(): input_file = gr.File( type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"] ) gr.Examples(["images/" + ex for ex in examples], input_file) with gr.Row(): submit_button = gr.Button("Run", variant="primary") clear_button = gr.ClearButton() slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice") download_file = gr.File(label="Download Segmentation", interactive=False) with gr.Column(): overlay_image_np = gr.AnnotatedImage(label="Axial View") # pred_dict = gr.Label(label="Prediction") # explanation= gr.Textbox(label="Classification Decision") # with gr.Accordion("Additional Information", open=False): # gradcam = gr.Image(label="GradCAM") # cropped_boxed_array_disp = gr.Image(label="Bounding Box") input_file.change( utils.read_and_display, inputs=[input_file, image_state, seg_state], outputs=[overlay_image_np, image_state, seg_state], ) slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np]) submit_button.click( infer_wrapper, inputs=[input_file, image_state, seg_state, slider], outputs=[overlay_image_np, seg_state, download_file], ) clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file]) if __name__ == "__main__": iface.queue() # iface.launch(server_name='0.0.0.0', server_port=8080) iface.launch()