onkarsus13's picture
Update README.md
381999b verified
metadata
license: mit

This is the trained model for the controlnet-stablediffusion for the Synthetic CT/MRI generaion from Segmentation Map We have to customize the pipeline for controlnet-stablediffusion

This Model is trained on the JHU dataset, containing, 5312 CT volumes with corrosponding Segmentation mask,

We make the 2D slices of CT volumes ~ 1.3M 2D slices

Here is the training and inference code for Diff_Synth_CT

Training details

Hardware: 8x Nvidia-A6000

Batch size: 8 x 4 x 32

For direct inference

step 1: Clone the GitHub repo to get the customized ControlNet-StableDiffusion Pipeline Implementation

git clone https://github.com/Onkarsus13/DiffCTSeg

Step2: Go into the repository and install repository, dependency

cd DiffCTSeg
pip install -e ".[torch]"
pip install -e .[all,dev,notebooks]

Step3: Run python test_eraser.py OR You can run the code given below

from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler, PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler
import torch
from PIL import Image
import numpy as np
import glob


class_dict_BTCV = {
        0:(0, 0, 0),
        1:(255, 60, 0),
        2:(255, 60, 232),
        3:(134, 79, 117),
        4:(125, 0, 190),
        5:(117, 200, 191),
        6:(230, 91, 101),
        7:(255, 0, 155),
        8:(75, 205, 155),
        9:(100, 37, 200)
}

class_dict = {
        0:"background",
        1:"aorta",
        2:"kidney_left",
        3:"liver",
        4:"postcava",
        5:"stomach",
        6:"gall_bladder",
        7:"kidney_right",
        8:"pancreas",
        9:"spleen"
}

def rgb_to_onehot(rgb_arr, color_dict=class_dict_BTCV):
    num_classes = len(color_dict)
    shape = rgb_arr.shape[:2]+(num_classes,)
    arr = np.zeros( shape, dtype=np.int8 )
    for i, cls in enumerate(color_dict):
        arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2])
    return arr



pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    "onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1", torch_dtype=torch.float16, safety_checker=None,
        feature_extractor=None,
)
pipe.scheduler = UniPCMultistepScheduler.from_pretrained('onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1', subfolder="scheduler")
pipe.to('cuda:0')
pipe.enable_model_cpu_offload()


generator = torch.Generator(device="cpu").manual_seed(1)
images = Image.open("<Give Segmentation Mask>")
npi = np.asarray(images.convert("RGB"))
npi = rgb_to_onehot(npi, ).argmax(-1)
unique_ids = np.unique(npi)

print('CT image containg '+" ".join([class_dict[i] for i in unique_ids]))
image = pipe(
    'CT image containg '+" ".join([class_dict[i] for i in unique_ids]),
    images,
    [images],
    num_inference_steps=30,
    generator=generator,
    controlnet_conditioning_scale=1.0,
).images[0]

image.save('./result.png')