Spaces:
Sleeping
Sleeping
File size: 6,299 Bytes
0902a5f e7a5f93 86f09db e7a5f93 0902a5f 86f09db 0902a5f 11935ed 0902a5f 11935ed 0902a5f bd89e06 0902a5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from cldm.ddim_hacked import DDIMSampler
import torch
from annotator.render_images import render_text_image_custom
from pytorch_lightning import seed_everything
save_memory = False
from cldm.hack import disable_verbosity
disable_verbosity()
import random
import einops
import numpy as np
from ldm.util import instantiate_from_config
from cldm.model import load_state_dict
from torchvision.transforms import ToTensor
from contextlib import nullcontext
def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False):
# if "model_ema.input_blocks10in_layers0weight" not in sd:
# print("missing model_ema.input_blocks10in_layers0weight. set use_ema as False")
# cfg.model.params.use_ema = False
model = instantiate_from_config(cfg.model)
if ckpt.endswith("model_states.pt"):
sd = torch.load(ckpt, map_location='cpu')["module"]
else:
sd = load_state_dict(ckpt, location='cpu')
keys_ = list(sd.keys())[:]
for k in keys_:
if k.startswith("module."):
nk = k[7:]
sd[nk] = sd[k]
del sd[k]
if not not_use_ckpt:
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys: {}".format(len(m)))
print(m)
if len(u) > 0 and verbose:
print("unexpected keys: {}".format(len(u)))
print(u)
if torch.cuda.is_available():
model.cuda()
model.eval()
return model
class Render_Text:
def __init__(self,
model,
precision_scope=nullcontext,
transform=ToTensor()
):
self.model = model
self.precision_scope = precision_scope
self.transform = transform
self.ddim_sampler = DDIMSampler(model)
def process_multi(self,
rendered_txt_values, shared_prompt,
width_values, ratio_values,
top_left_x_values, top_left_y_values,
yaw_values, num_rows_values,
shared_num_samples, shared_image_resolution,
shared_ddim_steps, shared_guess_mode,
shared_strength, shared_scale, shared_seed,
shared_eta, shared_a_prompt, shared_n_prompt,
only_show_rendered_image=False
):
with torch.no_grad(), \
self.precision_scope("cuda"), \
self.model.ema_scope("Sampling on Benchmark Prompts"):
print("rendered txt:", str(rendered_txt_values), "[t]")
if rendered_txt_values == "":
control = None
else:
def format_bboxes(width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values):
bboxes = []
for width, ratio, top_left_x, top_left_y, yaw in zip(width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values):
bbox = {
"width": width,
"ratio": ratio,
# "height": height,
"top_left_x": top_left_x,
"top_left_y": top_left_y,
"yaw": yaw
}
bboxes.append(bbox)
return bboxes
whiteboard_img = render_text_image_custom(
(shared_image_resolution, shared_image_resolution),
format_bboxes(width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values),
rendered_txt_values,
num_rows_values
)
whiteboard_img = whiteboard_img.convert("RGB")
if only_show_rendered_image:
return [whiteboard_img]
control = self.transform(whiteboard_img.copy())
if torch.cuda.is_available():
control = control.cuda()
control = torch.stack([control for _ in range(shared_num_samples)], dim=0)
control = control.clone()
control = [control]
H, W = shared_image_resolution, shared_image_resolution
if shared_seed == -1:
shared_seed = random.randint(0, 65535)
seed_everything(shared_seed)
print("control is None: {}".format(control is None))
print("prompt for the SD branch:", str(shared_prompt), "[t]")
cond_c_cross = self.model.get_learned_conditioning([shared_prompt + ', ' + shared_a_prompt] * shared_num_samples)
un_cond_cross = self.model.get_learned_conditioning([shared_n_prompt] * shared_num_samples)
cond = {"c_concat": control, "c_crossattn": [cond_c_cross] if not isinstance(cond_c_cross, list) else cond_c_cross}
un_cond = {"c_concat": None if shared_guess_mode else control, "c_crossattn": [un_cond_cross] if not isinstance(un_cond_cross, list) else un_cond_cross}
shape = (4, H // 8, W // 8)
if not self.model.learnable_conscale:
self.model.control_scales = [shared_strength * (0.825 ** float(12 - i)) for i in range(13)] if shared_guess_mode else ([shared_strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
else:
print("learned control scale: {}".format(str(self.model.control_scales)))
samples, intermediates = self.ddim_sampler.sample(shared_ddim_steps, shared_num_samples,
shape, cond, verbose=False, eta=shared_eta,
unconditional_guidance_scale=shared_scale,
unconditional_conditioning=un_cond)
x_samples = self.model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(shared_num_samples)]
if rendered_txt_values != "":
return [whiteboard_img] + results
else:
return results
|