Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from omegaconf import OmegaConf | |
from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt | |
import json | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
from functools import partial | |
from collections import Counter | |
import math | |
import gc | |
from gradio import processing_utils | |
from typing import Optional | |
import warnings | |
from datetime import datetime | |
from example_component import create_examples | |
from huggingface_hub import hf_hub_download | |
hf_hub_download = partial(hf_hub_download, library_name="gligen_demo") | |
import cv2 | |
import sys | |
sys.tracebacklimit = 0 | |
def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None): | |
cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder) | |
return torch.load(cache_file, map_location='cpu') | |
def load_ckpt_config_from_hf(modality): | |
ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='model') | |
config = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='config') | |
return ckpt, config | |
def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None): | |
pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality) | |
config = OmegaConf.create( config["_content"] ) # config used in training | |
config.alpha_scale = 1.0 | |
if common_instances is None: | |
common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model') | |
common_instances = load_common_ckpt(config, common_ckpt) | |
loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen, common_instances) | |
return loaded_model_list, common_instances | |
class Instance: | |
def __init__(self, capacity = 2): | |
self.model_type = 'base' | |
self.loaded_model_list = {} | |
self.counter = Counter() | |
self.global_counter = Counter() | |
self.loaded_model_list['base'], self.common_instances = ckpt_load_helper( | |
'gligen-generation-text-box', | |
is_inpaint=False, is_style=False, common_instances=None | |
) | |
self.capacity = capacity | |
def _log(self, model_type, batch_size, instruction, phrase_list): | |
self.counter[model_type] += 1 | |
self.global_counter[model_type] += 1 | |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format( | |
current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list | |
)) | |
def get_model(self, model_type, batch_size, instruction, phrase_list): | |
if model_type in self.loaded_model_list: | |
self._log(model_type, batch_size, instruction, phrase_list) | |
return self.loaded_model_list[model_type] | |
if self.capacity == len(self.loaded_model_list): | |
least_used_type = self.counter.most_common()[-1][0] | |
del self.loaded_model_list[least_used_type] | |
del self.counter[least_used_type] | |
gc.collect() | |
torch.cuda.empty_cache() | |
self.loaded_model_list[model_type] = self._get_model(model_type) | |
self._log(model_type, batch_size, instruction, phrase_list) | |
return self.loaded_model_list[model_type] | |
def _get_model(self, model_type): | |
if model_type == 'base': | |
return ckpt_load_helper( | |
'gligen-generation-text-box', | |
is_inpaint=False, is_style=False, common_instances=self.common_instances | |
)[0] | |
elif model_type == 'inpaint': | |
return ckpt_load_helper( | |
'gligen-inpainting-text-box', | |
is_inpaint=True, is_style=False, common_instances=self.common_instances | |
)[0] | |
elif model_type == 'style': | |
return ckpt_load_helper( | |
'gligen-generation-text-image-box', | |
is_inpaint=False, is_style=True, common_instances=self.common_instances | |
)[0] | |
assert False | |
instance = Instance() | |
def load_clip_model(): | |
from transformers import CLIPProcessor, CLIPModel | |
version = "openai/clip-vit-large-patch14" | |
model = CLIPModel.from_pretrained(version).cuda() | |
processor = CLIPProcessor.from_pretrained(version) | |
return { | |
'version': version, | |
'model': model, | |
'processor': processor, | |
} | |
clip_model = load_clip_model() | |
class ImageMask(gr.components.Image): | |
""" | |
Sets: source="canvas", tool="sketch" | |
""" | |
is_template = True | |
def __init__(self, **kwargs): | |
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs) | |
def preprocess(self, x): | |
if x is None: | |
return x | |
if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict: | |
decode_image = processing_utils.decode_base64_to_image(x) | |
width, height = decode_image.size | |
img = np.asarray(decode_image) | |
return {'image':img, 'mask':binarize_2(img)} | |
mask = np.zeros((height, width, 4), dtype=np.uint8) | |
mask[..., -1] = 255 | |
mask = self.postprocess(mask) | |
x = {'image': x, 'mask': mask} | |
print('vao preprocess-------------------------') | |
hh = super().preprocess(x) | |
if (hh['image'].min()!=255) and (hh['mask'][:,:,:3].max()==0): | |
hh['mask'] = binarize_2(hh['image']) | |
return hh | |
class Blocks(gr.Blocks): | |
def __init__( | |
self, | |
theme: str = "default", | |
analytics_enabled: Optional[bool] = None, | |
mode: str = "blocks", | |
title: str = "Gradio", | |
css: Optional[str] = None, | |
**kwargs, | |
): | |
self.extra_configs = { | |
'thumbnail': kwargs.pop('thumbnail', ''), | |
'url': kwargs.pop('url', 'https://gradio.app/'), | |
'creator': kwargs.pop('creator', '@teamGradio'), | |
} | |
super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs) | |
warnings.filterwarnings("ignore") | |
def get_config_file(self): | |
config = super(Blocks, self).get_config_file() | |
for k, v in self.extra_configs.items(): | |
config[k] = v | |
return config | |
''' | |
inference model | |
''' | |
# @torch.no_grad() | |
def inference(task, language_instruction, phrase_list, location_list, inpainting_boxes_nodrop, image, | |
alpha_sample, guidance_scale, batch_size, | |
fix_seed, rand_seed, actual_mask, style_image, | |
*args, **kwargs): | |
# import pdb; pdb.set_trace() | |
# grounding_instruction = json.loads(grounding_instruction) | |
# phrase_list, location_list = [], [] | |
# for k, v in grounding_instruction.items(): | |
# phrase_list.append(k) | |
# location_list.append(v) | |
placeholder_image = Image.open('images/teddy.jpg').convert("RGB") | |
image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled | |
batch_size = int(batch_size) | |
if not 1 <= batch_size <= 4: | |
batch_size = 1 | |
if style_image == None: | |
has_text_mask = 1 | |
has_image_mask = 0 # then we hack above 'image_list' | |
else: | |
valid_phrase_len = len(phrase_list) | |
phrase_list += ['placeholder'] | |
has_text_mask = [1]*valid_phrase_len + [0] | |
image_list = [placeholder_image]*valid_phrase_len + [style_image] | |
has_image_mask = [0]*valid_phrase_len + [1] | |
location_list += [ [0.0, 0.0, 1, 0.01] ] # style image grounding location | |
instruction = dict( | |
prompt = language_instruction, | |
phrases = phrase_list, | |
images = image_list, | |
locations = location_list, | |
alpha_type = [alpha_sample, 0, 1.0 - alpha_sample], | |
has_text_mask = has_text_mask, | |
has_image_mask = has_image_mask, | |
save_folder_name = language_instruction, | |
guidance_scale = guidance_scale, | |
batch_size = batch_size, | |
fix_seed = bool(fix_seed), | |
rand_seed = int(rand_seed), | |
actual_mask = actual_mask, | |
inpainting_boxes_nodrop = inpainting_boxes_nodrop, | |
) | |
get_model = partial(instance.get_model, | |
batch_size=batch_size, | |
instruction=language_instruction, | |
phrase_list=phrase_list) | |
with torch.autocast(device_type='cuda', dtype=torch.float16): | |
if task == 'User provide boxes' or 'Available boxes': | |
if style_image == None: | |
result = grounded_generation_box(get_model('base'), instruction, *args, **kwargs) | |
torch.cuda.empty_cache() | |
return result | |
else: | |
return grounded_generation_box(get_model('style'), instruction, *args, **kwargs) | |
def draw_box(boxes=[], texts=[], img=None): | |
if len(boxes) == 0 and img is None: | |
return None | |
if img is None: | |
img = Image.new('RGB', (512, 512), (255, 255, 255)) | |
colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] | |
draw = ImageDraw.Draw(img) | |
font = ImageFont.truetype("DejaVuSansMono.ttf", size=18) | |
for bid, box in enumerate(boxes): | |
draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4) | |
anno_text = texts[bid] | |
draw.rectangle([box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]], outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4) | |
draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size*1.2)], anno_text, font=font, fill=(255,255,255)) | |
return img | |
def get_concat(ims): | |
if len(ims) == 1: | |
n_col = 1 | |
else: | |
n_col = 2 | |
n_row = math.ceil(len(ims) / 2) | |
dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white") | |
for i, im in enumerate(ims): | |
row_id = i // n_col | |
col_id = i % n_col | |
dst.paste(im, (im.width * col_id, im.height * row_id)) | |
return dst | |
def auto_append_grounding(language_instruction, grounding_texts): | |
for grounding_text in grounding_texts: | |
if grounding_text.lower() not in language_instruction.lower() and grounding_text != 'auto': | |
language_instruction += "; " + grounding_text | |
return language_instruction | |
def generate(task, language_instruction, grounding_texts, sketch_pad, | |
alpha_sample, guidance_scale, batch_size, | |
fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image, | |
state): | |
if 'boxes' not in state: | |
state['boxes'] = [] | |
boxes = state['boxes'] | |
grounding_texts = [x.strip() for x in grounding_texts.split(';')] | |
# assert len(boxes) == len(grounding_texts) | |
if len(boxes) != len(grounding_texts): | |
if len(boxes) < len(grounding_texts): | |
raise ValueError("""The number of boxes should be equal to the number of grounding objects. | |
Number of boxes drawn: {}, number of grounding tokens: {}. | |
Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts))) | |
grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts)) | |
boxes = (np.asarray(boxes) / 512).tolist() | |
grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)}) | |
image = None | |
actual_mask = None | |
if append_grounding: | |
language_instruction = auto_append_grounding(language_instruction, grounding_texts) | |
gen_images, gen_overlays = inference( | |
task, language_instruction, grounding_texts,boxes, boxes, image, | |
alpha_sample, guidance_scale, batch_size, | |
fix_seed, rand_seed, actual_mask, style_cond_image, clip_model=clip_model, | |
) | |
blank_samples = batch_size % 2 if batch_size > 1 else 0 | |
gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \ | |
+ [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ | |
+ [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)] | |
return gen_images + [state] | |
def binarize(x): | |
return (x != 0).astype('uint8') * 255 | |
def binarize_2(x): | |
gray_image = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY) | |
return (gray_image!=255).astype('uint8') * 255 | |
def sized_center_crop(img, cropx, cropy): | |
y, x = img.shape[:2] | |
startx = x // 2 - (cropx // 2) | |
starty = y // 2 - (cropy // 2) | |
return img[starty:starty+cropy, startx:startx+cropx] | |
def sized_center_fill(img, fill, cropx, cropy): | |
y, x = img.shape[:2] | |
startx = x // 2 - (cropx // 2) | |
starty = y // 2 - (cropy // 2) | |
img[starty:starty+cropy, startx:startx+cropx] = fill | |
return img | |
def sized_center_mask(img, cropx, cropy): | |
y, x = img.shape[:2] | |
startx = x // 2 - (cropx // 2) | |
starty = y // 2 - (cropy // 2) | |
center_region = img[starty:starty+cropy, startx:startx+cropx].copy() | |
img = (img * 0.2).astype('uint8') | |
img[starty:starty+cropy, startx:startx+cropx] = center_region | |
return img | |
def center_crop(img, HW=None, tgt_size=(512, 512)): | |
if HW is None: | |
H, W = img.shape[:2] | |
HW = min(H, W) | |
img = sized_center_crop(img, HW, HW) | |
img = Image.fromarray(img) | |
img = img.resize(tgt_size) | |
return np.array(img) | |
def draw(task, input, grounding_texts, new_image_trigger, state, generate_parsed, box_image): | |
print('input', generate_parsed) | |
if type(input) == dict: | |
image = input['image'] | |
mask = input['mask'] | |
if generate_parsed==1: | |
generate_parsed = 0 | |
# import pdb; pdb.set_trace() | |
print('do nothing') | |
return [box_image, new_image_trigger, 1., state, generate_parsed] | |
else: | |
mask = input | |
if mask.ndim == 3: | |
mask = mask[..., 0] | |
image_scale = 1.0 | |
print('vao draw--------------------') | |
mask = binarize(mask) | |
if mask.shape != (512, 512): | |
# assert False, "should not receive any non- 512x512 masks." | |
if 'original_image' in state and state['original_image'].shape[:2] == mask.shape: | |
mask = center_crop(mask, state['inpaint_hw']) | |
image = center_crop(state['original_image'], state['inpaint_hw']) | |
else: | |
mask = np.zeros((512, 512), dtype=np.uint8) | |
mask = binarize(mask) | |
if type(mask) != np.ndarray: | |
mask = np.array(mask) | |
# | |
if mask.sum() == 0: | |
state = {} | |
print('delete state') | |
if True: | |
image = None | |
else: | |
image = Image.fromarray(image) | |
if 'boxes' not in state: | |
state['boxes'] = [] | |
if 'masks' not in state or len(state['masks']) == 0 : | |
state['masks'] = [] | |
last_mask = np.zeros_like(mask) | |
else: | |
last_mask = state['masks'][-1] | |
if type(mask) == np.ndarray and mask.size > 1 : | |
diff_mask = mask - last_mask | |
else: | |
diff_mask = np.zeros([]) | |
if diff_mask.sum() > 0: | |
x1x2 = np.where(diff_mask.max(0) > 1)[0] | |
y1y2 = np.where(diff_mask.max(1) > 1)[0] | |
y1, y2 = y1y2.min(), y1y2.max() | |
x1, x2 = x1x2.min(), x1x2.max() | |
if (x2 - x1 > 5) and (y2 - y1 > 5): | |
state['masks'].append(mask.copy()) | |
state['boxes'].append((x1, y1, x2, y2)) | |
grounding_texts = [x.strip() for x in grounding_texts.split(';')] | |
grounding_texts = [x for x in grounding_texts if len(x) > 0] | |
if len(grounding_texts) < len(state['boxes']): | |
grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))] | |
box_image = draw_box(state['boxes'], grounding_texts, image) | |
generate_parsed = 0 | |
return [box_image, new_image_trigger, image_scale, state, generate_parsed] | |
def change_state(bboxes,layout, state, instruction, trigger_stage, boxes): | |
if trigger_stage ==0 : | |
return [boxes, state, 0] | |
# mask = | |
state['boxes'] = [] | |
state['masks'] = [] | |
image = None | |
list_boxes = bboxes.split('/') | |
result =[] | |
for b in list_boxes: | |
ints = b[1:-1].split(',') | |
l = [] | |
for i in ints: | |
l.append(int(i)) | |
result.append(l) | |
print('run change state') | |
for box in result: | |
state['boxes'].append(box) | |
grounding_texts = [x.strip() for x in instruction.split(';')] | |
grounding_texts = [x for x in grounding_texts if len(x) > 0] | |
if len(grounding_texts) < len(result): | |
grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(result))] | |
box_image = draw_box(result, grounding_texts) | |
mask = binarize_2(layout['image']) | |
state['masks'].append(mask.copy()) | |
# print('done change state', state) | |
print('done change state') | |
# import pdb; pdb.set_trace() | |
return [box_image,state, trigger_stage] | |
def example_click(name, grounding_instruction, instruction, bboxes,generate_parsed, trigger_parsed): | |
list_boxes = bboxes.split('/') | |
result =[] | |
for b in list_boxes: | |
ints = b[1:-1].split(',') | |
l = [] | |
for i in ints: | |
l.append(int(i)) | |
result.append(l) | |
print('run change state') | |
box_image = draw_box(result, instruction) | |
trigger_parsed += 1 | |
print('done the example click') | |
return [box_image, trigger_parsed] | |
def clear(task, sketch_pad_trigger, batch_size, state,trigger_stage, switch_task=False): | |
sketch_pad_trigger = sketch_pad_trigger + 1 | |
trigger_stage = 0 | |
blank_samples = batch_size % 2 if batch_size > 1 else 0 | |
out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \ | |
+ [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ | |
+ [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)] | |
state = {} | |
return [None, sketch_pad_trigger, None, 1.0] + out_images + [state] + [trigger_stage] | |
css = """ | |
#img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img | |
{ | |
height: var(--height) !important; | |
max-height: var(--height) !important; | |
min-height: var(--height) !important; | |
} | |
#paper-info a { | |
color:#008AD7; | |
text-decoration: none; | |
} | |
#paper-info a:hover { | |
cursor: pointer; | |
text-decoration: none; | |
} | |
#my_image > div.fixed-height | |
{ | |
height: var(--height) !important; | |
} | |
""" | |
rescale_js = """ | |
function(x) { | |
const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app'); | |
let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0; | |
const image_width = root.querySelector('#img2img_image').clientWidth; | |
const target_height = parseInt(image_width * image_scale); | |
document.body.style.setProperty('--height', `${target_height}px`); | |
root.querySelectorAll('button.justify-center.rounded')[0].style.display='none'; | |
root.querySelectorAll('button.justify-center.rounded')[1].style.display='none'; | |
return x; | |
} | |
""" | |
# [<a href="https://arxiv.org/abs/2301.07093" target="_blank">Paper</a>] | |
with Blocks( | |
css=css, | |
analytics_enabled=False, | |
title="Attention-refocusing demo", | |
) as main: | |
description = """<p style="text-align: center; font-weight: bold;"> | |
<span style="font-size: 28px">Grounded Text-to-Image Synthesis with Attention Refocusing</span> | |
<br> | |
<span style="font-size: 18px" id="paper-info"> | |
[<a href="https://attention-refocusing.github.io/" target="_blank">Project Page</a>] | |
[<a href="https://github.com/Attention-Refocusing/attention-refocusing" target="_blank">GitHub</a>] | |
</span> | |
</p> | |
<p> | |
To identify the areas of interest based on specific spatial parameters, you need to (1) ⌨️ input the names of the concepts you're interested in <em> Grounding Instruction</em>, and (2) 🖱️ draw their corresponding bounding boxes using <em> Sketch Pad</em> -- the parsed boxes will automatically be showed up once you've drawn them. | |
<br> | |
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/gligen/demo?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a> | |
</p> | |
""" | |
gr.HTML(description) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
sketch_pad_trigger = gr.Number(value=0, visible=False) | |
sketch_pad_resize_trigger = gr.Number(value=0, visible=False) | |
trigger_stage = gr.Number(value=0, visible=False) | |
init_white_trigger = gr.Number(value=0, visible=False) | |
image_scale = gr.Number(value=1.0, elem_id="image_scale", visible=False) | |
new_image_trigger = gr.Number(value=0, visible=False) | |
text_box = gr.Textbox(visible=False) | |
generate_parsed = gr.Number(value=0, visible=False) | |
task = gr.Radio( | |
choices=["Available boxes", 'User provide boxes'], | |
type="value", | |
value="User provide boxes", | |
label="Task", | |
visible=False | |
) | |
language_instruction = gr.Textbox( | |
label="Language instruction", | |
) | |
grounding_instruction = gr.Textbox( | |
label="Grounding instruction (Separated by semicolon)", | |
) | |
with gr.Row(): | |
sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image") | |
out_imagebox = gr.Image(type="pil",elem_id="my_image" ,label="Parsed Sketch Pad", shape=(512,512)) | |
with gr.Row(): | |
clear_btn = gr.Button(value='Clear') | |
gen_btn = gr.Button(value='Generate') | |
with gr.Row(): | |
parsed_btn = gr.Button(value='generate parsed boxes', visible=False) | |
with gr.Accordion("Advanced Options", open=False): | |
with gr.Column(): | |
alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)") | |
guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale") | |
batch_size = gr.Slider(minimum=1, maximum=4,visible=False, step=1, value=1, label="Number of Samples") | |
append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption") | |
use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False) | |
with gr.Row(): | |
fix_seed = gr.Checkbox(value=True, label="Fixed seed") | |
rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed") | |
with gr.Row(): | |
use_style_cond = gr.Checkbox(value=False,visible=False, label="Enable Style Condition") | |
style_cond_image = gr.Image(type="pil",visible=False, label="Style Condition", interactive=True) | |
with gr.Column(scale=4): | |
gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>') | |
with gr.Row(): | |
out_gen_1 = gr.Image(type="pil", visible=True, show_label=False) | |
out_gen_2 = gr.Image(type="pil", visible=False, show_label=False) | |
with gr.Row(): | |
out_gen_3 = gr.Image(type="pil", visible=False, show_label=False) | |
out_gen_4 = gr.Image(type="pil", visible=False, show_label=False) | |
state = gr.State({}) | |
class Controller: | |
def __init__(self): | |
self.calls = 0 | |
self.tracks = 0 | |
self.resizes = 0 | |
self.scales = 0 | |
def init_white(self, init_white_trigger): | |
self.calls += 1 | |
return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger+1 | |
def change_n_samples(self, n_samples): | |
blank_samples = n_samples % 2 if n_samples > 1 else 0 | |
return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \ | |
+ [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)] | |
controller = Controller() | |
main.load( | |
lambda x:x+1, | |
inputs=sketch_pad_trigger, | |
outputs=sketch_pad_trigger, | |
queue=False) | |
sketch_pad.edit( | |
draw, | |
inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state, generate_parsed, out_imagebox], | |
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state, generate_parsed], | |
queue=False, | |
) | |
trigger_stage.change( | |
change_state, | |
inputs=[text_box,sketch_pad, state, grounding_instruction, trigger_stage,out_imagebox], | |
outputs=[out_imagebox,state,trigger_stage], | |
queue=True | |
) | |
grounding_instruction.change( | |
draw, | |
inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state, generate_parsed,out_imagebox], | |
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state, generate_parsed], | |
queue=False, | |
) | |
clear_btn.click( | |
clear, | |
inputs=[task, sketch_pad_trigger, batch_size,trigger_stage, state], | |
outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state, trigger_stage], | |
queue=False) | |
sketch_pad_trigger.change( | |
controller.init_white, | |
inputs=[init_white_trigger], | |
outputs=[sketch_pad, image_scale, init_white_trigger], | |
queue=False) | |
gen_btn.click( | |
generate, | |
inputs=[ | |
task, language_instruction, grounding_instruction, sketch_pad, | |
alpha_sample, guidance_scale, batch_size, | |
fix_seed, rand_seed, | |
use_actual_mask, | |
append_grounding, style_cond_image, | |
state, | |
], | |
outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state], | |
queue=True | |
) | |
init_white_trigger.change( | |
None, | |
None, | |
init_white_trigger, | |
_js=rescale_js, | |
queue=False) | |
examples = [ | |
[ | |
'guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg', | |
"a cat;a dog", | |
"a cat on the right of a dog", | |
'(291, 88, 481, 301)/(25, 64, 260, 391)', | |
1, 1 | |
], | |
[ | |
'guide_imgs/0_a_bus_on_the_left_of_a_car.jpg',#'guide_imgs/0_a_bus_on_the_left_of_a_car.jpg', | |
"a bus;a car", | |
"a bus and a car", | |
'(8,128,266,384)/(300,196,502,316)', #'(8,128,266,384)', #/(300,196,502,316) | |
1, 2 | |
], | |
[ | |
'guide_imgs/1_Two_cars_on_the_street..jpg', | |
"a car;a car", | |
"Two cars on the street.", | |
'(34, 98, 247, 264)/(271, 122, 481, 293)', | |
1, 3 | |
], | |
[ | |
'guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg', | |
"an apple;an apple", | |
"two apples lay side by side on a wooden table, their glossy red and green skins glinting in the sunlight.", | |
'(40, 210, 235, 450)/(275, 210, 470, 450)', | |
1, 4 | |
], | |
[ | |
'guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg', | |
"a banana;an apple", | |
"A banana on the left of an apple.", | |
'(62, 193, 225, 354)/(300, 184, 432, 329)', | |
1, 5 | |
], | |
[ | |
'guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg', | |
"a pizza ;a suitcase", | |
"A pizza on the right of a suitcase.", | |
'(307, 112, 490, 280)/(41, 120, 244, 270)', | |
1, 6 | |
], | |
[ | |
'guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg', | |
"a wine glass;a dog", | |
"A wine glass on top of a dog.", | |
'(206, 78, 306, 214)/(137, 222, 367, 432)', | |
1, 7 | |
] | |
, | |
[ | |
'guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg', | |
"a bicycle;a boat", | |
"A bicycle on top of a boat.", | |
'(185, 110, 335, 205)/(111, 228, 401, 373)', | |
1, 8 | |
] | |
, | |
[ | |
'guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg', | |
"a laptop;a teddy bear", | |
"A laptop on top of a teddy bear.", | |
'(180, 70, 332, 210)/(150, 240, 362, 420)', | |
1, 9 | |
] | |
, | |
[ | |
'guide_imgs/0_A_train_on_top_of_a_surfboard..jpg', | |
"a train;a surfboard", | |
"A train on top of a surfboard.", | |
'(130, 80, 385, 240)/(75, 260, 440, 450)', | |
1, 10 | |
] | |
] | |
with gr.Column(): | |
create_examples( | |
examples=examples, | |
inputs=[sketch_pad, grounding_instruction,language_instruction , text_box, generate_parsed, trigger_stage], | |
outputs=None, | |
fn=None, | |
cache_examples=False, | |
) | |
main.queue(concurrency_count=1, api_open=False) | |
main.launch(share=False, show_api=False, show_error=True, debug=False, server_name="0.0.0.0") | |