GlyphControl / scripts /gradio_rendertext.py
yyk19's picture
first trial
0902a5f
raw
history blame
13.1 kB
from cldm.model import load_state_dict
from cldm.ddim_hacked import DDIMSampler
from ldm.util import instantiate_from_config
import os
from omegaconf import OmegaConf
import argparse, os
from torchvision.transforms import ToTensor
from torch import autocast
from contextlib import nullcontext
from scripts.rendertext_tool import Render_Text, load_model_from_config
# def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False):
# sd = load_state_dict(ckpt, location='cpu')
# if "model_ema.input_blocks10in_layers0weight" not in sd:
# cfg.model.params.use_ema = False
# model = instantiate_from_config(cfg.model)
# if not not_use_ckpt:
# m, u = model.load_state_dict(sd, strict=False)
# if len(m) > 0 and verbose:
# print("missing keys: {}".format(len(m)))
# print(m)
# if len(u) > 0 and verbose:
# print("unexpected keys: {}".format(len(u)))
# print(u)
# model.cuda()
# model.eval()
# return model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cfg",
type=str,
default="configs/stable-diffusion/textcaps_cldm_v20.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
help="path to checkpoint of model",
)
parser.add_argument(
"--hint_range_m11",
action="store_true",
help="the range of the hint image ([-1, 1])",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="full" #"autocast"
)
parser.add_argument(
"--not_use_ckpt",
action="store_true",
help="not to use the ckpt",
)
parser.add_argument(
"--build_demo",
action="store_true",
help="whether to build the demo",
)
parser.add_argument(
"--sep_prompt",
action="store_true",
help="whether to sep the prompt",
)
parser.add_argument(
"--spell_prompt_type",
type=int,
default=1,
help="1: A sign with the word 'xxx' written on it; 2: A sign that says 'xxx'",
)
parser.add_argument(
"--max_num_prompts",
type=int,
default=None,
help="max num of the used prompts",
)
parser.add_argument(
"--grams",
type=int,
default=1,
help="How many grams (words or symbols) to form the to-be-rendered text (used for DrawSpelling Benchmark)",
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="how many samples to produce for each given prompt. A.k.a batch size",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file, separated by newlines",
)
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a sign that says 'Stable Diffusion'",
help="the prompt"
)
parser.add_argument(
"--rendered_txt",
type=str,
nargs="?",
default="Stable Diffusion",
help="the text to render"
)
parser.add_argument(
"--uncond_glycon_img",
action="store_true",
help="whether to set glyph embedding as None while using unconditional conditioning",
)
parser.add_argument(
"--deepspeed_ckpt",
action="store_true",
help="whether to use deepspeed while training",
)
parser.add_argument(
"--glyph_img_size",
type=int,
default=256,
help="the size of input images of the glyph image encoder",
)
parser.add_argument(
"--uncond_glyph_image_type",
type=str,
default="white",
help="the type of rendered glyph images as unconditional conditions while using classifier-free guidance"
)
parser.add_argument(
"--remove_txt_in_prompt",
action="store_true",
help="whether to remove text in the prompt",
)
parser.add_argument(
"--replace_token",
type=str,
default="",
help="the token used to replace"
)
return parser
if not os.path.basename(os.getcwd()) == "stablediffusion":
os.chdir(os.path.join(os.getcwd(), "stablediffusion"))
print(os.getcwd())
parser = parse_args()
opt = parser.parse_args()
if opt.deepspeed_ckpt:
assert os.path.isdir(opt.ckpt)
opt.ckpt = os.path.join(opt.ckpt, "checkpoint", "mp_rank_00_model_states.pt")
assert os.path.exists(opt.ckpt)
cfg = OmegaConf.load(f"{opt.cfg}")
model = load_model_from_config(cfg, f"{opt.ckpt}", verbose=True, not_use_ckpt=opt.not_use_ckpt)
hint_range_m11 = opt.hint_range_m11
sep_prompt = opt.sep_prompt
ddim_sampler = DDIMSampler(model)
precision_scope = autocast if opt.precision == "autocast" else nullcontext
trans = ToTensor()
render_tool = Render_Text(
model, precision_scope,
trans,
hint_range_m11,
sep_prompt,
uncond_glycon_img= cfg.uncond_glycon_img if hasattr(cfg, "uncond_glycon_img") else opt.uncond_glycon_img,
glyph_control_proc_config= cfg.glyph_control_proc_config if hasattr(cfg, "glyph_control_proc_config") else None,
glyph_img_size = opt.glyph_img_size,
uncond_glyph_image_type = cfg.uncond_glyph_image_type if hasattr(cfg, "uncond_glyph_image_type") else opt.uncond_glyph_image_type,
remove_txt_in_prompt = cfg.remove_txt_in_prompt if hasattr(cfg, "remove_txt_in_prompt") else opt.remove_txt_in_prompt,
replace_token = cfg.replace_token if hasattr(cfg, "replace_token") else opt.replace_token,
)
if opt.build_demo:
import gradio as gr
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Control Stable Diffusion with Glyph Images")
with gr.Row():
with gr.Column():
# input_image = gr.Image(source='upload', type="numpy")
rendered_txt = gr.Textbox(label="rendered_txt")
prompt = gr.Textbox(label="Prompt")
if sep_prompt:
prompt_2 = gr.Textbox(label="Prompt_ControlNet")
else:
prompt_2 = gr.Number(value = 0, visible = False) #None #""
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
width = gr.Slider(label="bbox_width", minimum=0., maximum=1, value=0.3, step=0.01)
# height = gr.Slider(label="bbox_height", minimum=0., maximum=1, value=0.2, step=0.01)
ratio = gr.Slider(label="bbox_width_height_ratio", minimum=0., maximum=5, value=0., step=0.02)
top_left_x = gr.Slider(label="bbox_top_left_x", minimum=0., maximum=1, value=0.5, step=0.01)
top_left_y = gr.Slider(label="bbox_top_left_y", minimum=0., maximum=1, value=0.5, step=0.01)
yaw = gr.Slider(label="bbox_yaw", minimum=-180, maximum=180, value=0, step=5)
num_rows = gr.Slider(label="num_rows", minimum=1, maximum=4, value=1, step=1)
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
# low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
# high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
eta = gr.Number(label="eta (DDIM)", value=0.0)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
n_prompt = gr.Textbox(label="Negative Prompt",
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [
rendered_txt, prompt,
width, ratio, # height,
top_left_x, top_left_y, yaw, num_rows,
a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta,
prompt_2
]
run_button.click(fn=render_tool.process, inputs=ips, outputs=[result_gallery])
# run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
block.launch(server_name='0.0.0.0', share=True)
else:
import easyocr
reader = easyocr.Reader(['en'])
# num_samples = 1
# rendered_txt = "happy"
# prompt = "A sign that says 'happy'"
num_samples = opt.num_samples
print("the num of samples is {}".format(num_samples))
if not opt.from_file:
prompts = [opt.prompt]
data = [opt.rendered_txt]
print("the prompt is {}".format(prompts))
print("the rendered_txt is {}".format(data))
assert prompts is not None
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
if "gram" in os.path.basename(opt.from_file):
data = [item.split("\t")[0] for item in data]
if opt.grams > 1:
data = [" ".join(data[i:i + opt.grams]) for i in range(0, len(data), opt.grams)]
if "DrawText_Spelling" in os.path.basename(opt.from_file) or "gram" in os.path.basename(opt.from_file):
if opt.spell_prompt_type == 1:
prompts = ['A sign with the word "{}" written on it'.format(line.strip()) for line in data]
elif opt.spell_prompt_type == 2:
prompts = ["A sign that says '{}'".format(line.strip()) for line in data]
elif opt.spell_prompt_type == 20:
prompts = ['A sign that says "{}"'.format(line.strip()) for line in data]
elif opt.spell_prompt_type == 3:
prompts = ["A whiteboard that says '{}'".format(line.strip()) for line in data]
elif opt.spell_prompt_type == 30:
prompts = ['A whiteboard that says "{}"'.format(line.strip()) for line in data]
else:
print("Only five types of prompt templates are supported currently")
raise ValueError
# if opt.verbose_all_prompts:
# show_num = opt.max_num_prompts if (opt.max_num_prompts is not None and opt.max_num_prompts >0) else 10
# for i in range(show_num):
# print("embed the word into the prompt template for {} Benchmark: {}".format(
# os.path.basename(opt.from_file), data[i])
# )
# else:
# print("embed the word into the prompt template for {} Benchmark: e.g., {}".format(
# os.path.basename(opt.from_file), data[0])
# )
if opt.max_num_prompts is not None and opt.max_num_prompts >0:
print("only use {} prompts to test the model".format(opt.max_num_prompts))
data = data[:opt.max_num_prompts]
prompts = prompts[:opt.max_num_prompts]
width, ratio, top_left_x, top_left_y, yaw, num_rows = 0.3, 0, 0.5, 0.5, 0, 1
image_resolution = 512
strength = 1
guess_mode = False
ddim_steps = 20
scale = 9.0
seed = 1945923867
eta = 0
a_prompt = 'best quality, extremely detailed'
n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
all_results_list = []
for i in range(len(data)):
ips = (
data[i], prompts[i],
width, ratio, top_left_x, top_left_y, yaw, num_rows,
a_prompt, n_prompt,
num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta
)
all_results = render_tool.process(*ips) #process(*ips)
all_results_list.extend(all_results[1:] if data[i] != "" else all_results)
all_ocr_info = []
for image_array in all_results_list:
ocr_result = reader.readtext(image_array)
all_ocr_info.append(ocr_result)