Spaces:
Sleeping
Sleeping
File size: 8,883 Bytes
0902a5f 1f7ae51 0902a5f e7a5f93 86f09db e7a5f93 0902a5f 86f09db 0902a5f 11935ed 0902a5f 200818a 0902a5f 73bb868 0902a5f 73bb868 e06ab5a 0902a5f 1f7ae51 0902a5f e06ab5a 0902a5f e06ab5a 0902a5f 11935ed 0902a5f 1f7ae51 0902a5f 73bb868 0902a5f 73bb868 0902a5f 73bb868 0902a5f 73bb868 0902a5f e06ab5a 0902a5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
from cldm.ddim_hacked import DDIMSampler
import torch
from annotator.render_images import render_text_image_custom
from pytorch_lightning import seed_everything
# save_memory = False
# from cldm.hack import disable_verbosity
# disable_verbosity()
import random
import einops
import numpy as np
from ldm.util import instantiate_from_config
from cldm.model import load_state_dict
from torchvision.transforms import ToTensor
from contextlib import nullcontext
def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False):
# if "model_ema.input_blocks10in_layers0weight" not in sd:
# print("missing model_ema.input_blocks10in_layers0weight. set use_ema as False")
# cfg.model.params.use_ema = False
model = instantiate_from_config(cfg.model)
if ckpt.endswith("model_states.pt"):
sd = torch.load(ckpt, map_location='cpu')["module"]
else:
sd = load_state_dict(ckpt, location='cpu')
keys_ = list(sd.keys())[:]
for k in keys_:
if k.startswith("module."):
nk = k[7:]
sd[nk] = sd[k]
del sd[k]
if not not_use_ckpt:
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys: {}".format(len(m)))
print(m)
if len(u) > 0 and verbose:
print("unexpected keys: {}".format(len(u)))
print(u)
if torch.cuda.is_available():
model.cuda()
model.eval()
return model
def load_model_ckpt(model, ckpt, verbose=True):
map_location = "cpu" if not torch.cuda.is_available() else "cuda"
print("checkpoint map location:", map_location)
if ckpt.endswith("model_states.pt"):
sd = torch.load(ckpt, map_location=map_location)["module"]
else:
sd = load_state_dict(ckpt, location=map_location)
keys_ = list(sd.keys())[:]
for k in keys_:
if k.startswith("module."):
nk = k[7:]
sd[nk] = sd[k]
del sd[k]
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys: {}".format(len(m)))
print(m)
if len(u) > 0 and verbose:
print("unexpected keys: {}".format(len(u)))
print(u)
model.eval()
return model
class Render_Text:
def __init__(self,
model,
precision_scope=nullcontext,
transform=ToTensor(),
save_memory = False,
):
self.model = model
self.precision_scope = precision_scope
self.transform = transform
self.ddim_sampler = DDIMSampler(model)
self.save_memory = save_memory
# process multiple groups of rendered text for building demo
def process_multi(self,
rendered_txt_values, shared_prompt,
width_values, ratio_values,
top_left_x_values, top_left_y_values,
yaw_values, num_rows_values,
shared_num_samples, shared_image_resolution,
shared_ddim_steps, shared_guess_mode,
shared_strength, shared_scale, shared_seed,
shared_eta, shared_a_prompt, shared_n_prompt,
only_show_rendered_image=False
):
if shared_seed == -1:
shared_seed = random.randint(0, 65535)
seed_everything(shared_seed)
with torch.no_grad(), \
self.precision_scope("cuda"), \
self.model.ema_scope("Sampling on Benchmark Prompts"):
print("rendered txt:", str(rendered_txt_values), "[t]")
render_none = len([1 for rendered_txt in rendered_txt_values if rendered_txt != ""]) == 0
if render_none:
# if rendered_txt_values == "":
control = None
if only_show_rendered_image:
return [None]
else:
def format_bboxes(width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values):
bboxes = []
for width, ratio, top_left_x, top_left_y, yaw in zip(width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values):
bbox = {
"width": width,
"ratio": ratio,
# "height": height,
"top_left_x": top_left_x,
"top_left_y": top_left_y,
"yaw": yaw
}
bboxes.append(bbox)
return bboxes
whiteboard_img = render_text_image_custom(
(shared_image_resolution, shared_image_resolution),
format_bboxes(width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values),
rendered_txt_values,
num_rows_values
)
whiteboard_img = whiteboard_img.convert("RGB")
if only_show_rendered_image:
return [whiteboard_img]
control = self.transform(whiteboard_img.copy())
if torch.cuda.is_available():
control = control.cuda()
control = torch.stack([control for _ in range(shared_num_samples)], dim=0)
control = control.clone()
control = [control]
H, W = shared_image_resolution, shared_image_resolution
# if shared_seed == -1:
# shared_seed = random.randint(0, 65535)
# seed_everything(shared_seed)
if torch.cuda.is_available() and self.save_memory:
print("low_vram_shift: is_diffusing", False)
self.model.low_vram_shift(is_diffusing=False)
print("control is None: {}".format(control is None))
if shared_prompt.endswith("."):
if shared_a_prompt == "":
c_prompt = shared_prompt
else:
c_prompt = shared_prompt + " " + shared_a_prompt
elif shared_prompt.endswith(","):
if shared_a_prompt == "":
c_prompt = shared_prompt[:-1] + "."
else:
c_prompt = shared_prompt + " " + shared_a_prompt
else:
if shared_a_prompt == "":
c_prompt = shared_prompt + "."
else:
c_prompt = shared_prompt + ", " + shared_a_prompt
# cond_c_cross = self.model.get_learned_conditioning([shared_prompt + ', ' + shared_a_prompt] * shared_num_samples)
cond_c_cross = self.model.get_learned_conditioning([c_prompt] * shared_num_samples)
print("prompt:", c_prompt)
un_cond_cross = self.model.get_learned_conditioning([shared_n_prompt] * shared_num_samples)
if torch.cuda.is_available() and self.save_memory:
print("low_vram_shift: is_diffusing", True)
self.model.low_vram_shift(is_diffusing=True)
cond = {"c_concat": control, "c_crossattn": [cond_c_cross] if not isinstance(cond_c_cross, list) else cond_c_cross}
un_cond = {"c_concat": None if shared_guess_mode else control, "c_crossattn": [un_cond_cross] if not isinstance(un_cond_cross, list) else un_cond_cross}
shape = (4, H // 8, W // 8)
if not self.model.learnable_conscale:
self.model.control_scales = [shared_strength * (0.825 ** float(12 - i)) for i in range(13)] if shared_guess_mode else ([shared_strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
else:
print("learned control scale: {}".format(str(self.model.control_scales)))
samples, intermediates = self.ddim_sampler.sample(shared_ddim_steps, shared_num_samples,
shape, cond, verbose=False, eta=shared_eta,
unconditional_guidance_scale=shared_scale,
unconditional_conditioning=un_cond)
if torch.cuda.is_available() and self.save_memory:
print("low_vram_shift: is_diffusing", False)
self.model.low_vram_shift(is_diffusing=False)
x_samples = self.model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(shared_num_samples)]
# if rendered_txt_values != "":
if not render_none:
return [whiteboard_img] + results
else:
return results
|