metadata
license: mit
This is the trained model for the controlnet-stablediffusion for the Synthetic CT/MRI generaion from Segmentation Map We have to customize the pipeline for controlnet-stablediffusion
This Model is trained on the JHU dataset, containing, 5312 CT volumes with corrosponding Segmentation mask,
We make the 2D slices of CT volumes ~ 1.3M 2D slices
Here is the training and inference code for Diff_Synth_CT
Training details
Hardware: 8x Nvidia-A6000
Batch size: 8 x 4 x 32
For direct inference
step 1: Clone the GitHub repo to get the customized ControlNet-StableDiffusion Pipeline Implementation
git clone https://github.com/Onkarsus13/DiffCTSeg
Step2: Go into the repository and install repository, dependency
cd DiffCTSeg
pip install -e ".[torch]"
pip install -e .[all,dev,notebooks]
Step3: Run python test_eraser.py
OR You can run the code given below
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler, PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler
import torch
from PIL import Image
import numpy as np
import glob
class_dict_BTCV = {
0:(0, 0, 0),
1:(255, 60, 0),
2:(255, 60, 232),
3:(134, 79, 117),
4:(125, 0, 190),
5:(117, 200, 191),
6:(230, 91, 101),
7:(255, 0, 155),
8:(75, 205, 155),
9:(100, 37, 200)
}
class_dict = {
0:"background",
1:"aorta",
2:"kidney_left",
3:"liver",
4:"postcava",
5:"stomach",
6:"gall_bladder",
7:"kidney_right",
8:"pancreas",
9:"spleen"
}
def rgb_to_onehot(rgb_arr, color_dict=class_dict_BTCV):
num_classes = len(color_dict)
shape = rgb_arr.shape[:2]+(num_classes,)
arr = np.zeros( shape, dtype=np.int8 )
for i, cls in enumerate(color_dict):
arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2])
return arr
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1", torch_dtype=torch.float16, safety_checker=None,
feature_extractor=None,
)
pipe.scheduler = UniPCMultistepScheduler.from_pretrained('onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1', subfolder="scheduler")
pipe.to('cuda:0')
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(1)
images = Image.open("<Give Segmentation Mask>")
npi = np.asarray(images.convert("RGB"))
npi = rgb_to_onehot(npi, ).argmax(-1)
unique_ids = np.unique(npi)
print('CT image containg '+" ".join([class_dict[i] for i in unique_ids]))
image = pipe(
'CT image containg '+" ".join([class_dict[i] for i in unique_ids]),
images,
[images],
num_inference_steps=30,
generator=generator,
controlnet_conditioning_scale=1.0,
).images[0]
image.save('./result.png')