GlyphControl / scripts /rendertext_tool.py
yyk19's picture
update codes to support the scenario when no glyph instructions are provided.
e06ab5a
raw
history blame
8.88 kB
from cldm.ddim_hacked import DDIMSampler
import torch
from annotator.render_images import render_text_image_custom
from pytorch_lightning import seed_everything
# save_memory = False
# from cldm.hack import disable_verbosity
# disable_verbosity()
import random
import einops
import numpy as np
from ldm.util import instantiate_from_config
from cldm.model import load_state_dict
from torchvision.transforms import ToTensor
from contextlib import nullcontext
def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False):
# if "model_ema.input_blocks10in_layers0weight" not in sd:
# print("missing model_ema.input_blocks10in_layers0weight. set use_ema as False")
# cfg.model.params.use_ema = False
model = instantiate_from_config(cfg.model)
if ckpt.endswith("model_states.pt"):
sd = torch.load(ckpt, map_location='cpu')["module"]
else:
sd = load_state_dict(ckpt, location='cpu')
keys_ = list(sd.keys())[:]
for k in keys_:
if k.startswith("module."):
nk = k[7:]
sd[nk] = sd[k]
del sd[k]
if not not_use_ckpt:
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys: {}".format(len(m)))
print(m)
if len(u) > 0 and verbose:
print("unexpected keys: {}".format(len(u)))
print(u)
if torch.cuda.is_available():
model.cuda()
model.eval()
return model
def load_model_ckpt(model, ckpt, verbose=True):
map_location = "cpu" if not torch.cuda.is_available() else "cuda"
print("checkpoint map location:", map_location)
if ckpt.endswith("model_states.pt"):
sd = torch.load(ckpt, map_location=map_location)["module"]
else:
sd = load_state_dict(ckpt, location=map_location)
keys_ = list(sd.keys())[:]
for k in keys_:
if k.startswith("module."):
nk = k[7:]
sd[nk] = sd[k]
del sd[k]
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys: {}".format(len(m)))
print(m)
if len(u) > 0 and verbose:
print("unexpected keys: {}".format(len(u)))
print(u)
model.eval()
return model
class Render_Text:
def __init__(self,
model,
precision_scope=nullcontext,
transform=ToTensor(),
save_memory = False,
):
self.model = model
self.precision_scope = precision_scope
self.transform = transform
self.ddim_sampler = DDIMSampler(model)
self.save_memory = save_memory
# process multiple groups of rendered text for building demo
def process_multi(self,
rendered_txt_values, shared_prompt,
width_values, ratio_values,
top_left_x_values, top_left_y_values,
yaw_values, num_rows_values,
shared_num_samples, shared_image_resolution,
shared_ddim_steps, shared_guess_mode,
shared_strength, shared_scale, shared_seed,
shared_eta, shared_a_prompt, shared_n_prompt,
only_show_rendered_image=False
):
if shared_seed == -1:
shared_seed = random.randint(0, 65535)
seed_everything(shared_seed)
with torch.no_grad(), \
self.precision_scope("cuda"), \
self.model.ema_scope("Sampling on Benchmark Prompts"):
print("rendered txt:", str(rendered_txt_values), "[t]")
render_none = len([1 for rendered_txt in rendered_txt_values if rendered_txt != ""]) == 0
if render_none:
# if rendered_txt_values == "":
control = None
if only_show_rendered_image:
return [None]
else:
def format_bboxes(width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values):
bboxes = []
for width, ratio, top_left_x, top_left_y, yaw in zip(width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values):
bbox = {
"width": width,
"ratio": ratio,
# "height": height,
"top_left_x": top_left_x,
"top_left_y": top_left_y,
"yaw": yaw
}
bboxes.append(bbox)
return bboxes
whiteboard_img = render_text_image_custom(
(shared_image_resolution, shared_image_resolution),
format_bboxes(width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values),
rendered_txt_values,
num_rows_values
)
whiteboard_img = whiteboard_img.convert("RGB")
if only_show_rendered_image:
return [whiteboard_img]
control = self.transform(whiteboard_img.copy())
if torch.cuda.is_available():
control = control.cuda()
control = torch.stack([control for _ in range(shared_num_samples)], dim=0)
control = control.clone()
control = [control]
H, W = shared_image_resolution, shared_image_resolution
# if shared_seed == -1:
# shared_seed = random.randint(0, 65535)
# seed_everything(shared_seed)
if torch.cuda.is_available() and self.save_memory:
print("low_vram_shift: is_diffusing", False)
self.model.low_vram_shift(is_diffusing=False)
print("control is None: {}".format(control is None))
if shared_prompt.endswith("."):
if shared_a_prompt == "":
c_prompt = shared_prompt
else:
c_prompt = shared_prompt + " " + shared_a_prompt
elif shared_prompt.endswith(","):
if shared_a_prompt == "":
c_prompt = shared_prompt[:-1] + "."
else:
c_prompt = shared_prompt + " " + shared_a_prompt
else:
if shared_a_prompt == "":
c_prompt = shared_prompt + "."
else:
c_prompt = shared_prompt + ", " + shared_a_prompt
# cond_c_cross = self.model.get_learned_conditioning([shared_prompt + ', ' + shared_a_prompt] * shared_num_samples)
cond_c_cross = self.model.get_learned_conditioning([c_prompt] * shared_num_samples)
print("prompt:", c_prompt)
un_cond_cross = self.model.get_learned_conditioning([shared_n_prompt] * shared_num_samples)
if torch.cuda.is_available() and self.save_memory:
print("low_vram_shift: is_diffusing", True)
self.model.low_vram_shift(is_diffusing=True)
cond = {"c_concat": control, "c_crossattn": [cond_c_cross] if not isinstance(cond_c_cross, list) else cond_c_cross}
un_cond = {"c_concat": None if shared_guess_mode else control, "c_crossattn": [un_cond_cross] if not isinstance(un_cond_cross, list) else un_cond_cross}
shape = (4, H // 8, W // 8)
if not self.model.learnable_conscale:
self.model.control_scales = [shared_strength * (0.825 ** float(12 - i)) for i in range(13)] if shared_guess_mode else ([shared_strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
else:
print("learned control scale: {}".format(str(self.model.control_scales)))
samples, intermediates = self.ddim_sampler.sample(shared_ddim_steps, shared_num_samples,
shape, cond, verbose=False, eta=shared_eta,
unconditional_guidance_scale=shared_scale,
unconditional_conditioning=un_cond)
if torch.cuda.is_available() and self.save_memory:
print("low_vram_shift: is_diffusing", False)
self.model.low_vram_shift(is_diffusing=False)
x_samples = self.model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(shared_num_samples)]
# if rendered_txt_values != "":
if not render_none:
return [whiteboard_img] + results
else:
return results