SwinUNETR_body_segmentation / dicom_to_nii.py
Margerie's picture
requirements, model weights, preprocessing and post processing
5e2c32d verified
import pydicom
import sys
import os
import numpy as np
import nibabel as nib
import scipy
def convert_ct_dicom_to_nii(dir_dicom, dir_nii, outputname, newvoxelsize=None):
Patients = PatientList() # initialize list of patient data
# search dicom files in the patient data folder, stores all files in the attributes (all CT images, dose file, struct file)
Patients.list_dicom_files(dir_dicom, 1)
patient = Patients.list[0]
patient_name = patient.PatientInfo.PatientName
patient.import_patient_data(newvoxelsize)
CT = patient.CTimages[0]
image_position_patient = CT.ImagePositionPatient
voxelsize = np.array(CT.PixelSpacing)
save_images(dst_dir=os.path.join(dir_nii), voxelsize=voxelsize,
image_position_patient=image_position_patient, image=CT.Image, outputname=outputname)
return CT
def save_images(dst_dir, voxelsize, image_position_patient, image, outputname):
# encode in nii and save at dst_dir
# IMPORTANT I NEED TO CONFIRM THE SIGNS OF THE ENTRIES IN THE AFFINE,
# ALTHOUGH MAYBE AT THE END THE IMPORTANCE IS HOW WE WILL USE THIS DATA ....
# also instead of changing field by field, the pixdim and affine can be encoded
# using the set_sform method --> info here: https://nipy.org/nibabel/nifti_images.html
# IMAGE (CT, MR ...)
image_shape = image.shape
# Separate Conversion from preprocessing
# image = overwrite_ct_threshold(image)
# for Nifti1 header, change for a Nifti2 type of header
image_nii = nib.Nifti1Image(image, affine=np.eye(4))
# Update header fields
image_nii = set_header_info(image_nii, voxelsize, image_position_patient)
# Save nii
nib.save(image_nii, os.path.join(dst_dir, outputname))
# nib.save(image_nii, os.path.join(dst_dir, 'ct.nii.gz'))
# def overwrite_ct_threshold(ct_image, body, artefact=None, contrast=None):
# # Change the HU out of the body to air: -1000
# ct_image[body == 0] = -1000
# if artefact is not None:
# # Change the HU to muscle: 14
# ct_image[artefact == 1] = 14
# if contrast is not None:
# # Change the HU to water: 0 Houndsfield Unit: CT unit
# ct_image[contrast == 1] = 0
# # Threshold above 1560HU
# ct_image[ct_image > 1560] = 1560
# return ct_image
def set_header_info(nii_file, voxelsize, image_position_patient, contours_exist=None):
nii_file.header['pixdim'][1] = voxelsize[0]
nii_file.header['pixdim'][2] = voxelsize[1]
nii_file.header['pixdim'][3] = voxelsize[2]
# affine - voxelsize
nii_file.affine[0][0] = voxelsize[0]
nii_file.affine[1][1] = voxelsize[1]
nii_file.affine[2][2] = voxelsize[2]
# affine - imagecorner
nii_file.affine[0][3] = image_position_patient[0]
nii_file.affine[1][3] = image_position_patient[1]
nii_file.affine[2][3] = image_position_patient[2]
if contours_exist:
nii_file.header.extensions.append(
nib.nifti1.Nifti1Extension(0, bytearray(contours_exist)))
return nii_file
class PatientList:
def __init__(self):
self.list = []
def find_CT_image(self, display_id):
count = -1
for patient_id in range(len(self.list)):
for ct_id in range(len(self.list[patient_id].CTimages)):
if (self.list[patient_id].CTimages[ct_id].isLoaded == 1):
count += 1
if (count == display_id):
break
if (count == display_id):
break
return patient_id, ct_id
def find_dose_image(self, display_id):
count = -1
for patient_id in range(len(self.list)):
for dose_id in range(len(self.list[patient_id].RTdoses)):
if (self.list[patient_id].RTdoses[dose_id].isLoaded == 1):
count += 1
if (count == display_id):
break
if (count == display_id):
break
return patient_id, dose_id
def find_contour(self, ROIName):
for patient_id in range(len(self.list)):
for struct_id in range(len(self.list[patient_id].RTstructs)):
if (self.list[patient_id].RTstructs[struct_id].isLoaded == 1):
for contour_id in range(len(self.list[patient_id].RTstructs[struct_id].Contours)):
if (self.list[patient_id].RTstructs[struct_id].Contours[contour_id].ROIName == ROIName):
return patient_id, struct_id, contour_id
def list_dicom_files(self, folder_path, recursive):
file_list = os.listdir(folder_path)
# print("len file_list", len(file_list), "folderpath",folder_path)
for file_name in file_list:
file_path = os.path.join(folder_path, file_name)
# folders
if os.path.isdir(file_path):
if recursive == True:
subfolder_list = self.list_dicom_files(file_path, True)
# join_patient_lists(Patients, subfolder_list)
# files
elif os.path.isfile(file_path):
try:
dcm = pydicom.dcmread(file_path)
except:
print("Invalid Dicom file: " + file_path)
continue
patient_id = next((x for x, val in enumerate(
self.list) if val.PatientInfo.PatientID == dcm.PatientID), -1)
if patient_id == -1:
Patient = PatientData()
Patient.PatientInfo.PatientID = dcm.PatientID
Patient.PatientInfo.PatientName = str(dcm.PatientName)
Patient.PatientInfo.PatientBirthDate = dcm.PatientBirthDate
Patient.PatientInfo.PatientSex = dcm.PatientSex
self.list.append(Patient)
patient_id = len(self.list) - 1
# Dicom CT
if dcm.SOPClassUID == "1.2.840.10008.5.1.4.1.1.2":
ct_id = next((x for x, val in enumerate(
self.list[patient_id].CTimages) if val.SeriesInstanceUID == dcm.SeriesInstanceUID), -1)
if ct_id == -1:
CT = CTimage()
CT.SeriesInstanceUID = dcm.SeriesInstanceUID
CT.SOPClassUID == "1.2.840.10008.5.1.4.1.1.2"
CT.PatientInfo = self.list[patient_id].PatientInfo
CT.StudyInfo = StudyInfo()
CT.StudyInfo.StudyInstanceUID = dcm.StudyInstanceUID
CT.StudyInfo.StudyID = dcm.StudyID
CT.StudyInfo.StudyDate = dcm.StudyDate
CT.StudyInfo.StudyTime = dcm.StudyTime
if (hasattr(dcm, 'SeriesDescription') and dcm.SeriesDescription != ""):
CT.ImgName = dcm.SeriesDescription
else:
CT.ImgName = dcm.SeriesInstanceUID
self.list[patient_id].CTimages.append(CT)
ct_id = len(self.list[patient_id].CTimages) - 1
self.list[patient_id].CTimages[ct_id].DcmFiles.append(
file_path)
else:
print("Unknown SOPClassUID " +
dcm.SOPClassUID + " for file " + file_path)
# other
else:
print("Unknown file type " + file_path)
def print_patient_list(self):
print("")
for patient in self.list:
patient.print_patient_info()
print("")
class PatientData:
def __init__(self):
self.PatientInfo = PatientInfo()
self.CTimages = []
def print_patient_info(self, prefix=""):
print("")
print(prefix + "PatientName: " + self.PatientInfo.PatientName)
print(prefix + "PatientID: " + self.PatientInfo.PatientID)
for ct in self.CTimages:
print("")
ct.print_CT_info(prefix + " ")
def import_patient_data(self, newvoxelsize=None):
# import CT images
for i, ct in enumerate(self.CTimages):
if (ct.isLoaded == 1):
continue
ct.import_Dicom_CT()
# Resample CT images
for i, ct in enumerate(self.CTimages):
ct.resample_CT(newvoxelsize)
class PatientInfo:
def __init__(self):
self.PatientID = ''
self.PatientName = ''
self.PatientBirthDate = ''
self.PatientSex = ''
class StudyInfo:
def __init__(self):
self.StudyInstanceUID = ''
self.StudyID = ''
self.StudyDate = ''
self.StudyTime = ''
class CTimage:
def __init__(self):
self.SeriesInstanceUID = ""
self.PatientInfo = {}
self.StudyInfo = {}
self.FrameOfReferenceUID = ""
self.ImgName = ""
self.SOPClassUID = ""
self.DcmFiles = []
self.isLoaded = 0
def print_CT_info(self, prefix=""):
print(prefix + "CT series: " + self.SeriesInstanceUID)
for ct_slice in self.DcmFiles:
print(prefix + " " + ct_slice)
def resample_CT(self, newvoxelsize):
ct = self.Image
# Rescaling to the newvoxelsize if given in parameter
if newvoxelsize is not None:
source_shape = self.GridSize
voxelsize = self.PixelSpacing
# print("self.ImagePositionPatient",self.ImagePositionPatient, "source_shape",source_shape,"voxelsize",voxelsize)
VoxelX_source = self.ImagePositionPatient[0] + \
np.arange(source_shape[0])*voxelsize[0]
VoxelY_source = self.ImagePositionPatient[1] + \
np.arange(source_shape[1])*voxelsize[1]
VoxelZ_source = self.ImagePositionPatient[2] + \
np.arange(source_shape[2])*voxelsize[2]
target_shape = np.ceil(np.array(source_shape).astype(
float)*np.array(voxelsize).astype(float)/newvoxelsize).astype(int)
VoxelX_target = self.ImagePositionPatient[0] + \
np.arange(target_shape[0])*newvoxelsize[0]
VoxelY_target = self.ImagePositionPatient[1] + \
np.arange(target_shape[1])*newvoxelsize[1]
VoxelZ_target = self.ImagePositionPatient[2] + \
np.arange(target_shape[2])*newvoxelsize[2]
# print("source_shape",source_shape,"target_shape",target_shape)
if (all(source_shape == target_shape) and np.linalg.norm(np.subtract(voxelsize, newvoxelsize) < 0.001)):
print("Image does not need filtering")
else:
# anti-aliasing filter
sigma = [0, 0, 0]
if (newvoxelsize[0] > voxelsize[0]):
sigma[0] = 0.4 * (newvoxelsize[0]/voxelsize[0])
if (newvoxelsize[1] > voxelsize[1]):
sigma[1] = 0.4 * (newvoxelsize[1]/voxelsize[1])
if (newvoxelsize[2] > voxelsize[2]):
sigma[2] = 0.4 * (newvoxelsize[2]/voxelsize[2])
if (sigma != [0, 0, 0]):
print("Image is filtered before downsampling")
ct = scipy.ndimage.gaussian_filter(ct, sigma)
xi = np.array(np.meshgrid(
VoxelX_target, VoxelY_target, VoxelZ_target))
xi = np.rollaxis(xi, 0, 4)
xi = xi.reshape((xi.size // 3, 3))
# get resized ct
ct = scipy.interpolate.interpn((VoxelX_source, VoxelY_source, VoxelZ_source), ct, xi, method='linear',
fill_value=-1000, bounds_error=False).reshape(target_shape).transpose(1, 0, 2)
self.PixelSpacing = newvoxelsize
self.GridSize = list(ct.shape)
self.NumVoxels = self.GridSize[0] * self.GridSize[1] * self.GridSize[2]
self.Image = ct
# print("self.ImagePositionPatient",self.ImagePositionPatient, "self.GridSize[0]",self.GridSize[0],"self.PixelSpacing",self.PixelSpacing)
self.VoxelX = self.ImagePositionPatient[0] + \
np.arange(self.GridSize[0])*self.PixelSpacing[0]
self.VoxelY = self.ImagePositionPatient[1] + \
np.arange(self.GridSize[1])*self.PixelSpacing[1]
self.VoxelZ = self.ImagePositionPatient[2] + \
np.arange(self.GridSize[2])*self.PixelSpacing[2]
self.isLoaded = 1
def import_Dicom_CT(self):
if (self.isLoaded == 1):
print("Warning: CT serries " +
self.SeriesInstanceUID + " is already loaded")
return
images = []
SOPInstanceUIDs = []
SliceLocation = np.zeros(len(self.DcmFiles), dtype='float')
for i in range(len(self.DcmFiles)):
file_path = self.DcmFiles[i]
dcm = pydicom.dcmread(file_path)
if (hasattr(dcm, 'SliceLocation') and abs(dcm.SliceLocation - dcm.ImagePositionPatient[2]) > 0.001):
print("WARNING: SliceLocation (" + str(dcm.SliceLocation) +
") is different than ImagePositionPatient[2] (" + str(dcm.ImagePositionPatient[2]) + ") for " + file_path)
SliceLocation[i] = float(dcm.ImagePositionPatient[2])
images.append(dcm.pixel_array * dcm.RescaleSlope +
dcm.RescaleIntercept)
SOPInstanceUIDs.append(dcm.SOPInstanceUID)
# sort slices according to their location in order to reconstruct the 3d image
sort_index = np.argsort(SliceLocation)
SliceLocation = SliceLocation[sort_index]
SOPInstanceUIDs = [SOPInstanceUIDs[n] for n in sort_index]
images = [images[n] for n in sort_index]
ct = np.dstack(images).astype("float32")
if ct.shape[0:2] != (dcm.Rows, dcm.Columns):
print("WARNING: GridSize " + str(ct.shape[0:2]) + " different from Dicom Rows (" + str(
dcm.Rows) + ") and Columns (" + str(dcm.Columns) + ")")
MeanSliceDistance = (
SliceLocation[-1] - SliceLocation[0]) / (len(images)-1)
if (abs(MeanSliceDistance - dcm.SliceThickness) > 0.001):
print("WARNING: MeanSliceDistance (" + str(MeanSliceDistance) +
") is different from SliceThickness (" + str(dcm.SliceThickness) + ")")
self.FrameOfReferenceUID = dcm.FrameOfReferenceUID
self.ImagePositionPatient = [float(dcm.ImagePositionPatient[0]), float(
dcm.ImagePositionPatient[1]), SliceLocation[0]]
self.PixelSpacing = [float(dcm.PixelSpacing[0]), float(
dcm.PixelSpacing[1]), MeanSliceDistance]
self.GridSize = list(ct.shape)
self.NumVoxels = self.GridSize[0] * self.GridSize[1] * self.GridSize[2]
self.Image = ct
self.SOPInstanceUIDs = SOPInstanceUIDs
self.VoxelX = self.ImagePositionPatient[0] + \
np.arange(self.GridSize[0])*self.PixelSpacing[0]
self.VoxelY = self.ImagePositionPatient[1] + \
np.arange(self.GridSize[1])*self.PixelSpacing[1]
self.VoxelZ = self.ImagePositionPatient[2] + \
np.arange(self.GridSize[2])*self.PixelSpacing[2]
self.isLoaded = 1
print("Convert CT dicom to nii done")