LSDM / generate.py
QinLei086's picture
Upload 28 files
15acbf0 verified
raw
history blame contribute delete
No virus
3.42 kB
from diffusers.schedulers import UniPCMultistepScheduler
from diffusers import AutoencoderKL
from diffusion_module.unet import UNetModel
import torch
from diffusion_module.utils.LSDMPipeline_expandDataset import SDMLDMPipeline
from accelerate import Accelerator
from evolution import random_walk
import cv2
import numpy as np
def mask2onehot(data, num_classes):
# move to GPU and change data types
data = data.to(dtype=torch.int64)
# create one-hot label map
label_map = data
bs, _, h, w = label_map.size()
input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device)
input_semantics = input_label.scatter_(1, label_map, 1.0)
return input_semantics
def generate(img, pretrain_weight,seed=None):
noise_scheduler = UniPCMultistepScheduler()
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
latent_size = (64, 64)
unet = UNetModel(
image_size = latent_size,
in_channels=vae.config.latent_channels,
model_channels=256,
out_channels=vae.config.latent_channels,
num_res_blocks=2,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 3, 4),
num_heads=8,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=True,
resblock_updown=True,
use_new_attention_order=False,
num_classes=151,
mask_emb="resize",
use_checkpoint=True,
SPADE_type="spade",
)
unet = unet.from_pretrained(pretrain_weight)
device = 'cpu'
if device != 'cpu':
mixed_precision = "fp16"
else:
mixed_precision = "no"
accelerator = Accelerator(
mixed_precision=mixed_precision,
cpu= True if device is 'cpu' else False
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
unet,vae = accelerator.prepare(unet, vae)
vae.to(device=accelerator.device, dtype=weight_dtype)
pipeline = SDMLDMPipeline(
vae=accelerator.unwrap_model(vae),
unet=accelerator.unwrap_model(unet),
scheduler=noise_scheduler,
torch_dtype=weight_dtype,
resolution_type="crack"
)
"""
if accelerator.device != 'cpu':
pipeline.enable_xformers_memory_efficient_attention()
"""
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=False)
if seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(seed)
resized_s = cv2.resize(img, (64, 64), interpolation=cv2.INTER_AREA)
# 灰度图放大到255
_, binary_s = cv2.threshold(resized_s, 1, 255, cv2.THRESH_BINARY)
# 转换为0,1
tensor_s = torch.from_numpy(binary_s / 255)
# h,w -> 1,1,h,w
tensor_s = tensor_s.unsqueeze(0).unsqueeze(0)
onehot_skeletons=[]
onehot_s = mask2onehot(tensor_s, 151)
onehot_skeletons.append(onehot_s)
onehot_skeletons = torch.stack(onehot_skeletons, dim=1).squeeze(0)
onehot_skeletons = onehot_skeletons.to(dtype=weight_dtype,device=accelerator.device)
images = pipeline(onehot_skeletons, generator=generator,batch_size = 1,
num_inference_steps=20, s=1.5,
num_evolution_per_mask=1).images
return images