|
import os |
|
|
|
import gradio as gr |
|
|
|
from .convert import nifti_to_obj |
|
from .css_style import css |
|
from .inference import run_model |
|
from .logger import flush_logs |
|
from .logger import read_logs |
|
from .logger import setup_logger |
|
from .utils import load_ct_to_numpy |
|
from .utils import load_pred_volume_to_numpy |
|
|
|
|
|
LOGGER = setup_logger() |
|
|
|
|
|
class WebUI: |
|
def __init__( |
|
self, |
|
model_name: str = None, |
|
cwd: str = "/home/user/app/", |
|
share: int = 1, |
|
): |
|
|
|
self.images = [] |
|
self.pred_images = [] |
|
|
|
|
|
self.nb_slider_items = 820 |
|
|
|
self.model_name = model_name |
|
self.cwd = cwd |
|
self.share = share |
|
|
|
self.class_name = "Lymph Nodes" |
|
self.class_names = { |
|
"Lymph Nodes": "CT_LymphNodes", |
|
} |
|
|
|
self.result_names = { |
|
"Lymph Nodes": "LymphNodes", |
|
} |
|
|
|
|
|
self.slider = gr.Slider( |
|
minimum=1, |
|
maximum=self.nb_slider_items, |
|
value=1, |
|
step=1, |
|
label="Which 2D slice to show", |
|
) |
|
self.volume_renderer = gr.Model3D( |
|
clear_color=[0.0, 0.0, 0.0, 0.0], |
|
label="3D Model", |
|
show_label=True, |
|
visible=True, |
|
elem_id="model-3d", |
|
camera_position=[90, 180, 768], |
|
).style(height=512) |
|
|
|
def set_class_name(self, value): |
|
LOGGER.info(f"Changed task to: {value}") |
|
self.class_name = value |
|
|
|
def combine_ct_and_seg(self, img, pred): |
|
return (img, [(pred, self.class_name)]) |
|
|
|
def upload_file(self, file): |
|
out = file.name |
|
LOGGER.info(f"File uploaded: {out}") |
|
return out |
|
|
|
def process(self, mesh_file_name): |
|
path = mesh_file_name.name |
|
run_model( |
|
path, |
|
model_path=os.path.join(self.cwd, "resources/models/"), |
|
task=self.class_names[self.class_name], |
|
name=self.result_names[self.class_name], |
|
) |
|
LOGGER.info("Converting prediction NIfTI to OBJ...") |
|
nifti_to_obj("prediction.nii.gz") |
|
|
|
LOGGER.info("Loading CT to numpy...") |
|
self.images = load_ct_to_numpy(path) |
|
|
|
LOGGER.info("Loading prediction volume to numpy..") |
|
self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz") |
|
|
|
return "./prediction.obj" |
|
|
|
def get_img_pred_pair(self, k): |
|
k = int(k) |
|
out = gr.AnnotatedImage(self.combine_ct_and_seg(self.images[k], self.pred_images[k]), visible=True, elem_id="model-2d",).style( |
|
color_map={self.class_name: "#ffae00"}, |
|
height=512, |
|
width=512, |
|
) |
|
return out |
|
|
|
def toggle_sidebar(self, state): |
|
state = not state |
|
return gr.update(visible=state), state |
|
|
|
def run(self): |
|
with gr.Blocks(css=css) as demo: |
|
with gr.Row(): |
|
with gr.Column(visible=True, scale=0.2) as sidebar_left: |
|
logs = gr.Textbox( |
|
placeholder="\n" * 16, |
|
label="Logs", |
|
info="Verbose from inference will be displayed below.", |
|
lines=38, |
|
max_lines=38, |
|
autoscroll=True, |
|
elem_id="logs", |
|
show_copy_button=True, |
|
scroll_to_output=False, |
|
container=True, |
|
line_breaks=True, |
|
) |
|
demo.load(read_logs, None, logs, every=1) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(scale=0.2, min_width=150): |
|
sidebar_state = gr.State(True) |
|
|
|
btn_toggle_sidebar = gr.Button( |
|
"Toggle Sidebar", |
|
elem_id="toggle-button", |
|
) |
|
btn_toggle_sidebar.click( |
|
self.toggle_sidebar, |
|
[sidebar_state], |
|
[sidebar_left, sidebar_state], |
|
) |
|
|
|
btn_clear_logs = gr.Button("Clear logs", elem_id="logs-button") |
|
btn_clear_logs.click(flush_logs, [], []) |
|
|
|
file_output = gr.File(file_count="single", elem_id="upload") |
|
file_output.upload(self.upload_file, file_output, file_output) |
|
|
|
model_selector = gr.Dropdown( |
|
list(self.class_names.keys()), |
|
label="Task", |
|
info="Which structure to segment.", |
|
multiselect=False, |
|
size="sm", |
|
) |
|
model_selector.input( |
|
fn=lambda x: self.set_class_name(x), |
|
inputs=model_selector, |
|
outputs=None, |
|
) |
|
|
|
with gr.Column(scale=0.2, min_width=150): |
|
run_btn = gr.Button("Run analysis", variant="primary", elem_id="run-button",).style( |
|
full_width=False, |
|
size="lg", |
|
) |
|
run_btn.click( |
|
fn=lambda x: self.process(x), |
|
inputs=file_output, |
|
outputs=self.volume_renderer, |
|
) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=[ |
|
os.path.join(self.cwd, "test_thorax_CT.nii.gz"), |
|
], |
|
inputs=file_output, |
|
outputs=file_output, |
|
fn=self.upload_file, |
|
cache_examples=True, |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
**NOTE:** Inference might take several minutes (Lymph nodes: ~8 minutes), see logs to the left. \\ |
|
The segmentation will be available in the 2D and 3D viewers below when finished. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Box(): |
|
with gr.Column(): |
|
|
|
t = gr.AnnotatedImage(visible=True, elem_id="model-2d").style( |
|
color_map={self.class_name: "#ffae00"}, |
|
height=512, |
|
width=512, |
|
) |
|
|
|
self.slider.input( |
|
self.get_img_pred_pair, |
|
self.slider, |
|
t, |
|
) |
|
|
|
self.slider.render() |
|
|
|
with gr.Box(): |
|
self.volume_renderer.render() |
|
|
|
|
|
|
|
|
|
|
|
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=self.share) |
|
|