File size: 13,381 Bytes
767f684
ab163d2
 
 
 
 
 
 
 
 
 
 
 
 
 
767f684
ab163d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767f684
 
 
 
 
ab163d2
767f684
 
ab163d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
from monai.transforms import Transform, Compose, LoadImage, EnsureChannelFirst
import torch
import skimage
import torch
import SimpleITK as sitk
import numpy as np
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import SimpleITK as sitk
from matplotlib.colors import ListedColormap
import base64
import numpy as np
from cv2 import dilate
from scipy.ndimage import label
from Model_Seg import RgbaToGrayscale

def image_to_base64(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

class CustomCLAHE(Transform):
    """Implements Contrast-Limited Adaptive Histogram Equalization (CLAHE) as a custom transform, as described by Qiu et al.

    Attributes:
        p1 (float): Weighting factor, determines degree of of contour enhacement. Default is 0.6.
        p2 (None or int): Kernel size for adaptive histogram. Default is None.
        p3 (float): Clip limit for histogram equalization. Default is 0.01.

    """

    def __init__(self, p1=0.6, p2=None, p3=0.01):
        self.p1 = p1
        self.p2 = p2
        self.p3 = p3

    def __call__(self, data):
        """Apply the CLAHE algorithm to input data.

        Args:
            data (Union[dict, np.ndarray]): Input data. Could be a dictionary containing the image or the image array itself.

        Returns:
            torch.Tensor: Transformed data.
        """
        
        if isinstance(data, dict):
            im = data["image"]

        else:
            im = data
        im = im.numpy()
        

        # remove the first dimension
        im = im[0]
        im = im[None, :, :]
        #im = np.expand_dims(im, axis=0)
        im = skimage.exposure.rescale_intensity(im, in_range="image", out_range=(0, 1))
        im_noi = skimage.filters.median(im)
        im_fil = im_noi - self.p1 * skimage.filters.gaussian(im_noi, sigma=1)
        im_fil = skimage.exposure.rescale_intensity(im_fil, in_range="image", out_range=(0, 1))
        im_ce = skimage.exposure.equalize_adapthist(im_fil, kernel_size=self.p2, clip_limit=self.p3)
        if isinstance(data, dict):
            data["image"] = torch.Tensor(im_ce)
        else:
            data = torch.Tensor(im_ce)
            
        return data



def custom_colormap():

    cdict = [(0, 0, 0, 0),    # Class 0 - fully transparent (background)
             (0, 1, 0, 0.5),  # Class 1 - Green with 50% transparency 
             (1, 0, 0, 0.5),  # Class 2 - Red with 50% transparency 
             (1, 1, 0, 0.5)]  # Class 3 - Yellow with 50% transparency 
    cmap = ListedColormap(cdict)
    return cmap

def read_image(image_path):
    read_transforms = Compose([
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        RgbaToGrayscale(),  # Convert RGBA to grayscale
    ])
    try:
        original_image = read_transforms(image_path)
        original_image_np = original_image.numpy().astype(np.uint8)
        return original_image_np.squeeze()

    except Exception as e:
        try :
            original_image = sitk.ReadImage(image_path)
            original_image_np = sitk.GetArrayFromImage(original_image)
            return original_image_np.squeeze()
        except Exception as e:
            print("Failed Loading the Image: ", e)
            return None

def overlay_mask(image_path, image_mask):
    original_image_np = read_image(image_path).squeeze().astype(np.uint8)

    #adjust mask intensities for display
    image_mask_disp = image_mask
    plt.figure(figsize=(10, 10))
    plt.imshow(original_image_np, cmap='gray')

    plt.imshow(image_mask_disp, cmap=custom_colormap(), alpha=0.5)
    plt.axis('off')

    # Save the overlay to a buffer
    buffer = BytesIO()
    plt.savefig(buffer, format='png', bbox_inches='tight', pad_inches=0)
    buffer.seek(0)
    overlay_image_np = np.array(Image.open(buffer))
    return overlay_image_np, original_image_np


def bounding_box_mask(image, label):
    """Generates a bounding box mask around a labeled region in an image

    Args:
        image (SimpleITK.Image): The input image.
        label (SimpleITK.Image): The labeled image containing the region of interest.

    Returns:
        SimpleITK.Image: An image containing the with the bounding box mask applied with the
        same spacing as the original image.

    Note:
        This function assumes that the input image and label are SimpleITK.Image objects.
        The returned bounding box mask is a binary image where pixels inside the bounding box
        are set to 1 and others are set to 0.
    """
    # get original spacing
    original_spacing = image.GetSpacing()

    # convert image and label to arrays
    image_array = sitk.GetArrayFromImage(image)
    image_array = np.squeeze(image_array)
    label_array = sitk.GetArrayFromImage(label)
    label_array = np.squeeze(label_array)

    # determine corners of the bounding box
    first_nonzero_row_index = np.nonzero(np.any(label_array != 0, axis=1))[0][0]
    last_nonzero_row_index = np.max(np.nonzero(np.any(label_array != 0, axis=1)))
    first_nonzero_column_index = np.nonzero(np.any(label_array != 0, axis=0))[0][0]
    last_nonzero_column_index = np.max(np.nonzero(np.any(label_array != 0, axis=0)))

    top_left_corner = (first_nonzero_row_index, first_nonzero_column_index)
    bottom_right_corner = (last_nonzero_row_index, last_nonzero_column_index)

    # define the bounding box as an array mask
    bounding_box_array = label_array.copy()
    bounding_box_array[
        top_left_corner[0] : bottom_right_corner[0] + 1,
        top_left_corner[1] : bottom_right_corner[1] + 1,
    ] = 1
    
    # add channel dimension
    bounding_box_array = bounding_box_array[None, ...].astype(np.uint8)

    # get Image from Array Mask and apply original spacing
    bounding_box_image = sitk.GetImageFromArray(bounding_box_array)
    bounding_box_image.SetSpacing(original_spacing)
    return bounding_box_image


def threshold_based_crop(image):
    """
    Use Otsu's threshold estimator to separate background and foreground. In medical imaging the background is
    usually air. Then crop the image using the foreground's axis aligned bounding box.
    Args:
        image (SimpleITK image): An image where the anatomy and background intensities form a
                                 bi-modal distribution
                                 (the assumption underlying Otsu's method.)
    Return:
        Cropped image based on foreground's axis aligned bounding box.
    """

    inside_value = 0
    outside_value = 255
    label_shape_filter = sitk.LabelShapeStatisticsImageFilter()
    # uncomment for debugging
    #sitk.WriteImage(image, "./image.png")
    label_shape_filter.Execute(sitk.OtsuThreshold(image, inside_value, outside_value))
    bounding_box = label_shape_filter.GetBoundingBox(outside_value)
    return sitk.RegionOfInterest(
        image,
        bounding_box[int(len(bounding_box) / 2) :],
        bounding_box[0 : int(len(bounding_box) / 2)],
    )

def creat_SIJ_mask(image, input_label):
    """
    Create a mask for the sacroiliac joints (SIJ) from pelvis and sascrum segmentation mask

    Args:
        image (SimpleITK.Image): x-ray image.
        input_label (SimpleITK.Image): Segmentation mask containing labels for sacrum, left- and right pelvis

    Returns:
        SimpleITK.Image: Mask of the SIJ

    """
    
    original_spacing = image.GetSpacing()
    # uncomment for debugging
    #sitk.WriteImage(input_label, "./input_label.png")
    mask_array = sitk.GetArrayFromImage(input_label).squeeze()
    
    sacrum_value = 1  
    left_pelvis_value = 2  
    right_pelvis_value = 3  
    background_value = 0

    
    sacrum_mask = (mask_array == sacrum_value)

    first_nonzero_column_index = np.nonzero(np.any(sacrum_mask != 0, axis=0))[0][0]
    last_nonzero_column_index = np.max(np.nonzero(np.any(sacrum_mask != 0, axis=0)))
    box_width=last_nonzero_column_index-first_nonzero_column_index

    dilation_extent = int(np.round(0.05 * box_width))

    dilated_sacrum_mask = dilate_mask(sacrum_mask, dilation_extent)

    intersection_left = (dilated_sacrum_mask & (mask_array == left_pelvis_value))
    if np.all(intersection_left == 0):
        print("Warning: No left intersection")
        left_pelvis_mask = (mask_array == 2)
        intersection_left = create_median_height_array(left_pelvis_mask)
        
    intersection_left = keep_largest_component(intersection_left)
    
    intersection_right = (dilated_sacrum_mask & (mask_array == right_pelvis_value))
    if np.all(intersection_right == 0):
        print("Warning: No right intersection")
        right_pelvis_mask = (mask_array == 3)
        intersection_right = create_median_height_array(right_pelvis_mask)
    intersection_right = keep_largest_component(intersection_right)
    
    intersection_mask = intersection_left +intersection_right
    intersection_mask = intersection_mask[None, ...]
                                    
    instersection_mask_im = sitk.GetImageFromArray(intersection_mask)
    instersection_mask_im.SetSpacing(original_spacing)
    return instersection_mask_im

def dilate_mask(mask, extent):
    """
    Keeps only the largest connected component in a binary segmentation mask.

    Args:
        mask (numpy.ndarray): A numpy array representing the binary segmentation mask, 
                              with 1s indicating the label and 0s indicating the background.

    Returns:
        numpy.ndarray: A modified version of the input mask, where only the largest 
                       connected component is retained, and other components are set to 0.

    """
    mask_uint8 = mask.astype(np.uint8)

    kernel = np.ones((2*extent+1, 2*extent+1), np.uint8)
    dilated_mask = dilate(mask_uint8, kernel, iterations=1)
    return dilated_mask

def mask_and_crop(image, input_label):
    """
    Performs masking and cropping operations on an image and its label.

    Args:
        image (SimpleITK.Image): The image to be processed.
        label (SimpleITK.Image): The corresponding label image.

    Returns:
        tuple: A tuple containing two SimpleITK.Image objects.
            - cropped_boxed_image: The image after applying bounding box masking and cropping.
            - mask: The binary mask corresponding to the label after cropping.

    Note:
        This function relies on other functions: bounding_box_mask() and threshold_based_crop().
    """
    input_label = creat_SIJ_mask(image,input_label)
    box_mask = bounding_box_mask(image, input_label)
    
    boxed_image = sitk.Mask(image, box_mask, maskingValue=0, outsideValue=0)
    masked_image = sitk.Mask(image, input_label, maskingValue=0, outsideValue=0)

    cropped_boxed_image = threshold_based_crop(boxed_image)
    cropped_masked_image = threshold_based_crop(masked_image)

    mask = np.squeeze(sitk.GetArrayFromImage(cropped_masked_image))
    mask = np.where(mask > 0, 1, 0)
    mask = sitk.GetImageFromArray(mask[None, ...])
    return cropped_boxed_image, mask

def create_median_height_array(mask):
    """
    Creates an array based on the median height of non-zero elements in each column of the input mask.

    Args:
        mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background.

    Returns:
        numpy.ndarray: A new binary mask array with columns filled based on the median height,
                       or None if the input mask has no non-zero columns.
                       
    Note: 
        This function is only used when there is no intersection between pelvis and sacrum, and creates an alternative
        SIJ mask, that serves as an approximate replacement.
    """
    rows, cols = mask.shape
    column_details = []

    for col in range(cols):
        column_data = mask[:, col]
        non_zero_indices = np.nonzero(column_data)[0]
        if non_zero_indices.size > 0:
            height = non_zero_indices[-1] - non_zero_indices[0] + 1
            start_idx = non_zero_indices[0]
            column_details.append((height, start_idx, col))
            
    if not column_details:
        return None  
    median_height = round(np.median([h[0] for h in column_details]))
    median_cols = [(col, start_idx) for height, start_idx, col in column_details if height == median_height]
    new_array = np.zeros_like(mask, dtype=int)
    for col, start_idx in median_cols:
        start_col = max(0, col - 5)
        end_col = min(cols, col + 5)
        new_array[start_idx:start_idx + median_height, start_col:end_col] = 1
    return new_array

def keep_largest_component(mask):
    """
    Identifies and retains the largest connected component in a binary segmentation mask.

    Args:
        mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background.

    Returns:
        numpy.ndarray: The modified mask with only the largest connected component.
    """
    # Label the connected components
    labeled_array, num_features = label(mask)

    # If no features are found, return the original mask
    if num_features <= 1:
        return mask

    # Find the largest connected component
    largest_component = np.argmax(np.bincount(labeled_array.flat)[1:]) + 1

    # Generate the mask for the largest component
    return (labeled_array == largest_component).astype(mask.dtype)