--- license: mit --- This is the trained model for the controlnet-stablediffusion for the Synthetic CT/MRI generaion from Segmentation Map We have to customize the pipeline for controlnet-stablediffusion This Model is trained on the JHU dataset, containing, 5312 CT volumes with corrosponding Segmentation mask, We make the 2D slices of CT volumes ~ 1.3M 2D slices Here is the training and inference code for [Diff_Synth_CT](https://github.com/Onkarsus13/DiffCTSeg) Training details Hardware: 8x Nvidia-A6000 Batch size: 8 x 4 x 32 For direct inference step 1: Clone the GitHub repo to get the customized ControlNet-StableDiffusion Pipeline Implementation ``` git clone https://github.com/Onkarsus13/DiffCTSeg ``` Step2: Go into the repository and install repository, dependency ``` cd DiffCTSeg pip install -e ".[torch]" pip install -e .[all,dev,notebooks] ``` Step3: Run `python test_eraser.py` OR You can run the code given below ```python from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler, PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler import torch from PIL import Image import numpy as np import glob class_dict_BTCV = { 0:(0, 0, 0), 1:(255, 60, 0), 2:(255, 60, 232), 3:(134, 79, 117), 4:(125, 0, 190), 5:(117, 200, 191), 6:(230, 91, 101), 7:(255, 0, 155), 8:(75, 205, 155), 9:(100, 37, 200) } class_dict = { 0:"background", 1:"aorta", 2:"kidney_left", 3:"liver", 4:"postcava", 5:"stomach", 6:"gall_bladder", 7:"kidney_right", 8:"pancreas", 9:"spleen" } def rgb_to_onehot(rgb_arr, color_dict=class_dict_BTCV): num_classes = len(color_dict) shape = rgb_arr.shape[:2]+(num_classes,) arr = np.zeros( shape, dtype=np.int8 ) for i, cls in enumerate(color_dict): arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2]) return arr pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( "onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1", torch_dtype=torch.float16, safety_checker=None, feature_extractor=None, ) pipe.scheduler = UniPCMultistepScheduler.from_pretrained('onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1', subfolder="scheduler") pipe.to('cuda:0') pipe.enable_model_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(1) images = Image.open("") npi = np.asarray(images.convert("RGB")) npi = rgb_to_onehot(npi, ).argmax(-1) unique_ids = np.unique(npi) print('CT image containg '+" ".join([class_dict[i] for i in unique_ids])) image = pipe( 'CT image containg '+" ".join([class_dict[i] for i in unique_ids]), images, [images], num_inference_steps=30, generator=generator, controlnet_conditioning_scale=1.0, ).images[0] image.save('./result.png') ```