import os import gradio as gr import torch from monai import bundle from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Orientationd, NormalizeIntensityd, Activationsd, AsDiscreted, ScaleIntensityd, ) BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0' BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME) examples = [ ['examples/BRATS_485.nii.gz'], ] model, _, _ = bundle.load( name = BUNDLE_NAME, source = 'huggingface_hub', repo = 'katielink/brats_mri_segmentation_v0.1.0', load_ts_module=True, ) device = "cuda:0" if torch.cuda.is_available() else "cpu" parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json') preproc_transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys="image"), Orientationd(keys=["image"], axcodes="RAS"), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), ] ) inferer = parser.get_parsed_content('inferer', lazy=True, eval_expr=True, instantiate=True) post_transforms = Compose( [ Activationsd(keys='pred', sigmoid=True), AsDiscreted(keys='pred', threshold=0.5), ScaleIntensityd(keys='image', minv=0., maxv=1.) ] ) def predict(input_file, z_axis, model=model, device=device): data = {'image': [input_file.name]} data = preproc_transforms(data) model.to(device) model.eval() with torch.no_grad(): inputs = data['image'].to(device) data['pred'] = inferer(inputs=inputs[None,...], network=model) data = post_transforms(data) input_image = data['image'].numpy() pred_image = data['pred'].cpu().detach().numpy() #input_t1_image = input_image[0, :, :, z_axis] input_t1c_image = input_image[1, :, :, z_axis] #input_t2_image = input_image[2, :, :, z_axis] #input_flair_image = input_image[3, :, :, z_axis] pred_1_image = pred_image[0, 0, :, :, z_axis] #pred_2_image = pred_image[0, 1, :, :, z_axis] #pred_3_image = pred_image[0, 2, :, :, z_axis] return input_t1c_image, pred_1_image, z_axis iface = gr.Interface( fn=predict, inputs=[ gr.File(label='Nifti file'), gr.Slider(0, 200, label='z-axis', value=100) ], outputs=[ gr.Image(label='input image'), gr.Image(label='segmentation'), gr.Slider(0, 200, label='z-axis', value=100)], title='Segment Brain Tumors using MONAI', examples=examples, ) iface.launch()