Spaces:
Sleeping
Sleeping
from cldm.model import load_state_dict | |
from cldm.ddim_hacked import DDIMSampler | |
from ldm.util import instantiate_from_config | |
import os | |
from omegaconf import OmegaConf | |
import argparse, os | |
from torchvision.transforms import ToTensor | |
from torch import autocast | |
from contextlib import nullcontext | |
from scripts.rendertext_tool import Render_Text, load_model_from_config | |
# def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False): | |
# sd = load_state_dict(ckpt, location='cpu') | |
# if "model_ema.input_blocks10in_layers0weight" not in sd: | |
# cfg.model.params.use_ema = False | |
# model = instantiate_from_config(cfg.model) | |
# if not not_use_ckpt: | |
# m, u = model.load_state_dict(sd, strict=False) | |
# if len(m) > 0 and verbose: | |
# print("missing keys: {}".format(len(m))) | |
# print(m) | |
# if len(u) > 0 and verbose: | |
# print("unexpected keys: {}".format(len(u))) | |
# print(u) | |
# model.cuda() | |
# model.eval() | |
# return model | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--cfg", | |
type=str, | |
default="configs/stable-diffusion/textcaps_cldm_v20.yaml", | |
help="path to config which constructs model", | |
) | |
parser.add_argument( | |
"--ckpt", | |
type=str, | |
help="path to checkpoint of model", | |
) | |
parser.add_argument( | |
"--hint_range_m11", | |
action="store_true", | |
help="the range of the hint image ([-1, 1])", | |
) | |
parser.add_argument( | |
"--precision", | |
type=str, | |
help="evaluate at this precision", | |
choices=["full", "autocast"], | |
default="full" #"autocast" | |
) | |
parser.add_argument( | |
"--not_use_ckpt", | |
action="store_true", | |
help="not to use the ckpt", | |
) | |
parser.add_argument( | |
"--build_demo", | |
action="store_true", | |
help="whether to build the demo", | |
) | |
parser.add_argument( | |
"--sep_prompt", | |
action="store_true", | |
help="whether to sep the prompt", | |
) | |
parser.add_argument( | |
"--spell_prompt_type", | |
type=int, | |
default=1, | |
help="1: A sign with the word 'xxx' written on it; 2: A sign that says 'xxx'", | |
) | |
parser.add_argument( | |
"--max_num_prompts", | |
type=int, | |
default=None, | |
help="max num of the used prompts", | |
) | |
parser.add_argument( | |
"--grams", | |
type=int, | |
default=1, | |
help="How many grams (words or symbols) to form the to-be-rendered text (used for DrawSpelling Benchmark)", | |
) | |
parser.add_argument( | |
"--num_samples", | |
type=int, | |
default=1, | |
help="how many samples to produce for each given prompt. A.k.a batch size", | |
) | |
parser.add_argument( | |
"--from-file", | |
type=str, | |
help="if specified, load prompts from this file, separated by newlines", | |
) | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
nargs="?", | |
default="a sign that says 'Stable Diffusion'", | |
help="the prompt" | |
) | |
parser.add_argument( | |
"--rendered_txt", | |
type=str, | |
nargs="?", | |
default="Stable Diffusion", | |
help="the text to render" | |
) | |
parser.add_argument( | |
"--uncond_glycon_img", | |
action="store_true", | |
help="whether to set glyph embedding as None while using unconditional conditioning", | |
) | |
parser.add_argument( | |
"--deepspeed_ckpt", | |
action="store_true", | |
help="whether to use deepspeed while training", | |
) | |
parser.add_argument( | |
"--glyph_img_size", | |
type=int, | |
default=256, | |
help="the size of input images of the glyph image encoder", | |
) | |
parser.add_argument( | |
"--uncond_glyph_image_type", | |
type=str, | |
default="white", | |
help="the type of rendered glyph images as unconditional conditions while using classifier-free guidance" | |
) | |
parser.add_argument( | |
"--remove_txt_in_prompt", | |
action="store_true", | |
help="whether to remove text in the prompt", | |
) | |
parser.add_argument( | |
"--replace_token", | |
type=str, | |
default="", | |
help="the token used to replace" | |
) | |
return parser | |
if not os.path.basename(os.getcwd()) == "stablediffusion": | |
os.chdir(os.path.join(os.getcwd(), "stablediffusion")) | |
print(os.getcwd()) | |
parser = parse_args() | |
opt = parser.parse_args() | |
if opt.deepspeed_ckpt: | |
assert os.path.isdir(opt.ckpt) | |
opt.ckpt = os.path.join(opt.ckpt, "checkpoint", "mp_rank_00_model_states.pt") | |
assert os.path.exists(opt.ckpt) | |
cfg = OmegaConf.load(f"{opt.cfg}") | |
model = load_model_from_config(cfg, f"{opt.ckpt}", verbose=True, not_use_ckpt=opt.not_use_ckpt) | |
hint_range_m11 = opt.hint_range_m11 | |
sep_prompt = opt.sep_prompt | |
ddim_sampler = DDIMSampler(model) | |
precision_scope = autocast if opt.precision == "autocast" else nullcontext | |
trans = ToTensor() | |
render_tool = Render_Text( | |
model, precision_scope, | |
trans, | |
hint_range_m11, | |
sep_prompt, | |
uncond_glycon_img= cfg.uncond_glycon_img if hasattr(cfg, "uncond_glycon_img") else opt.uncond_glycon_img, | |
glyph_control_proc_config= cfg.glyph_control_proc_config if hasattr(cfg, "glyph_control_proc_config") else None, | |
glyph_img_size = opt.glyph_img_size, | |
uncond_glyph_image_type = cfg.uncond_glyph_image_type if hasattr(cfg, "uncond_glyph_image_type") else opt.uncond_glyph_image_type, | |
remove_txt_in_prompt = cfg.remove_txt_in_prompt if hasattr(cfg, "remove_txt_in_prompt") else opt.remove_txt_in_prompt, | |
replace_token = cfg.replace_token if hasattr(cfg, "replace_token") else opt.replace_token, | |
) | |
if opt.build_demo: | |
import gradio as gr | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.Markdown("## Control Stable Diffusion with Glyph Images") | |
with gr.Row(): | |
with gr.Column(): | |
# input_image = gr.Image(source='upload', type="numpy") | |
rendered_txt = gr.Textbox(label="rendered_txt") | |
prompt = gr.Textbox(label="Prompt") | |
if sep_prompt: | |
prompt_2 = gr.Textbox(label="Prompt_ControlNet") | |
else: | |
prompt_2 = gr.Number(value = 0, visible = False) #None #"" | |
run_button = gr.Button(label="Run") | |
with gr.Accordion("Advanced options", open=False): | |
width = gr.Slider(label="bbox_width", minimum=0., maximum=1, value=0.3, step=0.01) | |
# height = gr.Slider(label="bbox_height", minimum=0., maximum=1, value=0.2, step=0.01) | |
ratio = gr.Slider(label="bbox_width_height_ratio", minimum=0., maximum=5, value=0., step=0.02) | |
top_left_x = gr.Slider(label="bbox_top_left_x", minimum=0., maximum=1, value=0.5, step=0.01) | |
top_left_y = gr.Slider(label="bbox_top_left_y", minimum=0., maximum=1, value=0.5, step=0.01) | |
yaw = gr.Slider(label="bbox_yaw", minimum=-180, maximum=180, value=0, step=5) | |
num_rows = gr.Slider(label="num_rows", minimum=1, maximum=4, value=1, step=1) | |
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) | |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) | |
guess_mode = gr.Checkbox(label='Guess Mode', value=False) | |
# low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1) | |
# high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1) | |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
eta = gr.Number(label="eta (DDIM)", value=0.0) | |
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') | |
n_prompt = gr.Textbox(label="Negative Prompt", | |
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') | |
with gr.Column(): | |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') | |
ips = [ | |
rendered_txt, prompt, | |
width, ratio, # height, | |
top_left_x, top_left_y, yaw, num_rows, | |
a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, | |
prompt_2 | |
] | |
run_button.click(fn=render_tool.process, inputs=ips, outputs=[result_gallery]) | |
# run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) | |
block.launch(server_name='0.0.0.0', share=True) | |
else: | |
import easyocr | |
reader = easyocr.Reader(['en']) | |
# num_samples = 1 | |
# rendered_txt = "happy" | |
# prompt = "A sign that says 'happy'" | |
num_samples = opt.num_samples | |
print("the num of samples is {}".format(num_samples)) | |
if not opt.from_file: | |
prompts = [opt.prompt] | |
data = [opt.rendered_txt] | |
print("the prompt is {}".format(prompts)) | |
print("the rendered_txt is {}".format(data)) | |
assert prompts is not None | |
else: | |
print(f"reading prompts from {opt.from_file}") | |
with open(opt.from_file, "r") as f: | |
data = f.read().splitlines() | |
if "gram" in os.path.basename(opt.from_file): | |
data = [item.split("\t")[0] for item in data] | |
if opt.grams > 1: | |
data = [" ".join(data[i:i + opt.grams]) for i in range(0, len(data), opt.grams)] | |
if "DrawText_Spelling" in os.path.basename(opt.from_file) or "gram" in os.path.basename(opt.from_file): | |
if opt.spell_prompt_type == 1: | |
prompts = ['A sign with the word "{}" written on it'.format(line.strip()) for line in data] | |
elif opt.spell_prompt_type == 2: | |
prompts = ["A sign that says '{}'".format(line.strip()) for line in data] | |
elif opt.spell_prompt_type == 20: | |
prompts = ['A sign that says "{}"'.format(line.strip()) for line in data] | |
elif opt.spell_prompt_type == 3: | |
prompts = ["A whiteboard that says '{}'".format(line.strip()) for line in data] | |
elif opt.spell_prompt_type == 30: | |
prompts = ['A whiteboard that says "{}"'.format(line.strip()) for line in data] | |
else: | |
print("Only five types of prompt templates are supported currently") | |
raise ValueError | |
# if opt.verbose_all_prompts: | |
# show_num = opt.max_num_prompts if (opt.max_num_prompts is not None and opt.max_num_prompts >0) else 10 | |
# for i in range(show_num): | |
# print("embed the word into the prompt template for {} Benchmark: {}".format( | |
# os.path.basename(opt.from_file), data[i]) | |
# ) | |
# else: | |
# print("embed the word into the prompt template for {} Benchmark: e.g., {}".format( | |
# os.path.basename(opt.from_file), data[0]) | |
# ) | |
if opt.max_num_prompts is not None and opt.max_num_prompts >0: | |
print("only use {} prompts to test the model".format(opt.max_num_prompts)) | |
data = data[:opt.max_num_prompts] | |
prompts = prompts[:opt.max_num_prompts] | |
width, ratio, top_left_x, top_left_y, yaw, num_rows = 0.3, 0, 0.5, 0.5, 0, 1 | |
image_resolution = 512 | |
strength = 1 | |
guess_mode = False | |
ddim_steps = 20 | |
scale = 9.0 | |
seed = 1945923867 | |
eta = 0 | |
a_prompt = 'best quality, extremely detailed' | |
n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' | |
all_results_list = [] | |
for i in range(len(data)): | |
ips = ( | |
data[i], prompts[i], | |
width, ratio, top_left_x, top_left_y, yaw, num_rows, | |
a_prompt, n_prompt, | |
num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta | |
) | |
all_results = render_tool.process(*ips) #process(*ips) | |
all_results_list.extend(all_results[1:] if data[i] != "" else all_results) | |
all_ocr_info = [] | |
for image_array in all_results_list: | |
ocr_result = reader.readtext(image_array) | |
all_ocr_info.append(ocr_result) | |