File size: 13,381 Bytes
767f684 ab163d2 767f684 ab163d2 767f684 ab163d2 767f684 ab163d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
from monai.transforms import Transform, Compose, LoadImage, EnsureChannelFirst
import torch
import skimage
import torch
import SimpleITK as sitk
import numpy as np
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import SimpleITK as sitk
from matplotlib.colors import ListedColormap
import base64
import numpy as np
from cv2 import dilate
from scipy.ndimage import label
from Model_Seg import RgbaToGrayscale
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
class CustomCLAHE(Transform):
"""Implements Contrast-Limited Adaptive Histogram Equalization (CLAHE) as a custom transform, as described by Qiu et al.
Attributes:
p1 (float): Weighting factor, determines degree of of contour enhacement. Default is 0.6.
p2 (None or int): Kernel size for adaptive histogram. Default is None.
p3 (float): Clip limit for histogram equalization. Default is 0.01.
"""
def __init__(self, p1=0.6, p2=None, p3=0.01):
self.p1 = p1
self.p2 = p2
self.p3 = p3
def __call__(self, data):
"""Apply the CLAHE algorithm to input data.
Args:
data (Union[dict, np.ndarray]): Input data. Could be a dictionary containing the image or the image array itself.
Returns:
torch.Tensor: Transformed data.
"""
if isinstance(data, dict):
im = data["image"]
else:
im = data
im = im.numpy()
# remove the first dimension
im = im[0]
im = im[None, :, :]
#im = np.expand_dims(im, axis=0)
im = skimage.exposure.rescale_intensity(im, in_range="image", out_range=(0, 1))
im_noi = skimage.filters.median(im)
im_fil = im_noi - self.p1 * skimage.filters.gaussian(im_noi, sigma=1)
im_fil = skimage.exposure.rescale_intensity(im_fil, in_range="image", out_range=(0, 1))
im_ce = skimage.exposure.equalize_adapthist(im_fil, kernel_size=self.p2, clip_limit=self.p3)
if isinstance(data, dict):
data["image"] = torch.Tensor(im_ce)
else:
data = torch.Tensor(im_ce)
return data
def custom_colormap():
cdict = [(0, 0, 0, 0), # Class 0 - fully transparent (background)
(0, 1, 0, 0.5), # Class 1 - Green with 50% transparency
(1, 0, 0, 0.5), # Class 2 - Red with 50% transparency
(1, 1, 0, 0.5)] # Class 3 - Yellow with 50% transparency
cmap = ListedColormap(cdict)
return cmap
def read_image(image_path):
read_transforms = Compose([
LoadImage(image_only=True),
EnsureChannelFirst(),
RgbaToGrayscale(), # Convert RGBA to grayscale
])
try:
original_image = read_transforms(image_path)
original_image_np = original_image.numpy().astype(np.uint8)
return original_image_np.squeeze()
except Exception as e:
try :
original_image = sitk.ReadImage(image_path)
original_image_np = sitk.GetArrayFromImage(original_image)
return original_image_np.squeeze()
except Exception as e:
print("Failed Loading the Image: ", e)
return None
def overlay_mask(image_path, image_mask):
original_image_np = read_image(image_path).squeeze().astype(np.uint8)
#adjust mask intensities for display
image_mask_disp = image_mask
plt.figure(figsize=(10, 10))
plt.imshow(original_image_np, cmap='gray')
plt.imshow(image_mask_disp, cmap=custom_colormap(), alpha=0.5)
plt.axis('off')
# Save the overlay to a buffer
buffer = BytesIO()
plt.savefig(buffer, format='png', bbox_inches='tight', pad_inches=0)
buffer.seek(0)
overlay_image_np = np.array(Image.open(buffer))
return overlay_image_np, original_image_np
def bounding_box_mask(image, label):
"""Generates a bounding box mask around a labeled region in an image
Args:
image (SimpleITK.Image): The input image.
label (SimpleITK.Image): The labeled image containing the region of interest.
Returns:
SimpleITK.Image: An image containing the with the bounding box mask applied with the
same spacing as the original image.
Note:
This function assumes that the input image and label are SimpleITK.Image objects.
The returned bounding box mask is a binary image where pixels inside the bounding box
are set to 1 and others are set to 0.
"""
# get original spacing
original_spacing = image.GetSpacing()
# convert image and label to arrays
image_array = sitk.GetArrayFromImage(image)
image_array = np.squeeze(image_array)
label_array = sitk.GetArrayFromImage(label)
label_array = np.squeeze(label_array)
# determine corners of the bounding box
first_nonzero_row_index = np.nonzero(np.any(label_array != 0, axis=1))[0][0]
last_nonzero_row_index = np.max(np.nonzero(np.any(label_array != 0, axis=1)))
first_nonzero_column_index = np.nonzero(np.any(label_array != 0, axis=0))[0][0]
last_nonzero_column_index = np.max(np.nonzero(np.any(label_array != 0, axis=0)))
top_left_corner = (first_nonzero_row_index, first_nonzero_column_index)
bottom_right_corner = (last_nonzero_row_index, last_nonzero_column_index)
# define the bounding box as an array mask
bounding_box_array = label_array.copy()
bounding_box_array[
top_left_corner[0] : bottom_right_corner[0] + 1,
top_left_corner[1] : bottom_right_corner[1] + 1,
] = 1
# add channel dimension
bounding_box_array = bounding_box_array[None, ...].astype(np.uint8)
# get Image from Array Mask and apply original spacing
bounding_box_image = sitk.GetImageFromArray(bounding_box_array)
bounding_box_image.SetSpacing(original_spacing)
return bounding_box_image
def threshold_based_crop(image):
"""
Use Otsu's threshold estimator to separate background and foreground. In medical imaging the background is
usually air. Then crop the image using the foreground's axis aligned bounding box.
Args:
image (SimpleITK image): An image where the anatomy and background intensities form a
bi-modal distribution
(the assumption underlying Otsu's method.)
Return:
Cropped image based on foreground's axis aligned bounding box.
"""
inside_value = 0
outside_value = 255
label_shape_filter = sitk.LabelShapeStatisticsImageFilter()
# uncomment for debugging
#sitk.WriteImage(image, "./image.png")
label_shape_filter.Execute(sitk.OtsuThreshold(image, inside_value, outside_value))
bounding_box = label_shape_filter.GetBoundingBox(outside_value)
return sitk.RegionOfInterest(
image,
bounding_box[int(len(bounding_box) / 2) :],
bounding_box[0 : int(len(bounding_box) / 2)],
)
def creat_SIJ_mask(image, input_label):
"""
Create a mask for the sacroiliac joints (SIJ) from pelvis and sascrum segmentation mask
Args:
image (SimpleITK.Image): x-ray image.
input_label (SimpleITK.Image): Segmentation mask containing labels for sacrum, left- and right pelvis
Returns:
SimpleITK.Image: Mask of the SIJ
"""
original_spacing = image.GetSpacing()
# uncomment for debugging
#sitk.WriteImage(input_label, "./input_label.png")
mask_array = sitk.GetArrayFromImage(input_label).squeeze()
sacrum_value = 1
left_pelvis_value = 2
right_pelvis_value = 3
background_value = 0
sacrum_mask = (mask_array == sacrum_value)
first_nonzero_column_index = np.nonzero(np.any(sacrum_mask != 0, axis=0))[0][0]
last_nonzero_column_index = np.max(np.nonzero(np.any(sacrum_mask != 0, axis=0)))
box_width=last_nonzero_column_index-first_nonzero_column_index
dilation_extent = int(np.round(0.05 * box_width))
dilated_sacrum_mask = dilate_mask(sacrum_mask, dilation_extent)
intersection_left = (dilated_sacrum_mask & (mask_array == left_pelvis_value))
if np.all(intersection_left == 0):
print("Warning: No left intersection")
left_pelvis_mask = (mask_array == 2)
intersection_left = create_median_height_array(left_pelvis_mask)
intersection_left = keep_largest_component(intersection_left)
intersection_right = (dilated_sacrum_mask & (mask_array == right_pelvis_value))
if np.all(intersection_right == 0):
print("Warning: No right intersection")
right_pelvis_mask = (mask_array == 3)
intersection_right = create_median_height_array(right_pelvis_mask)
intersection_right = keep_largest_component(intersection_right)
intersection_mask = intersection_left +intersection_right
intersection_mask = intersection_mask[None, ...]
instersection_mask_im = sitk.GetImageFromArray(intersection_mask)
instersection_mask_im.SetSpacing(original_spacing)
return instersection_mask_im
def dilate_mask(mask, extent):
"""
Keeps only the largest connected component in a binary segmentation mask.
Args:
mask (numpy.ndarray): A numpy array representing the binary segmentation mask,
with 1s indicating the label and 0s indicating the background.
Returns:
numpy.ndarray: A modified version of the input mask, where only the largest
connected component is retained, and other components are set to 0.
"""
mask_uint8 = mask.astype(np.uint8)
kernel = np.ones((2*extent+1, 2*extent+1), np.uint8)
dilated_mask = dilate(mask_uint8, kernel, iterations=1)
return dilated_mask
def mask_and_crop(image, input_label):
"""
Performs masking and cropping operations on an image and its label.
Args:
image (SimpleITK.Image): The image to be processed.
label (SimpleITK.Image): The corresponding label image.
Returns:
tuple: A tuple containing two SimpleITK.Image objects.
- cropped_boxed_image: The image after applying bounding box masking and cropping.
- mask: The binary mask corresponding to the label after cropping.
Note:
This function relies on other functions: bounding_box_mask() and threshold_based_crop().
"""
input_label = creat_SIJ_mask(image,input_label)
box_mask = bounding_box_mask(image, input_label)
boxed_image = sitk.Mask(image, box_mask, maskingValue=0, outsideValue=0)
masked_image = sitk.Mask(image, input_label, maskingValue=0, outsideValue=0)
cropped_boxed_image = threshold_based_crop(boxed_image)
cropped_masked_image = threshold_based_crop(masked_image)
mask = np.squeeze(sitk.GetArrayFromImage(cropped_masked_image))
mask = np.where(mask > 0, 1, 0)
mask = sitk.GetImageFromArray(mask[None, ...])
return cropped_boxed_image, mask
def create_median_height_array(mask):
"""
Creates an array based on the median height of non-zero elements in each column of the input mask.
Args:
mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background.
Returns:
numpy.ndarray: A new binary mask array with columns filled based on the median height,
or None if the input mask has no non-zero columns.
Note:
This function is only used when there is no intersection between pelvis and sacrum, and creates an alternative
SIJ mask, that serves as an approximate replacement.
"""
rows, cols = mask.shape
column_details = []
for col in range(cols):
column_data = mask[:, col]
non_zero_indices = np.nonzero(column_data)[0]
if non_zero_indices.size > 0:
height = non_zero_indices[-1] - non_zero_indices[0] + 1
start_idx = non_zero_indices[0]
column_details.append((height, start_idx, col))
if not column_details:
return None
median_height = round(np.median([h[0] for h in column_details]))
median_cols = [(col, start_idx) for height, start_idx, col in column_details if height == median_height]
new_array = np.zeros_like(mask, dtype=int)
for col, start_idx in median_cols:
start_col = max(0, col - 5)
end_col = min(cols, col + 5)
new_array[start_idx:start_idx + median_height, start_col:end_col] = 1
return new_array
def keep_largest_component(mask):
"""
Identifies and retains the largest connected component in a binary segmentation mask.
Args:
mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background.
Returns:
numpy.ndarray: The modified mask with only the largest connected component.
"""
# Label the connected components
labeled_array, num_features = label(mask)
# If no features are found, return the original mask
if num_features <= 1:
return mask
# Find the largest connected component
largest_component = np.argmax(np.bincount(labeled_array.flat)[1:]) + 1
# Generate the mask for the largest component
return (labeled_array == largest_component).astype(mask.dtype) |