Spaces:
Runtime error
Runtime error
# Copyright 2022 Lunar Ring. All rights reserved. | |
# Written by Johannes Stelzer, email [email protected] twitter @j_stelzer | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import torch | |
torch.backends.cudnn.benchmark = False | |
torch.set_grad_enabled(False) | |
import numpy as np | |
import warnings | |
warnings.filterwarnings('ignore') | |
import warnings | |
import torch | |
from PIL import Image | |
import torch | |
from typing import Optional | |
from omegaconf import OmegaConf | |
from torch import autocast | |
from contextlib import nullcontext | |
from ldm.util import instantiate_from_config | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from einops import repeat, rearrange | |
from utils import interpolate_spherical | |
def pad_image(input_image): | |
pad_w, pad_h = np.max(((2, 2), np.ceil( | |
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size | |
im_padded = Image.fromarray( | |
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) | |
return im_padded | |
def make_batch_superres( | |
image, | |
txt, | |
device, | |
num_samples=1): | |
image = np.array(image.convert("RGB")) | |
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 | |
batch = { | |
"lr": rearrange(image, 'h w c -> 1 c h w'), | |
"txt": num_samples * [txt], | |
} | |
batch["lr"] = repeat(batch["lr"].to(device=device), | |
"1 ... -> n ...", n=num_samples) | |
return batch | |
def make_noise_augmentation(model, batch, noise_level=None): | |
x_low = batch[model.low_scale_key] | |
x_low = x_low.to(memory_format=torch.contiguous_format).float() | |
x_aug, noise_level = model.low_scale_model(x_low, noise_level) | |
return x_aug, noise_level | |
class StableDiffusionHolder: | |
def __init__(self, | |
fp_ckpt: str = None, | |
fp_config: str = None, | |
num_inference_steps: int = 30, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
device: str = None, | |
precision: str = 'autocast', | |
): | |
r""" | |
Initializes the stable diffusion holder, which contains the models and sampler. | |
Args: | |
fp_ckpt: File pointer to the .ckpt model file | |
fp_config: File pointer to the .yaml config file | |
num_inference_steps: Number of diffusion iterations. Will be overwritten by latent blending. | |
height: Height of the resulting image. | |
width: Width of the resulting image. | |
device: Device to run the model on. | |
precision: Precision to run the model on. | |
""" | |
self.seed = 42 | |
self.guidance_scale = 5.0 | |
if device is None: | |
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
else: | |
self.device = device | |
self.precision = precision | |
self.init_model(fp_ckpt, fp_config) | |
self.f = 8 # downsampling factor, most often 8 or 16" | |
self.C = 4 | |
self.ddim_eta = 0 | |
self.num_inference_steps = num_inference_steps | |
if height is None and width is None: | |
self.init_auto_res() | |
else: | |
assert height is not None, "specify both width and height" | |
assert width is not None, "specify both width and height" | |
self.height = height | |
self.width = width | |
self.negative_prompt = [""] | |
def init_model(self, fp_ckpt, fp_config): | |
r"""Loads the models and sampler. | |
""" | |
assert os.path.isfile(fp_ckpt), f"Your model checkpoint file does not exist: {fp_ckpt}" | |
self.fp_ckpt = fp_ckpt | |
# Auto init the config? | |
if fp_config is None: | |
fn_ckpt = os.path.basename(fp_ckpt) | |
if 'depth' in fn_ckpt: | |
fp_config = 'configs/v2-midas-inference.yaml' | |
elif 'upscaler' in fn_ckpt: | |
fp_config = 'configs/x4-upscaling.yaml' | |
elif '512' in fn_ckpt: | |
fp_config = 'configs/v2-inference.yaml' | |
elif '768' in fn_ckpt: | |
fp_config = 'configs/v2-inference-v.yaml' | |
elif 'v1-5' in fn_ckpt: | |
fp_config = 'configs/v1-inference.yaml' | |
else: | |
raise ValueError("auto detect of config failed. please specify fp_config manually!") | |
assert os.path.isfile(fp_config), "Auto-init of the config file failed. Please specify manually." | |
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}" | |
config = OmegaConf.load(fp_config) | |
self.model = instantiate_from_config(config.model) | |
self.model.load_state_dict(torch.load(fp_ckpt)["state_dict"], strict=False) | |
self.model = self.model.to(self.device) | |
self.sampler = DDIMSampler(self.model) | |
def init_auto_res(self): | |
r"""Automatically set the resolution to the one used in training. | |
""" | |
if '768' in self.fp_ckpt: | |
self.height = 768 | |
self.width = 768 | |
else: | |
self.height = 512 | |
self.width = 512 | |
def set_negative_prompt(self, negative_prompt): | |
r"""Set the negative prompt. Currenty only one negative prompt is supported | |
""" | |
if isinstance(negative_prompt, str): | |
self.negative_prompt = [negative_prompt] | |
else: | |
self.negative_prompt = negative_prompt | |
if len(self.negative_prompt) > 1: | |
self.negative_prompt = [self.negative_prompt[0]] | |
def get_text_embedding(self, prompt): | |
c = self.model.get_learned_conditioning(prompt) | |
return c | |
def get_cond_upscaling(self, image, text_embedding, noise_level): | |
r""" | |
Initializes the conditioning for the x4 upscaling model. | |
""" | |
image = pad_image(image) # resize to integer multiple of 32 | |
w, h = image.size | |
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long() | |
batch = make_batch_superres(image, txt="placeholder", device=self.device, num_samples=1) | |
x_augment, noise_level = make_noise_augmentation(self.model, batch, noise_level) | |
cond = {"c_concat": [x_augment], "c_crossattn": [text_embedding], "c_adm": noise_level} | |
# uncond cond | |
uc_cross = self.model.get_unconditional_conditioning(1, "") | |
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level} | |
return cond, uc_full | |
def run_diffusion_standard( | |
self, | |
text_embeddings: torch.FloatTensor, | |
latents_start: torch.FloatTensor, | |
idx_start: int = 0, | |
list_latents_mixing=None, | |
mixing_coeffs=0.0, | |
spatial_mask=None, | |
return_image: Optional[bool] = False): | |
r""" | |
Diffusion standard version. | |
Args: | |
text_embeddings: torch.FloatTensor | |
Text embeddings used for diffusion | |
latents_for_injection: torch.FloatTensor or list | |
Latents that are used for injection | |
idx_start: int | |
Index of the diffusion process start and where the latents_for_injection are injected | |
mixing_coeff: | |
mixing coefficients for latent blending | |
spatial_mask: | |
experimental feature for enforcing pixels from list_latents_mixing | |
return_image: Optional[bool] | |
Optionally return image directly | |
""" | |
# Asserts | |
if type(mixing_coeffs) == float: | |
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs] | |
elif type(mixing_coeffs) == list: | |
assert len(mixing_coeffs) == self.num_inference_steps | |
list_mixing_coeffs = mixing_coeffs | |
else: | |
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") | |
if np.sum(list_mixing_coeffs) > 0: | |
assert len(list_latents_mixing) == self.num_inference_steps | |
precision_scope = autocast if self.precision == "autocast" else nullcontext | |
with precision_scope("cuda"): | |
with self.model.ema_scope(): | |
if self.guidance_scale != 1.0: | |
uc = self.model.get_learned_conditioning(self.negative_prompt) | |
else: | |
uc = None | |
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False) | |
latents = latents_start.clone() | |
timesteps = self.sampler.ddim_timesteps | |
time_range = np.flip(timesteps) | |
total_steps = timesteps.shape[0] | |
# Collect latents | |
list_latents_out = [] | |
for i, step in enumerate(time_range): | |
# Set the right starting latents | |
if i < idx_start: | |
list_latents_out.append(None) | |
continue | |
elif i == idx_start: | |
latents = latents_start.clone() | |
# Mix latents | |
if i > 0 and list_mixing_coeffs[i] > 0: | |
latents_mixtarget = list_latents_mixing[i - 1].clone() | |
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) | |
if spatial_mask is not None and list_latents_mixing is not None: | |
latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask) | |
index = total_steps - i - 1 | |
ts = torch.full((1,), step, device=self.device, dtype=torch.long) | |
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False, | |
quantize_denoised=False, temperature=1.0, | |
noise_dropout=0.0, score_corrector=None, | |
corrector_kwargs=None, | |
unconditional_guidance_scale=self.guidance_scale, | |
unconditional_conditioning=uc, | |
dynamic_threshold=None) | |
latents, pred_x0 = outs | |
list_latents_out.append(latents.clone()) | |
if return_image: | |
return self.latent2image(latents) | |
else: | |
return list_latents_out | |
def run_diffusion_upscaling( | |
self, | |
cond, | |
uc_full, | |
latents_start: torch.FloatTensor, | |
idx_start: int = -1, | |
list_latents_mixing: list = None, | |
mixing_coeffs: float = 0.0, | |
return_image: Optional[bool] = False): | |
r""" | |
Diffusion upscaling version. | |
""" | |
# Asserts | |
if type(mixing_coeffs) == float: | |
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs] | |
elif type(mixing_coeffs) == list: | |
assert len(mixing_coeffs) == self.num_inference_steps | |
list_mixing_coeffs = mixing_coeffs | |
else: | |
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") | |
if np.sum(list_mixing_coeffs) > 0: | |
assert len(list_latents_mixing) == self.num_inference_steps | |
precision_scope = autocast if self.precision == "autocast" else nullcontext | |
h = uc_full['c_concat'][0].shape[2] | |
w = uc_full['c_concat'][0].shape[3] | |
with precision_scope("cuda"): | |
with self.model.ema_scope(): | |
shape_latents = [self.model.channels, h, w] | |
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False) | |
C, H, W = shape_latents | |
size = (1, C, H, W) | |
b = size[0] | |
latents = latents_start.clone() | |
timesteps = self.sampler.ddim_timesteps | |
time_range = np.flip(timesteps) | |
total_steps = timesteps.shape[0] | |
# collect latents | |
list_latents_out = [] | |
for i, step in enumerate(time_range): | |
# Set the right starting latents | |
if i < idx_start: | |
list_latents_out.append(None) | |
continue | |
elif i == idx_start: | |
latents = latents_start.clone() | |
# Mix the latents. | |
if i > 0 and list_mixing_coeffs[i] > 0: | |
latents_mixtarget = list_latents_mixing[i - 1].clone() | |
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) | |
# print(f"diffusion iter {i}") | |
index = total_steps - i - 1 | |
ts = torch.full((b,), step, device=self.device, dtype=torch.long) | |
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False, | |
quantize_denoised=False, temperature=1.0, | |
noise_dropout=0.0, score_corrector=None, | |
corrector_kwargs=None, | |
unconditional_guidance_scale=self.guidance_scale, | |
unconditional_conditioning=uc_full, | |
dynamic_threshold=None) | |
latents, pred_x0 = outs | |
list_latents_out.append(latents.clone()) | |
if return_image: | |
return self.latent2image(latents) | |
else: | |
return list_latents_out | |
def latent2image( | |
self, | |
latents: torch.FloatTensor): | |
r""" | |
Returns an image provided a latent representation from diffusion. | |
Args: | |
latents: torch.FloatTensor | |
Result of the diffusion process. | |
""" | |
x_sample = self.model.decode_first_stage(latents) | |
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) | |
x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy() | |
image = x_sample.astype(np.uint8) | |
return image | |