katielink's picture
Update app.py
d32f7b5
raw
history blame
2.52 kB
import os
import gradio as gr
import torch
from monai import bundle
from monai.transforms import (
Compose,
LoadImaged,
EnsureChannelFirstd,
Orientationd,
NormalizeIntensityd,
Activationsd,
AsDiscreted,
ScaleIntensityd,
)
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)
examples = [
['examples/BRATS_485.nii.gz'],
]
model, _, _ = bundle.load(
name = BUNDLE_NAME,
source = 'huggingface_hub',
repo = 'katielink/brats_mri_segmentation_v0.1.0',
load_ts_module=True,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
preproc_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys="image"),
Orientationd(keys=["image"], axcodes="RAS"),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
inferer = parser.get_parsed_content('inferer', lazy=True, eval_expr=True, instantiate=True)
post_transforms = Compose(
[
Activationsd(keys='pred', sigmoid=True),
AsDiscreted(keys='pred', threshold=0.5),
ScaleIntensityd(keys='image', minv=0., maxv=1.)
]
)
def predict(input_file, z_axis, model=model, device=device):
data = {'image': [input_file.name]}
data = preproc_transforms(data)
model.to(device)
model.eval()
with torch.no_grad():
inputs = data['image'].to(device)
data['pred'] = inferer(inputs=inputs[None,...], network=model)
data = post_transforms(data)
input_image = data['image'].numpy()
pred_image = data['pred'].cpu().detach().numpy()
#input_t1_image = input_image[0, :, :, z_axis]
input_t1c_image = input_image[1, :, :, z_axis]
#input_t2_image = input_image[2, :, :, z_axis]
#input_flair_image = input_image[3, :, :, z_axis]
pred_1_image = pred_image[0, 0, :, :, z_axis]
#pred_2_image = pred_image[0, 1, :, :, z_axis]
#pred_3_image = pred_image[0, 2, :, :, z_axis]
return input_t1c_image, pred_1_image, z_axis
iface = gr.Interface(
fn=predict,
inputs=[
gr.File(label='Nifti file'),
gr.Slider(0, 200, label='z-axis', value=100)
],
outputs=[
gr.Image(label='input image'),
gr.Image(label='segmentation'),
gr.Slider(0, 200, label='z-axis', value=100)],
title='Segment Brain Tumors using MONAI',
examples=examples,
)
iface.launch()