diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..487b3a03936584403eb0140324f63cc01118f8e8 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..2b84e60803a5b0868ca3e0b33e09000815b9ecb7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +gligen/ldm/data/imagenet_train_hr_indices.p filter=lfs diff=lfs merge=lfs -text +gligen/projection_matrix.pth filter=lfs diff=lfs merge=lfs -text +gligen/ldm/data/imagenet_val_hr_indices.p filter=lfs diff=lfs merge=lfs -text +gligen/SD_input_conv_weight_bias.pth filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 200e1f3ba1da6cfa52f3b1fb08087653a10fb9e5..e80c1844fd3a5f88c1d10df4d018672093e59f56 100644 --- a/.gitignore +++ b/.gitignore @@ -110,3 +110,6 @@ create_samples/ create_samples/* ckpts/* + +**/__pycache__/* +**/__pycache__ diff --git a/README.md b/README.md index b3f216cd2f7e8ff733bfc4461ce8f1774521041e..b961d3e7557609d173f66d0f8fc460e5d99f26f8 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ --- -title: LoCo_Gligen Demo -emoji: 👁 -colorFrom: blue -colorTo: purple +title: Attention Refocusing +emoji: 🌖 +colorFrom: yellow +colorTo: indigo sdk: gradio sdk_version: 3.19.1 app_file: app.py pinned: false --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py index 8e1926a0b5001c8cd6cbce7b79fefca88088f8f0..d44dee13553dd0633f20fd54e443a64255fbbb8c 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,5 @@ import gradio as gr +import os import torch from omegaconf import OmegaConf from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt @@ -18,9 +19,11 @@ import warnings from datetime import datetime +from example_component import create_examples + from huggingface_hub import hf_hub_download hf_hub_download = partial(hf_hub_download, library_name="gligen_demo") - +import cv2 import sys sys.tracebacklimit = 0 @@ -39,8 +42,6 @@ def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None): pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality) config = OmegaConf.create( config["_content"] ) # config used in training config.alpha_scale = 1.0 - config.model['params']['is_inpaint'] = is_inpaint - config.model['params']['is_style'] = is_style if common_instances is None: common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model') @@ -138,13 +139,25 @@ class ImageMask(gr.components.Image): if x is None: return x if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict: + decode_image = processing_utils.decode_base64_to_image(x) + print('decode to 64') width, height = decode_image.size + img = np.asarray(decode_image) + return {'image':img, 'mask':binarize_2(img)} + mask = np.zeros((height, width, 4), dtype=np.uint8) + mask[..., -1] = 255 mask = self.postprocess(mask) x = {'image': x, 'mask': mask} - return super().preprocess(x) + print('vao preprocess-------------------------') + hh = super().preprocess(x) + if (hh['image'].min()!=255) and (hh['mask'][:,:,:3].max()==0): + + hh['mask'] = binarize_2(hh['image']) + + return hh class Blocks(gr.Blocks): @@ -180,23 +193,25 @@ class Blocks(gr.Blocks): inference model ''' -@torch.no_grad() -def inference(task, language_instruction, grounding_instruction, inpainting_boxes_nodrop, image, +# @torch.no_grad() +def inference(task, language_instruction, phrase_list, location_list, inpainting_boxes_nodrop, image, alpha_sample, guidance_scale, batch_size, fix_seed, rand_seed, actual_mask, style_image, *args, **kwargs): - grounding_instruction = json.loads(grounding_instruction) - phrase_list, location_list = [], [] - for k, v in grounding_instruction.items(): - phrase_list.append(k) - location_list.append(v) + # import pdb; pdb.set_trace() + + # grounding_instruction = json.loads(grounding_instruction) + # phrase_list, location_list = [], [] + # for k, v in grounding_instruction.items(): + # phrase_list.append(k) + # location_list.append(v) placeholder_image = Image.open('images/teddy.jpg').convert("RGB") image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled batch_size = int(batch_size) if not 1 <= batch_size <= 4: - batch_size = 2 + batch_size = 1 if style_image == None: has_text_mask = 1 @@ -212,9 +227,6 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe location_list += [ [0.0, 0.0, 1, 0.01] ] # style image grounding location - if task == 'Grounded Inpainting': - alpha_sample = 1.0 - instruction = dict( prompt = language_instruction, phrases = phrase_list, @@ -238,21 +250,19 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe phrase_list=phrase_list) with torch.autocast(device_type='cuda', dtype=torch.float16): - if task == 'Grounded Generation': + if task == 'User provide boxes' or 'Available boxes': if style_image == None: - return grounded_generation_box(get_model('base'), instruction, *args, **kwargs) + result = grounded_generation_box(get_model('base'), instruction, *args, **kwargs) + torch.cuda.empty_cache() + return result else: return grounded_generation_box(get_model('style'), instruction, *args, **kwargs) - elif task == 'Grounded Inpainting': - assert image is not None - instruction['input_image'] = image.convert("RGB") - return grounded_generation_box(get_model('inpaint'), instruction, *args, **kwargs) def draw_box(boxes=[], texts=[], img=None): if len(boxes) == 0 and img is None: return None - + if img is None: img = Image.new('RGB', (512, 512), (255, 255, 255)) colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] @@ -281,7 +291,7 @@ def get_concat(ims): def auto_append_grounding(language_instruction, grounding_texts): for grounding_text in grounding_texts: - if grounding_text not in language_instruction and grounding_text != 'auto': + if grounding_text.lower() not in language_instruction.lower() and grounding_text != 'auto': language_instruction += "; " + grounding_text return language_instruction @@ -292,6 +302,7 @@ def generate(task, language_instruction, grounding_texts, sketch_pad, alpha_sample, guidance_scale, batch_size, fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image, state): + if 'boxes' not in state: state['boxes'] = [] @@ -307,44 +318,18 @@ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(groun boxes = (np.asarray(boxes) / 512).tolist() grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)}) - image = None actual_mask = None - if task == 'Grounded Inpainting': - image = state.get('original_image', sketch_pad['image']).copy() - image = center_crop(image) - image = Image.fromarray(image) - - if use_actual_mask: - actual_mask = sketch_pad['mask'].copy() - if actual_mask.ndim == 3: - actual_mask = actual_mask[..., 0] - actual_mask = center_crop(actual_mask, tgt_size=(64, 64)) - actual_mask = torch.from_numpy(actual_mask == 0).float() - - if state.get('inpaint_hw', None): - boxes = np.asarray(boxes) * 0.9 + 0.05 - boxes = boxes.tolist() - grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes) if obj != 'auto'}) + if append_grounding: language_instruction = auto_append_grounding(language_instruction, grounding_texts) gen_images, gen_overlays = inference( - task, language_instruction, grounding_instruction, boxes, image, + task, language_instruction, grounding_texts,boxes, boxes, image, alpha_sample, guidance_scale, batch_size, fix_seed, rand_seed, actual_mask, style_cond_image, clip_model=clip_model, ) - - for idx, gen_image in enumerate(gen_images): - - if task == 'Grounded Inpainting' and state.get('inpaint_hw', None): - hw = min(*state['original_image'].shape[:2]) - gen_image = sized_center_fill(state['original_image'].copy(), np.array(gen_image.resize((hw, hw))), hw, hw) - gen_image = Image.fromarray(gen_image) - - gen_images[idx] = gen_image - blank_samples = batch_size % 2 if batch_size > 1 else 0 gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ @@ -355,6 +340,9 @@ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(groun def binarize(x): return (x != 0).astype('uint8') * 255 +def binarize_2(x): + gray_image = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY) + return (gray_image!=255).astype('uint8') * 255 def sized_center_crop(img, cropx, cropy): y, x = img.shape[:2] @@ -387,10 +375,20 @@ def center_crop(img, HW=None, tgt_size=(512, 512)): img = img.resize(tgt_size) return np.array(img) -def draw(task, input, grounding_texts, new_image_trigger, state): +# 接收 sketchpad 的输入 (左边) +def draw(task, input, grounding_texts, new_image_trigger, state, generate_parsed, box_image): + print('input', generate_parsed) + if type(input) == dict: image = input['image'] mask = input['mask'] + if generate_parsed==1: + generate_parsed = 0 + # import pdb; pdb.set_trace() + print('do nothing') + + return [box_image, new_image_trigger, 1., state, generate_parsed] + else: mask = input @@ -398,24 +396,8 @@ def draw(task, input, grounding_texts, new_image_trigger, state): mask = mask[..., 0] image_scale = 1.0 - - # resize trigger - if task == "Grounded Inpainting": - mask_cond = mask.sum() == 0 - # size_cond = mask.shape != (512, 512) - if mask_cond and 'original_image' not in state: - image = Image.fromarray(image) - width, height = image.size - scale = 600 / min(width, height) - image = image.resize((int(width * scale), int(height * scale))) - state['original_image'] = np.array(image).copy() - image_scale = float(height / width) - return [None, new_image_trigger + 1, image_scale, state] - else: - original_image = state['original_image'] - H, W = original_image.shape[:2] - image_scale = float(H / W) - + + print('vao draw--------------------') mask = binarize(mask) if mask.shape != (512, 512): # assert False, "should not receive any non- 512x512 masks." @@ -424,16 +406,16 @@ def draw(task, input, grounding_texts, new_image_trigger, state): image = center_crop(state['original_image'], state['inpaint_hw']) else: mask = np.zeros((512, 512), dtype=np.uint8) - # mask = center_crop(mask) mask = binarize(mask) if type(mask) != np.ndarray: mask = np.array(mask) - - if mask.sum() == 0 and task != "Grounded Inpainting": + # + if mask.sum() == 0: state = {} + print('delete state') - if task != 'Grounded Inpainting': + if True: image = None else: image = Image.fromarray(image) @@ -441,20 +423,20 @@ def draw(task, input, grounding_texts, new_image_trigger, state): if 'boxes' not in state: state['boxes'] = [] - if 'masks' not in state or len(state['masks']) == 0: + if 'masks' not in state or len(state['masks']) == 0 : state['masks'] = [] last_mask = np.zeros_like(mask) else: last_mask = state['masks'][-1] - - if type(mask) == np.ndarray and mask.size > 1: + + if type(mask) == np.ndarray and mask.size > 1 : diff_mask = mask - last_mask else: diff_mask = np.zeros([]) if diff_mask.sum() > 0: - x1x2 = np.where(diff_mask.max(0) != 0)[0] - y1y2 = np.where(diff_mask.max(1) != 0)[0] + x1x2 = np.where(diff_mask.max(0) > 1)[0] + y1y2 = np.where(diff_mask.max(1) > 1)[0] y1, y2 = y1y2.min(), y1y2.max() x1, x2 = x1x2.min(), x1x2.max() @@ -466,26 +448,73 @@ def draw(task, input, grounding_texts, new_image_trigger, state): grounding_texts = [x for x in grounding_texts if len(x) > 0] if len(grounding_texts) < len(state['boxes']): grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))] - + box_image = draw_box(state['boxes'], grounding_texts, image) + generate_parsed = 0 + + return [box_image, new_image_trigger, image_scale, state, generate_parsed] + +def change_state(bboxes,layout, state, instruction, trigger_stage, boxes): + if trigger_stage ==0 : + return [boxes, state, 0] + # mask = + state['boxes'] = [] + state['masks'] = [] + image = None + list_boxes = bboxes.split('/') + result =[] + for b in list_boxes: + ints = b[1:-1].split(',') + l = [] + for i in ints: + l.append(int(i)) + result.append(l) + print('run change state') + + for box in result: + state['boxes'].append(box) + grounding_texts = [x.strip() for x in instruction.split(';')] + grounding_texts = [x for x in grounding_texts if len(x) > 0] + if len(grounding_texts) < len(result): + grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(result))] - if box_image is not None and state.get('inpaint_hw', None): - inpaint_hw = state['inpaint_hw'] - box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw))) - original_image = state['original_image'].copy() - box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw) - - return [box_image, new_image_trigger, image_scale, state] + box_image = draw_box(result, grounding_texts) + + mask = binarize_2(layout['image']) + state['masks'].append(mask.copy()) + # print('done change state', state) + print('done change state') + # import pdb; pdb.set_trace() + return [box_image,state, trigger_stage] + +def example_click(name, grounding_instruction, instruction, bboxes,generate_parsed, trigger_parsed): + + list_boxes = bboxes.split('/') + result =[] + + for b in list_boxes: + ints = b[1:-1].split(',') + l = [] + for i in ints: + l.append(int(i)) + result.append(l) + print('run change state') + + box_image = draw_box(result, instruction) + trigger_parsed += 1 + print('done the example click') + return [box_image, trigger_parsed] -def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False): - if task != 'Grounded Inpainting': - sketch_pad_trigger = sketch_pad_trigger + 1 +def clear(task, sketch_pad_trigger, batch_size, state,trigger_stage, switch_task=False): + + sketch_pad_trigger = sketch_pad_trigger + 1 + trigger_stage = 0 blank_samples = batch_size % 2 if batch_size > 1 else 0 out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)] state = {} - return [None, sketch_pad_trigger, None, 1.0] + out_images + [state] + return [None, sketch_pad_trigger, None, 1.0] + out_images + [state] + [trigger_stage] css = """ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img @@ -502,6 +531,10 @@ css = """ cursor: pointer; text-decoration: none; } +#my_image > div.fixed-height +{ + height: var(--height) !important; +} """ rescale_js = """ @@ -516,42 +549,48 @@ function(x) { return x; } """ - +# [Paper] with Blocks( css=css, analytics_enabled=False, - title="GLIGen demo", + title="Attention-refocusing demo", ) as main: description = """

- GLIGen: Open-Set Grounded Text-to-Image Generation + Grounded Text-to-Image Synthesis with Attention Refocusing
- [Project Page] - [Paper] - [GitHub] + [Project Page] + + [GitHub]

- To ground concepts of interest with desired spatial specification, please (1) ⌨️ enter the concept names in Grounding Instruction, and (2) 🖱️ draw their corresponding bounding boxes one by one using Sketch Pad -- the parsed boxes will be displayed automatically. + To identify the areas of interest based on specific spatial parameters, you need to (1) ⌨️ input the names of the concepts you're interested in Grounding Instruction, and (2) 🖱️ draw their corresponding bounding boxes using Sketch Pad -- the parsed boxes will automatically be showed up once you've drawn them.
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. Duplicate Space

""" gr.HTML(description) - + with gr.Row(): with gr.Column(scale=4): sketch_pad_trigger = gr.Number(value=0, visible=False) sketch_pad_resize_trigger = gr.Number(value=0, visible=False) + trigger_stage = gr.Number(value=0, visible=False) + init_white_trigger = gr.Number(value=0, visible=False) - image_scale = gr.Number(value=0, elem_id="image_scale", visible=False) + image_scale = gr.Number(value=1.0, elem_id="image_scale", visible=False) new_image_trigger = gr.Number(value=0, visible=False) - + text_box = gr.Textbox(visible=False) + generate_parsed = gr.Number(value=0, visible=False) + task = gr.Radio( - choices=["Grounded Generation", 'Grounded Inpainting'], + choices=["Available boxes", 'User provide boxes'], type="value", - value="Grounded Generation", + value="User provide boxes", label="Task", + visible=False + ) language_instruction = gr.Textbox( label="Language instruction", @@ -561,33 +600,38 @@ with Blocks( ) with gr.Row(): sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image") - out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad") + out_imagebox = gr.Image(type="pil",elem_id="my_image" ,label="Parsed Sketch Pad", shape=(512,512)) with gr.Row(): clear_btn = gr.Button(value='Clear') gen_btn = gr.Button(value='Generate') + with gr.Row(): + parsed_btn = gr.Button(value='generate parsed boxes') + with gr.Accordion("Advanced Options", open=False): with gr.Column(): alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)") guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale") - batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=2, label="Number of Samples") + batch_size = gr.Slider(minimum=1, maximum=4,visible=False, step=1, value=1, label="Number of Samples") append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption") use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False) with gr.Row(): fix_seed = gr.Checkbox(value=True, label="Fixed seed") rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed") - with gr.Row(): - use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition") - style_cond_image = gr.Image(type="pil", label="Style Condition", visible=False, interactive=True) + + with gr.Row(): + use_style_cond = gr.Checkbox(value=False,visible=False, label="Enable Style Condition") + style_cond_image = gr.Image(type="pil",visible=False, label="Style Condition", interactive=True) with gr.Column(scale=4): gr.HTML('Generated Images') with gr.Row(): out_gen_1 = gr.Image(type="pil", visible=True, show_label=False) - out_gen_2 = gr.Image(type="pil", visible=True, show_label=False) + out_gen_2 = gr.Image(type="pil", visible=False, show_label=False) with gr.Row(): out_gen_3 = gr.Image(type="pil", visible=False, show_label=False) out_gen_4 = gr.Image(type="pil", visible=False, show_label=False) state = gr.State({}) + class Controller: def __init__(self): @@ -605,75 +649,43 @@ with Blocks( return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \ + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)] - def resize_centercrop(self, state): - self.resizes += 1 - image = state['original_image'].copy() - inpaint_hw = int(0.9 * min(*image.shape[:2])) - state['inpaint_hw'] = inpaint_hw - image_cc = center_crop(image, inpaint_hw) - # print(f'resize triggered {self.resizes}', image.shape, '->', image_cc.shape) - return image_cc, state - - def resize_masked(self, state): - self.resizes += 1 - image = state['original_image'].copy() - inpaint_hw = int(0.9 * min(*image.shape[:2])) - state['inpaint_hw'] = inpaint_hw - image_mask = sized_center_mask(image, inpaint_hw, inpaint_hw) - state['masked_image'] = image_mask.copy() - # print(f'mask triggered {self.resizes}') - return image_mask, state - - def switch_task_hide_cond(self, task): - cond = False - if task == "Grounded Generation": - cond = True - - return gr.Checkbox.update(visible=cond, value=False), gr.Image.update(value=None, visible=False), gr.Slider.update(visible=cond), gr.Checkbox.update(visible=(not cond), value=False) - controller = Controller() main.load( lambda x:x+1, inputs=sketch_pad_trigger, outputs=sketch_pad_trigger, queue=False) + sketch_pad.edit( draw, - inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state], - outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state], + inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state, generate_parsed, out_imagebox], + outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state, generate_parsed], queue=False, ) + trigger_stage.change( + change_state, + inputs=[text_box,sketch_pad, state, grounding_instruction, trigger_stage,out_imagebox], + outputs=[out_imagebox,state,trigger_stage], + queue=True + ) grounding_instruction.change( draw, - inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state], - outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state], + inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state, generate_parsed,out_imagebox], + outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state, generate_parsed], queue=False, ) clear_btn.click( clear, - inputs=[task, sketch_pad_trigger, batch_size, state], - outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state], - queue=False) - task.change( - partial(clear, switch_task=True), - inputs=[task, sketch_pad_trigger, batch_size, state], - outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state], + inputs=[task, sketch_pad_trigger, batch_size,trigger_stage, state], + outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state, trigger_stage], queue=False) + sketch_pad_trigger.change( controller.init_white, inputs=[init_white_trigger], outputs=[sketch_pad, image_scale, init_white_trigger], queue=False) - sketch_pad_resize_trigger.change( - controller.resize_masked, - inputs=[state], - outputs=[sketch_pad, state], - queue=False) - batch_size.change( - controller.change_n_samples, - inputs=[batch_size], - outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4], - queue=False) + gen_btn.click( generate, inputs=[ @@ -687,88 +699,98 @@ with Blocks( outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state], queue=True ) - sketch_pad_resize_trigger.change( - None, - None, - sketch_pad_resize_trigger, - _js=rescale_js, - queue=False) init_white_trigger.change( None, None, init_white_trigger, _js=rescale_js, queue=False) - use_style_cond.change( - lambda cond: gr.Image.update(visible=cond), - use_style_cond, - style_cond_image, - queue=False) - task.change( - controller.switch_task_hide_cond, - inputs=task, - outputs=[use_style_cond, style_cond_image, alpha_sample, use_actual_mask], - queue=False) - - with gr.Column(): - gr.Examples( - examples=[ - [ - "images/blank.png", - "Grounded Generation", - "a dog and an apple", - "a dog;an apple", + examples = [ + [ + 'guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg', + "a cat;a dog", + "a cat on the right of a dog", + '(291, 88, 481, 301)/(25, 64, 260, 391)', + 1, 1 ], [ - "images/blank.png", - "Grounded Generation", - "John Lennon is using a pc", - "John Lennon;a pc", - [ - "images/blank.png", - "Grounded Generation", - "a painting of a fox sitting in a field at sunrise in the style of Claude Mone", - "fox;sunrise", - ], + 'guide_imgs/0_a_bus_on_the_left_of_a_car.jpg',#'guide_imgs/0_a_bus_on_the_left_of_a_car.jpg', + "a bus;a car", + "a bus and a car", + '(8,128,266,384)/(300,196,502,316)', #'(8,128,266,384)', #/(300,196,502,316) + 1, 2 ], [ - "images/blank.png", - "Grounded Generation", - "a beautiful painting of hot dog by studio ghibli, octane render, brilliantly coloured", - "hot dog", + 'guide_imgs/1_Two_cars_on_the_street..jpg', + "a car;a car", + "Two cars on the street.", + '(34, 98, 247, 264)/(271, 122, 481, 293)', + 1, 3 ], [ - "images/blank.png", - "Grounded Generation", - "a sport car, unreal engine, global illumination, ray tracing", - "a sport car", + 'guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg', + "an apple;an apple", + "two apples lay side by side on a wooden table, their glossy red and green skins glinting in the sunlight.", + '(40, 210, 235, 450)/(275, 210, 470, 450)', + 1, 4 ], [ - "images/flower_beach.jpg", - "Grounded Inpainting", - "a squirrel and the space needle", - "a squirrel;the space needle", + 'guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg', + "a banana;an apple", + "A banana on the left of an apple.", + '(62, 193, 225, 354)/(300, 184, 432, 329)', + 1, 5 ], [ - "images/arg_corgis.jpeg", - "Grounded Inpainting", - "a dog and a birthday cake", - "a dog; a birthday cake", + 'guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg', + "a pizza ;a suitcase", + "A pizza on the right of a suitcase.", + '(307, 112, 490, 280)/(41, 120, 244, 270)', + 1, 6 ], [ - "images/teddy.jpg", - "Grounded Inpainting", - "a teddy bear wearing a santa claus red shirt; holding a Christmas gift box on hand", - "a santa claus shirt; a Christmas gift box", - ], - ], - inputs=[sketch_pad, task, language_instruction, grounding_instruction], + 'guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg', + "a wine glass;a dog", + "A wine glass on top of a dog.", + '(206, 78, 306, 214)/(137, 222, 367, 432)', + 1, 7 + ] + , + [ + 'guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg', + "a bicycle;a boat", + "A bicycle on top of a boat.", + '(185, 110, 335, 205)/(111, 228, 401, 373)', + 1, 8 + ] + , + [ + 'guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg', + "a laptop;a teddy bear", + "A laptop on top of a teddy bear.", + '(180, 70, 332, 210)/(150, 240, 362, 420)', + 1, 9 + ] + , + [ + 'guide_imgs/0_A_train_on_top_of_a_surfboard..jpg', + "a train;a surfboard", + "A train on top of a surfboard.", + '(130, 80, 385, 240)/(75, 260, 440, 450)', + 1, 10 + ] + ] + + with gr.Column(): + + create_examples( + examples=examples, + inputs=[sketch_pad, grounding_instruction,language_instruction , text_box, generate_parsed, trigger_stage], outputs=None, fn=None, cache_examples=False, + ) main.queue(concurrency_count=1, api_open=False) -main.launch(share=False, show_api=False, show_error=True) - - +main.launch(share=False, show_api=False, show_error=True, debug=False,) diff --git a/dataset/__pycache__/__init__.cpython-310.pyc b/dataset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a8ab8eabbfd7df6c2f77020678efc89e6a5c531 Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-310.pyc differ diff --git a/dataset/__pycache__/__init__.cpython-38.pyc b/dataset/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05b0ef16d074d28a35ecbe025b6821120d8605da Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-38.pyc differ diff --git a/dataset/__pycache__/catalog.cpython-310.pyc b/dataset/__pycache__/catalog.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5479f5eb754f7dafa022a0a3a5b4bf42ad05533 Binary files /dev/null and b/dataset/__pycache__/catalog.cpython-310.pyc differ diff --git a/dataset/__pycache__/catalog.cpython-38.pyc b/dataset/__pycache__/catalog.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc5b0f96eb8b80206ac3e28b4e5237e9d1cf29a3 Binary files /dev/null and b/dataset/__pycache__/catalog.cpython-38.pyc differ diff --git a/dataset/__pycache__/concat_dataset.cpython-310.pyc b/dataset/__pycache__/concat_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e54af630c170e2572d6e50f19adbfd00b450eaf4 Binary files /dev/null and b/dataset/__pycache__/concat_dataset.cpython-310.pyc differ diff --git a/dataset/__pycache__/concat_dataset.cpython-38.pyc b/dataset/__pycache__/concat_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0d24742b9b477ac163ee9f1bd4a17b85f370051 Binary files /dev/null and b/dataset/__pycache__/concat_dataset.cpython-38.pyc differ diff --git a/environment.yaml b/environment.yaml index 1102441a1a99f62687385aee15b4a82d4da7453a..6fa931c37c2b460f4de1959a0e8fc5d777cf71c4 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,4 +1,4 @@ -name: loco_gligen_demo +name: gligen_demo channels: - xformers/label/dev - pytorch diff --git a/example_component.py b/example_component.py new file mode 100644 index 0000000000000000000000000000000000000000..19fceb0d8abb853da6d66901201c0784930be8fe --- /dev/null +++ b/example_component.py @@ -0,0 +1,805 @@ +""" +Defines helper methods useful for loading and caching Interface examples. +""" +from __future__ import annotations + +import ast +import csv +import inspect +import os +import subprocess +import tempfile +import threading +import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import PIL +import PIL.Image + +from gradio import components, processing_utils, routes, utils +from gradio.context import Context +from gradio.documentation import document, set_documentation_group +from gradio.flagging import CSVLogger + +if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). + from gradio.components import IOComponent + +CACHED_FOLDER = "gradio_cached_examples" +LOG_FILE = "log.csv" + +set_documentation_group("helpers") + + +def create_examples( + examples: List[Any] | List[List[Any]] | str, + inputs: IOComponent | List[IOComponent], + outputs: IOComponent | List[IOComponent] | None = None, + fn: Callable | None = None, + cache_examples: bool = False, + examples_per_page: int = 10, + _api_mode: bool = False, + label: str | None = None, + elem_id: str | None = None, + run_on_click: bool = False, + preprocess: bool = True, + postprocess: bool = True, + batch: bool = False, +): + """Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component.""" + examples_obj = Examples( + examples=examples, + inputs=inputs, + outputs=outputs, + fn=fn, + cache_examples=cache_examples, + examples_per_page=examples_per_page, + _api_mode=_api_mode, + label=label, + elem_id=elem_id, + run_on_click=run_on_click, + preprocess=preprocess, + postprocess=postprocess, + batch=batch, + _initiated_directly=False, + ) + utils.synchronize_async(examples_obj.create) + return examples_obj + + +class Examples: + """ + This class is a wrapper over the Dataset component and can be used to create Examples + for Blocks / Interfaces. Populates the Dataset component with examples and + assigns event listener so that clicking on an example populates the input/output + components. Optionally handles example caching for fast inference. + + Demos: blocks_inputs, fake_gan + Guides: more_on_examples_and_flagging, using_hugging_face_integrations, image_classification_in_pytorch, image_classification_in_tensorflow, image_classification_with_vision_transformers, create_your_own_friends_with_a_gan + """ + + def __init__( + self, + examples: List[Any] | List[List[Any]] | str, + inputs: IOComponent | List[IOComponent], + outputs: IOComponent | List[IOComponent] | None = None, + fn: Callable | None = None, + cache_examples: bool = False, + examples_per_page: int = 10, + _api_mode: bool = False, + label: str | None = "Examples", + elem_id: str | None = None, + run_on_click: bool = False, + preprocess: bool = True, + postprocess: bool = True, + batch: bool = False, + _initiated_directly: bool = True, + ): + """ + Parameters: + examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs. + inputs: the component or list of components corresponding to the examples + outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True. + fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True. + cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided + examples_per_page: how many examples to show per page. + label: the label to use for the examples component (by default, "Examples") + elem_id: an optional string that is assigned as the id of this component in the HTML DOM. + run_on_click: if cache_examples is False, clicking on an example does not run the function when an example is clicked. Set this to True to run the function when an example is clicked. Has no effect if cache_examples is True. + preprocess: if True, preprocesses the example input before running the prediction function and caching the output. Only applies if cache_examples is True. + postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if cache_examples is True. + batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is True. + """ + if _initiated_directly: + warnings.warn( + "Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.", + ) + + if cache_examples and (fn is None or outputs is None): + raise ValueError("If caching examples, `fn` and `outputs` must be provided") + + if not isinstance(inputs, list): + inputs = [inputs] + if outputs and not isinstance(outputs, list): + outputs = [outputs] + + working_directory = Path().absolute() + + if examples is None: + raise ValueError("The parameter `examples` cannot be None") + elif isinstance(examples, list) and ( + len(examples) == 0 or isinstance(examples[0], list) + ): + pass + elif ( + isinstance(examples, list) and len(inputs) == 1 + ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists + examples = [[e] for e in examples] + elif isinstance(examples, str): + if not Path(examples).exists(): + raise FileNotFoundError( + "Could not find examples directory: " + examples + ) + working_directory = examples + if not (Path(examples) / LOG_FILE).exists(): + if len(inputs) == 1: + examples = [[e] for e in os.listdir(examples)] + else: + raise FileNotFoundError( + "Could not find log file (required for multiple inputs): " + + LOG_FILE + ) + else: + with open(Path(examples) / LOG_FILE) as logs: + examples = list(csv.reader(logs)) + examples = [ + examples[i][: len(inputs)] for i in range(1, len(examples)) + ] # remove header and unnecessary columns + + else: + raise ValueError( + "The parameter `examples` must either be a string directory or a list" + "(if there is only 1 input component) or (more generally), a nested " + "list, where each sublist represents a set of inputs." + ) + + input_has_examples = [False] * len(inputs) + for example in examples: + for idx, example_for_input in enumerate(example): + if not (example_for_input is None): + try: + input_has_examples[idx] = True + except IndexError: + pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged) + + inputs_with_examples = [ + inp for (inp, keep) in zip(inputs, input_has_examples) if keep + ] + non_none_examples = [ + [ex for (ex, keep) in zip(example, input_has_examples) if keep] + for example in examples + ] + + self.examples = examples + self.non_none_examples = non_none_examples + self.inputs = inputs + self.inputs_with_examples = inputs_with_examples + self.outputs = outputs + self.fn = fn + self.cache_examples = cache_examples + self._api_mode = _api_mode + self.preprocess = preprocess + self.postprocess = postprocess + self.batch = batch + + with utils.set_directory(working_directory): + self.processed_examples = [ + [ + component.postprocess(sample) + for component, sample in zip(inputs, example) + ] + for example in examples + ] + self.non_none_processed_examples = [ + [ex for (ex, keep) in zip(example, input_has_examples) if keep] + for example in self.processed_examples + ] + if cache_examples: + for example in self.examples: + if len([ex for ex in example if ex is not None]) != len(self.inputs): + warnings.warn( + "Examples are being cached but not all input components have " + "example values. This may result in an exception being thrown by " + "your function. If you do get an error while caching examples, make " + "sure all of your inputs have example values for all of your examples " + "or you provide default values for those particular parameters in your function." + ) + break + + with utils.set_directory(working_directory): + self.dataset = components.Dataset( + components=inputs_with_examples, + samples=non_none_examples, + type="index", + label=label, + samples_per_page=examples_per_page, + elem_id=elem_id, + ) + + self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id) + self.cached_file = Path(self.cached_folder) / "log.csv" + self.cache_examples = cache_examples + self.run_on_click = run_on_click + + async def create(self) -> None: + """Caches the examples if self.cache_examples is True and creates the Dataset + component to hold the examples""" + + async def load_example(example_id): + # import pdb; pdb.set_trace() + if self.cache_examples: + processed_example = self.non_none_processed_examples[ + example_id + ] + await self.load_from_cache(example_id) + else: + processed_example = self.non_none_processed_examples[example_id] + return utils.resolve_singleton(processed_example) + + if Context.root_block: + if self.cache_examples and self.outputs: + targets = self.inputs_with_examples + self.outputs + else: + targets = self.inputs_with_examples + self.dataset.click( + load_example, + inputs=[self.dataset], + outputs=targets, # type: ignore + postprocess=False, + queue=False, + ) + self.dataset.click( + self.fn, + inputs=[self.dataset], + outputs=targets, # type: ignore + postprocess=False, + queue=False, + ) + # if self.run_on_click and not self.cache_examples: + # if self.fn is None: + # raise ValueError("Cannot run_on_click if no function is provided") + # self.dataset.click( + # self.fn, + # inputs=self.inputs, # type: ignore + # outputs=self.outputs, # type: ignore + # ) + + if self.cache_examples: + await self.cache() + + async def cache(self) -> None: + """ + Caches all of the examples so that their predictions can be shown immediately. + """ + if Path(self.cached_file).exists(): + print( + f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache." + ) + else: + if Context.root_block is None: + raise ValueError("Cannot cache examples if not in a Blocks context") + + print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'") + cache_logger = CSVLogger() + + # create a fake dependency to process the examples and get the predictions + dependency = Context.root_block.set_event_trigger( + event_name="fake_event", + fn=self.fn, + inputs=self.inputs_with_examples, # type: ignore + outputs=self.outputs, # type: ignore + preprocess=self.preprocess and not self._api_mode, + postprocess=self.postprocess and not self._api_mode, + batch=self.batch, + ) + + fn_index = Context.root_block.dependencies.index(dependency) + assert self.outputs is not None + cache_logger.setup(self.outputs, self.cached_folder) + for example_id, _ in enumerate(self.examples): + processed_input = self.processed_examples[example_id] + if self.batch: + processed_input = [[value] for value in processed_input] + prediction = await Context.root_block.process_api( + fn_index=fn_index, inputs=processed_input, request=None, state={} + ) + output = prediction["data"] + if self.batch: + output = [value[0] for value in output] + cache_logger.flag(output) + # Remove the "fake_event" to prevent bugs in loading interfaces from spaces + Context.root_block.dependencies.remove(dependency) + Context.root_block.fns.pop(fn_index) + + async def load_from_cache(self, example_id: int) -> List[Any]: + """Loads a particular cached example for the interface. + Parameters: + example_id: The id of the example to process (zero-indexed). + """ + # import pdb; pdb.set_trace() + with open(self.cached_file, encoding="utf-8") as cache: + examples = list(csv.reader(cache)) + example = examples[example_id + 1] # +1 to adjust for header + output = [] + assert self.outputs is not None + for component, value in zip(self.outputs, example): + try: + value_as_dict = ast.literal_eval(value) + assert utils.is_update(value_as_dict) + output.append(value_as_dict) + except (ValueError, TypeError, SyntaxError, AssertionError): + output.append(component.serialize(value, self.cached_folder)) + return output + + +class TrackedIterable: + def __init__( + self, + iterable: Iterable | None, + index: int | None, + length: int | None, + desc: str | None, + unit: str | None, + _tqdm=None, + progress: float | None = None, + ) -> None: + self.iterable = iterable + self.index = index + self.length = length + self.desc = desc + self.unit = unit + self._tqdm = _tqdm + self.progress = progress + + +@document("__call__", "tqdm") +class Progress(Iterable): + """ + The Progress class provides a custom progress tracker that is used in a function signature. + To attach a Progress tracker to a function, simply add a parameter right after the input parameters that has a default value set to a `gradio.Progress()` instance. + The Progress tracker can then be updated in the function by calling the Progress object or using the `tqdm` method on an Iterable. + The Progress tracker is currently only available with `queue()`. + Example: + import gradio as gr + import time + def my_function(x, progress=gr.Progress()): + progress(0, desc="Starting...") + time.sleep(1) + for i in progress.tqdm(range(100)): + time.sleep(0.1) + return x + gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue().launch() + Demos: progress + """ + + def __init__( + self, + track_tqdm: bool = False, + _callback: Callable | None = None, # for internal use only + _event_id: str | None = None, + ): + """ + Parameters: + track_tqdm: If True, the Progress object will track any tqdm.tqdm iterations with the tqdm library in the function. + """ + self.track_tqdm = track_tqdm + self._callback = _callback + self._event_id = _event_id + self.iterables: List[TrackedIterable] = [] + + def __len__(self): + return self.iterables[-1].length + + def __iter__(self): + return self + + def __next__(self): + """ + Updates progress tracker with next item in iterable. + """ + if self._callback: + current_iterable = self.iterables[-1] + while ( + not hasattr(current_iterable.iterable, "__next__") + and len(self.iterables) > 0 + ): + current_iterable = self.iterables.pop() + self._callback( + event_id=self._event_id, + iterables=self.iterables, + ) + assert current_iterable.index is not None, "Index not set." + current_iterable.index += 1 + try: + return next(current_iterable.iterable) # type: ignore + except StopIteration: + self.iterables.pop() + raise StopIteration + else: + return self + + def __call__( + self, + progress: float | Tuple[int, int | None] | None, + desc: str | None = None, + total: int | None = None, + unit: str = "steps", + _tqdm=None, + ): + """ + Updates progress tracker with progress and message text. + Parameters: + progress: If float, should be between 0 and 1 representing completion. If Tuple, first number represents steps completed, and second value represents total steps or None if unknown. If None, hides progress bar. + desc: description to display. + total: estimated total number of steps. + unit: unit of iterations. + """ + if self._callback: + if isinstance(progress, tuple): + index, total = progress + progress = None + else: + index = None + self._callback( + event_id=self._event_id, + iterables=self.iterables + + [TrackedIterable(None, index, total, desc, unit, _tqdm, progress)], + ) + else: + return progress + + def tqdm( + self, + iterable: Iterable | None, + desc: str | None = None, + total: int | None = None, + unit: str = "steps", + _tqdm=None, + *args, + **kwargs, + ): + """ + Attaches progress tracker to iterable, like tqdm. + Parameters: + iterable: iterable to attach progress tracker to. + desc: description to display. + total: estimated total number of steps. + unit: unit of iterations. + """ + if self._callback: + if iterable is None: + new_iterable = TrackedIterable(None, 0, total, desc, unit, _tqdm) + self.iterables.append(new_iterable) + self._callback(event_id=self._event_id, iterables=self.iterables) + return self + length = len(iterable) if hasattr(iterable, "__len__") else None # type: ignore + self.iterables.append( + TrackedIterable(iter(iterable), 0, length, desc, unit, _tqdm) + ) + return self + + def update(self, n=1): + """ + Increases latest iterable with specified number of steps. + Parameters: + n: number of steps completed. + """ + if self._callback and len(self.iterables) > 0: + current_iterable = self.iterables[-1] + assert current_iterable.index is not None, "Index not set." + current_iterable.index += n + self._callback( + event_id=self._event_id, + iterables=self.iterables, + ) + else: + return + + def close(self, _tqdm): + """ + Removes iterable with given _tqdm. + """ + if self._callback: + for i in range(len(self.iterables)): + if id(self.iterables[i]._tqdm) == id(_tqdm): + self.iterables.pop(i) + break + self._callback( + event_id=self._event_id, + iterables=self.iterables, + ) + else: + return + + +def create_tracker(root_blocks, event_id, fn, track_tqdm): + + progress = Progress(_callback=root_blocks._queue.set_progress, _event_id=event_id) + if not track_tqdm: + return progress, fn + + try: + _tqdm = __import__("tqdm") + except ModuleNotFoundError: + return progress, fn + if not hasattr(root_blocks, "_progress_tracker_per_thread"): + root_blocks._progress_tracker_per_thread = {} + + def init_tqdm(self, iterable=None, desc=None, *args, **kwargs): + self._progress = root_blocks._progress_tracker_per_thread.get( + threading.get_ident() + ) + if self._progress is not None: + self._progress.event_id = event_id + self._progress.tqdm(iterable, desc, _tqdm=self, *args, **kwargs) + kwargs["file"] = open(os.devnull, "w") + self.__init__orig__(iterable, desc, *args, **kwargs) + + def iter_tqdm(self): + if self._progress is not None: + return self._progress + else: + return self.__iter__orig__() + + def update_tqdm(self, n=1): + if self._progress is not None: + self._progress.update(n) + return self.__update__orig__(n) + + def close_tqdm(self): + if self._progress is not None: + self._progress.close(self) + return self.__close__orig__() + + def exit_tqdm(self, exc_type, exc_value, traceback): + if self._progress is not None: + self._progress.close(self) + return self.__exit__orig__(exc_type, exc_value, traceback) + + if not hasattr(_tqdm.tqdm, "__init__orig__"): + _tqdm.tqdm.__init__orig__ = _tqdm.tqdm.__init__ + _tqdm.tqdm.__init__ = init_tqdm + if not hasattr(_tqdm.tqdm, "__update__orig__"): + _tqdm.tqdm.__update__orig__ = _tqdm.tqdm.update + _tqdm.tqdm.update = update_tqdm + if not hasattr(_tqdm.tqdm, "__close__orig__"): + _tqdm.tqdm.__close__orig__ = _tqdm.tqdm.close + _tqdm.tqdm.close = close_tqdm + if not hasattr(_tqdm.tqdm, "__exit__orig__"): + _tqdm.tqdm.__exit__orig__ = _tqdm.tqdm.__exit__ + _tqdm.tqdm.__exit__ = exit_tqdm + if not hasattr(_tqdm.tqdm, "__iter__orig__"): + _tqdm.tqdm.__iter__orig__ = _tqdm.tqdm.__iter__ + _tqdm.tqdm.__iter__ = iter_tqdm + if hasattr(_tqdm, "auto") and hasattr(_tqdm.auto, "tqdm"): + _tqdm.auto.tqdm = _tqdm.tqdm + + def tracked_fn(*args): + thread_id = threading.get_ident() + root_blocks._progress_tracker_per_thread[thread_id] = progress + response = fn(*args) + del root_blocks._progress_tracker_per_thread[thread_id] + return response + + return progress, tracked_fn + + +def special_args( + fn: Callable, + inputs: List[Any] | None = None, + request: routes.Request | None = None, +): + """ + Checks if function has special arguments Request (via annotation) or Progress (via default value). + If inputs is provided, these values will be loaded into the inputs array. + Parameters: + block_fn: function to check. + inputs: array to load special arguments into. + request: request to load into inputs. + Returns: + updated inputs, request index, progress index + """ + signature = inspect.signature(fn) + positional_args = [] + for i, param in enumerate(signature.parameters.values()): + if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD): + break + positional_args.append(param) + progress_index = None + for i, param in enumerate(positional_args): + if isinstance(param.default, Progress): + progress_index = i + if inputs is not None: + inputs.insert(i, param.default) + elif param.annotation == routes.Request: + if inputs is not None: + inputs.insert(i, request) + if inputs is not None: + while len(inputs) < len(positional_args): + i = len(inputs) + param = positional_args[i] + if param.default == param.empty: + warnings.warn("Unexpected argument. Filling with None.") + inputs.append(None) + else: + inputs.append(param.default) + return inputs or [], progress_index + + +@document() +def update(**kwargs) -> dict: + """ + Updates component properties. When a function passed into a Gradio Interface or a Blocks events returns a typical value, it updates the value of the output component. But it is also possible to update the properties of an output component (such as the number of lines of a `Textbox` or the visibility of an `Image`) by returning the component's `update()` function, which takes as parameters any of the constructor parameters for that component. + This is a shorthand for using the update method on a component. + For example, rather than using gr.Number.update(...) you can just use gr.update(...). + Note that your editor's autocompletion will suggest proper parameters + if you use the update method on the component. + Demos: blocks_essay, blocks_update, blocks_essay_update + + Parameters: + kwargs: Key-word arguments used to update the component's properties. + Example: + # Blocks Example + import gradio as gr + with gr.Blocks() as demo: + radio = gr.Radio([1, 2, 4], label="Set the value of the number") + number = gr.Number(value=2, interactive=True) + radio.change(fn=lambda value: gr.update(value=value), inputs=radio, outputs=number) + demo.launch() + + # Interface example + import gradio as gr + def change_textbox(choice): + if choice == "short": + return gr.Textbox.update(lines=2, visible=True) + elif choice == "long": + return gr.Textbox.update(lines=8, visible=True) + else: + return gr.Textbox.update(visible=False) + gr.Interface( + change_textbox, + gr.Radio( + ["short", "long", "none"], label="What kind of essay would you like to write?" + ), + gr.Textbox(lines=2), + live=True, + ).launch() + """ + kwargs["__type__"] = "generic_update" + return kwargs + + +def skip() -> dict: + return update() + + +@document() +def make_waveform( + audio: str | Tuple[int, np.ndarray], + *, + bg_color: str = "#f3f4f6", + bg_image: str | None = None, + fg_alpha: float = 0.75, + bars_color: str | Tuple[str, str] = ("#fbbf24", "#ea580c"), + bar_count: int = 50, + bar_width: float = 0.6, +): + """ + Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component. + Parameters: + audio: Audio file path or tuple of (sample_rate, audio_data) + bg_color: Background color of waveform (ignored if bg_image is provided) + bg_image: Background image of waveform + fg_alpha: Opacity of foreground waveform + bars_color: Color of waveform bars. Can be a single color or a tuple of (start_color, end_color) of gradient + bar_count: Number of bars in waveform + bar_width: Width of bars in waveform. 1 represents full width, 0.5 represents half width, etc. + Returns: + A filepath to the output video. + """ + if isinstance(audio, str): + audio_file = audio + audio = processing_utils.audio_from_file(audio) + else: + tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name) + audio_file = tmp_wav.name + duration = round(len(audio[1]) / audio[0], 4) + + # Helper methods to create waveform + def hex_to_RGB(hex_str): + return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)] + + def get_color_gradient(c1, c2, n): + assert n > 1 + c1_rgb = np.array(hex_to_RGB(c1)) / 255 + c2_rgb = np.array(hex_to_RGB(c2)) / 255 + mix_pcts = [x / (n - 1) for x in range(n)] + rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts] + return [ + "#" + "".join([format(int(round(val * 255)), "02x") for val in item]) + for item in rgb_colors + ] + + # Reshape audio to have a fixed number of bars + samples = audio[1] + if len(samples.shape) > 1: + samples = np.mean(samples, 1) + bins_to_pad = bar_count - (len(samples) % bar_count) + samples = np.pad(samples, [(0, bins_to_pad)]) + samples = np.reshape(samples, (bar_count, -1)) + samples = np.abs(samples) + samples = np.max(samples, 1) + + matplotlib.use("Agg") + plt.clf() + # Plot waveform + color = ( + bars_color + if isinstance(bars_color, str) + else get_color_gradient(bars_color[0], bars_color[1], bar_count) + ) + plt.bar( + np.arange(0, bar_count), + samples * 2, + bottom=(-1 * samples), + width=bar_width, + color=color, + ) + plt.axis("off") + plt.margins(x=0) + tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + savefig_kwargs: Dict[str, Any] = {"bbox_inches": "tight"} + if bg_image is not None: + savefig_kwargs["transparent"] = True + else: + savefig_kwargs["facecolor"] = bg_color + plt.savefig(tmp_img.name, **savefig_kwargs) + waveform_img = PIL.Image.open(tmp_img.name) + waveform_img = waveform_img.resize((1000, 200)) + + # Composite waveform with background image + if bg_image is not None: + waveform_array = np.array(waveform_img) + waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha + waveform_img = PIL.Image.fromarray(waveform_array) + + bg_img = PIL.Image.open(bg_image) + waveform_width, waveform_height = waveform_img.size + bg_width, bg_height = bg_img.size + if waveform_width != bg_width: + bg_img = bg_img.resize( + (waveform_width, 2 * int(bg_height * waveform_width / bg_width / 2)) + ) + bg_width, bg_height = bg_img.size + composite_height = max(bg_height, waveform_height) + composite = PIL.Image.new("RGBA", (waveform_width, composite_height), "#FFFFFF") + composite.paste(bg_img, (0, composite_height - bg_height)) + composite.paste( + waveform_img, (0, composite_height - waveform_height), waveform_img + ) + composite.save(tmp_img.name) + img_width, img_height = composite.size + else: + img_width, img_height = waveform_img.size + waveform_img.save(tmp_img.name) + + # Convert waveform to video with ffmpeg + output_mp4 = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) + + ffmpeg_cmd = f"""ffmpeg -loop 1 -i {tmp_img.name} -i {audio_file} -vf "color=c=#FFFFFF77:s={img_width}x{img_height}[bar];[0][bar]overlay=-w+(w/{duration})*t:H-h:shortest=1" -t {duration} -y {output_mp4.name}""" + + subprocess.call(ffmpeg_cmd, shell=True) + return output_mp4.name diff --git a/gligen/.DS_Store b/gligen/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b880aeb4826dd12db4035e8c2abbda8457d64eaa Binary files /dev/null and b/gligen/.DS_Store differ diff --git a/gligen/SD_input_conv_weight_bias.pth b/gligen/SD_input_conv_weight_bias.pth new file mode 100644 index 0000000000000000000000000000000000000000..76eed06e176fec68ff2d1c4a3fd179cf620a7d7d --- /dev/null +++ b/gligen/SD_input_conv_weight_bias.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5a0efad69747a766158304f39091c2b6a24cafb5f833d174f32bee8e864a562 +size 130 diff --git a/gligen/__pycache__/__init__.cpython-310.pyc b/gligen/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24d4dfffcc487fa19918fea81b886d0232938f1c Binary files /dev/null and b/gligen/__pycache__/__init__.cpython-310.pyc differ diff --git a/gligen/__pycache__/__init__.cpython-38.pyc b/gligen/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dddf0891f9b86a5a4c19aad9273cdac089f3782 Binary files /dev/null and b/gligen/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/__pycache__/distributed.cpython-310.pyc b/gligen/__pycache__/distributed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcef2b878821df992a78d5701ef00ee43a73e8ed Binary files /dev/null and b/gligen/__pycache__/distributed.cpython-310.pyc differ diff --git a/gligen/__pycache__/distributed.cpython-38.pyc b/gligen/__pycache__/distributed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7e01e94be05a8a6f8afe0291682765163cf72ac Binary files /dev/null and b/gligen/__pycache__/distributed.cpython-38.pyc differ diff --git a/gligen/__pycache__/evaluator.cpython-310.pyc b/gligen/__pycache__/evaluator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23536570dfc1c541b49cef24c2fab64c7a2d6dfc Binary files /dev/null and b/gligen/__pycache__/evaluator.cpython-310.pyc differ diff --git a/gligen/__pycache__/evaluator.cpython-38.pyc b/gligen/__pycache__/evaluator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b775a0d708154f3ead33f5fd96f7d51fcf266103 Binary files /dev/null and b/gligen/__pycache__/evaluator.cpython-38.pyc differ diff --git a/gligen/__pycache__/task_grounded_generation.cpython-310.pyc b/gligen/__pycache__/task_grounded_generation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e89c08b11cdfdeaf134fe6cdfa1551d76dfa85bc Binary files /dev/null and b/gligen/__pycache__/task_grounded_generation.cpython-310.pyc differ diff --git a/gligen/__pycache__/task_grounded_generation.cpython-38.pyc b/gligen/__pycache__/task_grounded_generation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5b36ff062326f901f0a31f3f547ce95993e4fd Binary files /dev/null and b/gligen/__pycache__/task_grounded_generation.cpython-38.pyc differ diff --git a/gligen/__pycache__/trainer.cpython-310.pyc b/gligen/__pycache__/trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43bca3e34683a417f9e312f079afa51759b9f7ef Binary files /dev/null and b/gligen/__pycache__/trainer.cpython-310.pyc differ diff --git a/gligen/__pycache__/trainer.cpython-38.pyc b/gligen/__pycache__/trainer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16ac8ae5d0a4e8c3ade40746a342320c7aab222b Binary files /dev/null and b/gligen/__pycache__/trainer.cpython-38.pyc differ diff --git a/gligen/evaluator.py b/gligen/evaluator.py index afb61ec9aef76ef2654769c878bc233e4c805767..436c3d9b1c733bf3a3cc1ff027eb08d03b2d2fed 100644 --- a/gligen/evaluator.py +++ b/gligen/evaluator.py @@ -14,7 +14,7 @@ from trainer import read_official_ckpt, batch_to_device, ImageCaptionSaver, wrap from PIL import Image import math import json - +#hello def draw_masks_from_boxes(boxes,size): diff --git a/gligen/ldm/.DS_Store b/gligen/ldm/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..cc8962d7b5d196908d8d336eb5b39dc4d7ee7b02 Binary files /dev/null and b/gligen/ldm/.DS_Store differ diff --git a/gligen/ldm/__pycache__/util.cpython-310.pyc b/gligen/ldm/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48dd3593600d1df25ced228c16a04de926fea931 Binary files /dev/null and b/gligen/ldm/__pycache__/util.cpython-310.pyc differ diff --git a/gligen/ldm/__pycache__/util.cpython-38.pyc b/gligen/ldm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c49bacc41c35909fac61e2ecc1c916fc1ffb7605 Binary files /dev/null and b/gligen/ldm/__pycache__/util.cpython-38.pyc differ diff --git a/gligen/ldm/data/.DS_Store b/gligen/ldm/data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/gligen/ldm/data/.DS_Store differ diff --git a/gligen/ldm/data/imagenet_train_hr_indices.p b/gligen/ldm/data/imagenet_train_hr_indices.p new file mode 100644 index 0000000000000000000000000000000000000000..f55f631aa0c1ae1a805896d42f133bacd3f7139b --- /dev/null +++ b/gligen/ldm/data/imagenet_train_hr_indices.p @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f86ea1924a1522b20bc0f709a069cc65f09d5fc617a7a31af7aaa3839a5a4d73 +size 132 diff --git a/gligen/ldm/data/imagenet_val_hr_indices.p b/gligen/ldm/data/imagenet_val_hr_indices.p new file mode 100644 index 0000000000000000000000000000000000000000..93e8f10adc6c89e445f6b3f7af9d5c7d2c0da3df --- /dev/null +++ b/gligen/ldm/data/imagenet_val_hr_indices.p @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff1f5eb275a93c0fb53e227679f323ea1d024c87db296453296cebeef86fc0f4 +size 131 diff --git a/gligen/ldm/models/.DS_Store b/gligen/ldm/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8e2da8e41b76fb8a3c71433582fcacba45e51b72 Binary files /dev/null and b/gligen/ldm/models/.DS_Store differ diff --git a/gligen/ldm/models/__pycache__/autoencoder.cpython-310.pyc b/gligen/ldm/models/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14285a6fe4e9333e7d8ce4945009d5b36513980d Binary files /dev/null and b/gligen/ldm/models/__pycache__/autoencoder.cpython-310.pyc differ diff --git a/gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31267258e90660e79d7d00084135454ec92e8285 Binary files /dev/null and b/gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc b/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4907c0c3353be2c2ea5b86b55f4fbe370b70907c Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aab67efed99c0c65c5da44eab775d687302dcbd4 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc b/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc5958c922e3cb9cd8a1c856de98f8cc30d86527 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e496b8b850a5e43196b3aa6381f453e21f1d1766 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc b/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..619bfa7112444ce4f78b8f11c58ec3bae705d25e Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..967b19d2615c7aca3ade02c313bdf641867ed6a0 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-310.pyc b/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b0dd634d90eb6657ddaaf143cdb7013e7c4fc99 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-310.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07c6ce260940772e875980d5df10dfba907d352e Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-310.pyc b/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94445ecb328130f1db6cb9809087918ec001b168 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-310.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0130f2932816a221829355e8f8fbf412e035960 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/loss.cpython-310.pyc b/gligen/ldm/models/diffusion/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fc020e96f4a79e76be29cc03fd75ba60cd48cc5 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/loss.cpython-310.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d450a9c0e02daf65da42ac35d0f564f98c894e8 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/plms.cpython-310.pyc b/gligen/ldm/models/diffusion/__pycache__/plms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5e8c61e915c87f5158e3d9a4b95f6bd73c5ae65 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/plms.cpython-310.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3067ac490bb3ebd63857301562bf694b559888fc Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/ddim.py b/gligen/ldm/models/diffusion/ddim.py index ef5603ae921ee3ad88a1b5914201c1385bee3a2a..7db86661e94319b54bec15bf521097bb7b7faf87 100644 --- a/gligen/ldm/models/diffusion/ddim.py +++ b/gligen/ldm/models/diffusion/ddim.py @@ -87,7 +87,9 @@ class DDIMSampler(object): # set alpha if self.alpha_generator_func != None: self.set_alpha_scale(self.model, alphas[i]) - + if alphas[i] == 0: + self.model.restore_first_conv_from_SD() + # run index = total_steps - i - 1 input["timesteps"] = torch.full((b,), step, device=self.device, dtype=torch.long) @@ -110,9 +112,7 @@ class DDIMSampler(object): e_t = self.model(input) if uc is not None and guidance_scale != 1: - unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc) - if "inpainting_extra_input" in input: - unconditional_input["inpainting_extra_input"] = input["inpainting_extra_input"] + unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=input["inpainting_extra_input"], grounding_extra_input=input['grounding_extra_input']) e_t_uncond = self.model( unconditional_input ) e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond) diff --git a/gligen/ldm/models/diffusion/loss.py b/gligen/ldm/models/diffusion/loss.py index bec1c4cabd1a209aad2eefcc1329851030a68810..c07fd730fc0a67facbc30e9a2e0afc4afcfb89bb 100644 --- a/gligen/ldm/models/diffusion/loss.py +++ b/gligen/ldm/models/diffusion/loss.py @@ -639,11 +639,11 @@ def caculate_loss_LAC(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, objec # avg_fg_value = torch.mean(ca_map_obj * mask) # print('avg_fg_value', avg_fg_value) - sum_in += (norm_ca_map_obj * mask).sum() - sum_out += (norm_ca_map_obj * (1 - mask)).sum() - + sum_in = (norm_ca_map_obj * mask).sum() + sum_out = (norm_ca_map_obj * (1 - mask)).sum() + obj_loss += (1 - sum_in / (sum_in + sum_out)) ** 2 # 在这里每个物体对应1个box,所以len是1 - loss += (obj_loss/len(object_positions[obj_idx])) + loss += obj_loss # get pad_loss #sot_map = attn_map[:, :, 0].reshape(H, W) @@ -676,7 +676,7 @@ def caculate_loss_LAC(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, objec # print('该步优化结束') - loss += (1 - sum_in / (sum_in + sum_out)) ** 2 + # loss += (1 - sum_in / (sum_in + sum_out)) ** 2 # loss += max_loss # print('loss', loss) # print('pad_loss', alpha * pad_loss) diff --git a/gligen/ldm/models/diffusion/plms.py b/gligen/ldm/models/diffusion/plms.py index ac3ce55cbeb7ef794781907b5fb43051f0b88655..a8719e021665cad3c4479384f75f1db130cb1451 100644 --- a/gligen/ldm/models/diffusion/plms.py +++ b/gligen/ldm/models/diffusion/plms.py @@ -3,7 +3,7 @@ import numpy as np from tqdm import tqdm from functools import partial from copy import deepcopy - +from diffusers import AutoencoderKL, LMSDiscreteScheduler from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like import math from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att, caculate_loss_LoCo,caculate_loss_LAC, caculate_loss_LoCo_V2 @@ -82,6 +82,11 @@ class PLMSSampler(object): if self.alpha_generator_func != None: alphas = self.alpha_generator_func(len(time_range)) + # 新加的scheduler + noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", num_train_timesteps=1000) + noise_scheduler.set_timesteps(50) + for i, step in enumerate(time_range): # set alpha and restore first conv layer @@ -103,10 +108,21 @@ class PLMSSampler(object): # three loss types if loss_type !=None and loss_type!='standard': if input['object_position'] != []: - - - x = self.update_loss_LoCo( input,i, index, ts, time_factor = time_factor) - + if loss_type=='SAR_CAR': + x = self.update_loss_self_cross( input,i, index, ts ) + elif loss_type=='SAR': + x = self.update_only_self( input,i, index, ts ) + elif loss_type=='CAR': + x = self.update_loss_only_cross( input,i, index, ts ) + elif loss_type=='LoCo': + + #print('Utilizing LoCo!!') + time_factor = noise_scheduler.sigmas[i] ** 2 + x = self.update_loss_LoCo( input,i, index, ts, time_factor = time_factor) + + elif loss_type=='LAC': + #print('Utilizing LoCo!!') + x = self.update_loss_LAC( input,i, index, ts ) input["x"] = x img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next) input["x"] = img @@ -116,6 +132,86 @@ class PLMSSampler(object): return img + def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ): + if index1 < 10: + loss_scale = 3 + max_iter = 5 + elif index1 < 20: + loss_scale = 2 + max_iter = 3 + else: + loss_scale = 1 + max_iter = 1 + + loss_threshold = 0.1 + max_index = 30 + x = deepcopy(input["x"]) + iteration = 0 + loss = torch.tensor(10000) + input["timesteps"] = ts + + print("optimize", index1) + while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) : + print('iter', iteration) + x = x.requires_grad_(True) + input['x'] = x + e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input) + bboxes = input['boxes'] + object_positions = input['object_position'] + loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes, + object_positions=object_positions, t = index1)*loss_scale + loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes, + object_positions=object_positions, t = index1)*loss_scale + loss = loss1 + loss2 + print('AR loss:', loss, 'SAR:', loss1, 'CAR:', loss2) + hh = torch.autograd.backward(loss) + grad_cond = x.grad + x = x - grad_cond + x = x.detach() + iteration += 1 + torch.cuda.empty_cache() + return x + + def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'): + + if index1 < 10: + loss_scale = 3 + max_iter = 5 + elif index1 < 20: + loss_scale = 2 + max_iter = 5 + else: + loss_scale = 1 + max_iter = 1 + loss_threshold = 0.1 + + max_index = 30 + x = deepcopy(input["x"]) + iteration = 0 + loss = torch.tensor(10000) + input["timesteps"] = ts + + print("optimize", index1) + while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) : + print('iter', iteration) + x = x.requires_grad_(True) + print('x shape', x.shape) + input['x'] = x + e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input) + + bboxes = input['boxes'] + object_positions = input['object_position'] + loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes, + object_positions=object_positions, t = index1)*loss_scale + loss = loss2 + print('loss', loss) + hh = torch.autograd.backward(loss, retain_graph=True) + grad_cond = x.grad + x = x - grad_cond + x = x.detach() + iteration += 1 + torch.cuda.empty_cache() + return x def update_loss_LoCo(self, input,index1, index, ts, time_factor, type_loss='self_accross'): @@ -165,7 +261,90 @@ class PLMSSampler(object): torch.cuda.empty_cache() return x + def update_loss_LAC(self, input,index1, index, ts,type_loss='self_accross'): + + # loss_scale = 30 + # max_iter = 5 + + if index1 < 10: + loss_scale = 6 + max_iter = 5 + elif index1 < 20: + loss_scale = 4 + max_iter = 3 + else: + loss_scale = 1 + max_iter = 1 + loss_threshold = 0.002 + + max_index = 30 + x = deepcopy(input["x"]) + iteration = 0 + loss = torch.tensor(10000) + input["timesteps"] = ts + + print("optimize", index1) + while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) : + print('iter', iteration) + x = x.requires_grad_(True) + # print('x shape', x.shape) + input['x'] = x + e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input) + + bboxes = input['boxes'] + object_positions = input['object_position'] + loss2 = caculate_loss_LAC(att_second,att_first,att_third, bboxes=bboxes, + object_positions=object_positions, t = index1)*loss_scale + loss = loss2 + print('LoCo loss', loss) + hh = torch.autograd.backward(loss, retain_graph=True) + grad_cond = x.grad + x = x - grad_cond + x = x.detach() + iteration += 1 + torch.cuda.empty_cache() + return x + + + def update_only_self(self, input,index1, index, ts,type_loss='self_accross' ): + if index1 < 10: + loss_scale = 4 + max_iter = 5 + elif index1 < 20: + loss_scale = 3 + max_iter = 5 + else: + loss_scale = 1 + max_iter = 1 + loss_threshold = 0.1 + + max_index = 30 + x = deepcopy(input["x"]) + iteration = 0 + loss = torch.tensor(10000) + input["timesteps"] = ts + + print("optimize", index1) + while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) : + print('iter', iteration) + x = x.requires_grad_(True) + input['x'] = x + e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input) + + bboxes = input['boxes'] + object_positions = input['object_position'] + loss = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes, + object_positions=object_positions, t = index1)*loss_scale + print('loss', loss) + hh = torch.autograd.backward(loss) + grad_cond = x.grad + + x = x - grad_cond + x = x.detach() + iteration += 1 + torch.cuda.empty_cache() + return x @torch.no_grad() def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None): diff --git a/gligen/ldm/modules/.DS_Store b/gligen/ldm/modules/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..4c77b6cb34a492f9c0d376c771e131c8e8dc9388 Binary files /dev/null and b/gligen/ldm/modules/.DS_Store differ diff --git a/gligen/ldm/modules/__pycache__/attention.cpython-310.pyc b/gligen/ldm/modules/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cc90a3c5910b61919e89cfbaca109f20795adcc Binary files /dev/null and b/gligen/ldm/modules/__pycache__/attention.cpython-310.pyc differ diff --git a/gligen/ldm/modules/__pycache__/attention.cpython-38.pyc b/gligen/ldm/modules/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab45dbe365eb7a3fdccee25685d18676c93762b2 Binary files /dev/null and b/gligen/ldm/modules/__pycache__/attention.cpython-38.pyc differ diff --git a/gligen/ldm/modules/__pycache__/x_transformer.cpython-310.pyc b/gligen/ldm/modules/__pycache__/x_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d37ec1c0f305133ca702dc9fc379863b744bee92 Binary files /dev/null and b/gligen/ldm/modules/__pycache__/x_transformer.cpython-310.pyc differ diff --git a/gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc b/gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..506b77ab3cb448eba85cb89937a96c6649eebf20 Binary files /dev/null and b/gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc differ diff --git a/gligen/ldm/modules/attention.py b/gligen/ldm/modules/attention.py index c443da348bc1ce707487fb8962a13b1810a43454..2147b3d23b1a1ecd539e741cff42b61c29476a97 100644 --- a/gligen/ldm/modules/attention.py +++ b/gligen/ldm/modules/attention.py @@ -4,17 +4,13 @@ import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat - +# import configigure # from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder from torch.utils import checkpoint +import os +from torchvision.utils import save_image -try: - import xformers - import xformers.ops - XFORMERS_IS_AVAILBLE = True -except: - XFORMERS_IS_AVAILBLE = False - +iter_att = 0 def exists(val): return val is not None @@ -106,13 +102,14 @@ class LinearAttention(nn.Module): + class CrossAttention(nn.Module): def __init__(self, query_dim, key_dim, value_dim, heads=8, dim_head=64, dropout=0): super().__init__() inner_dim = dim_head * heads self.scale = dim_head ** -0.5 self.heads = heads - self.dim_head = dim_head + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(key_dim, inner_dim, bias=False) @@ -129,9 +126,18 @@ class CrossAttention(nn.Module): max_neg_value = -torch.finfo(sim.dtype).max sim.masked_fill_(~mask, max_neg_value) return sim + # def scaled_dot_product(q, k, v, mask=None): + # d_k = q.size()[-1] + # attn_logits = torch.matmul(q, k.transpose(-2, -1)) + # attn_logits = attn_logits / math.sqrt(d_k) + # if mask is not None: + # attn_logits = attn_logits.masked_fill(mask == 0, -9e15) + # attention = F.softmax(attn_logits, dim=-1) + # values = torch.matmul(attention, v) + # return values, attention - def forward_plain(self, x, key, value, mask=None): - + def forward(self, x, key, value, mask=None): + # import pdb; pdb.set_trace() q = self.to_q(x) # B*N*(H*C) k = self.to_k(key) # B*M*(H*C) v = self.to_v(value) # B*M*(H*C) @@ -148,44 +154,21 @@ class CrossAttention(nn.Module): sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale # (B*H)*N*M self.fill_inf_from_mask(sim, mask) attn = sim.softmax(dim=-1) # (B*H)*N*M - + # import pdb; pdb.set_trace() + # if attn.shape[1] == 4096: + # self.visual_att(attn) out = torch.einsum('b i j, b j d -> b i d', attn, v) # (B*H)*N*C out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C) - return self.to_out(out) - - def forward(self, x, key, value, mask=None): - if not XFORMERS_IS_AVAILBLE: - return self.forward_plain(x, key, value, mask) - - q = self.to_q(x) # B*N*(H*C) - k = self.to_k(key) # B*M*(H*C) - v = self.to_v(value) # B*M*(H*C) - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) - - if exists(mask): - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) - return self.to_out(out) - - + return self.to_out(out), attn + def visual_att(self, att): + global iter_att + ll = [0,2,7] + for i in range(12): + kk = torch.sum(att[:,:,i], axis=0) + kk = kk.reshape(64,64) + save_image( (kk-kk.min()) / (kk.max() - kk.min()) , os.path.join('att', str(iter_att) + '_' +str(i) + '.png')) + iter_att += 1 @@ -195,7 +178,6 @@ class SelfAttention(nn.Module): inner_dim = dim_head * heads self.scale = dim_head ** -0.5 self.heads = heads - self.dim_head = dim_head self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(query_dim, inner_dim, bias=False) @@ -203,7 +185,7 @@ class SelfAttention(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) - def forward_plain(self, x): + def forward(self, x, gated=False): q = self.to_q(x) # B*N*(H*C) k = self.to_k(x) # B*N*(H*C) v = self.to_v(x) # B*N*(H*C) @@ -211,50 +193,29 @@ class SelfAttention(nn.Module): B, N, HC = q.shape H = self.heads C = HC // H - + # if gated: import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N attn = sim.softmax(dim=-1) # (B*H)*N*N - + # if gated and attn.shape[1] == 4126: + # self.visual_att(attn) out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C) - return self.to_out(out) - - def forward(self, x, context=None, mask=None): - if not XFORMERS_IS_AVAILBLE: - return self.forward_plain(x) - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) - - if exists(mask): - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) - return self.to_out(out) + return self.to_out(out), attn + + def visual_att(self, att): + global iter_att + ll = [0,2,7] + for i in range(): + kk = torch.sum(att[i],axis=0) + kk = kk[:4096].reshape(64,64) + save_image( (kk-kk.min()) / (kk.max() - kk.min()) , os.path.join('att', str(iter_att) + '_' +str(i) + '.png')) + iter_att += 1 class GatedCrossAttentionDense(nn.Module): @@ -272,7 +233,7 @@ class GatedCrossAttentionDense(nn.Module): # this can be useful: we can externally change magnitude of tanh(alpha) # for example, when it is set to 0, then the entire model is same as original one - self.scale = 1 + self.scale = 1 def forward(self, x, objs): @@ -303,17 +264,87 @@ class GatedSelfAttentionDense(nn.Module): self.scale = 1 + def forward(self, x, objs,t): + # if t >300: + # self.scale = 1 + # elif t > 200: + # self.scale = 0.9 + # else: + # self.scale = 0.6 + # if t >700: + # self.scale = 1 + # elif t > 300: + # self.scale = 0.7 + # else: + # self.scale = 0.4 + # self.scale = 0 + + N_visual = x.shape[1] + objs = self.linear(objs) + out, grounding_att = self.attn( self.norm1(torch.cat([x,objs],dim=1)), True ) + out = out[:,0:N_visual,:] + x = x + self.scale*torch.tanh(self.alpha_attn) * out + x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) ) + + return x , grounding_att + + + + + + +class GatedSelfAttentionDense2(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) ) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) ) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as original one + self.scale = 1 + + def forward(self, x, objs): - N_visual = x.shape[1] + B, N_visual, _ = x.shape + B, N_ground, _ = objs.shape + objs = self.linear(objs) - x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(torch.cat([x,objs],dim=1)) )[:,0:N_visual,:] + # sanity check + size_v = math.sqrt(N_visual) + size_g = math.sqrt(N_ground) + assert int(size_v) == size_v, "Visual tokens must be square rootable" + assert int(size_g) == size_g, "Grounding tokens must be square rootable" + size_v = int(size_v) + size_g = int(size_g) + + # select grounding token and resize it to visual token size as residual + out = self.attn( self.norm1(torch.cat([x,objs],dim=1)) )[:,N_visual:,:] + out = out.permute(0,2,1).reshape( B,-1,size_g,size_g ) + out = torch.nn.functional.interpolate(out, (size_v,size_v), mode='bicubic') + residual = out.reshape(B,-1,N_visual).permute(0,2,1) + + # add residual to visual feature + x = x + self.scale*torch.tanh(self.alpha_attn) * residual x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) ) return x + + + class BasicTransformerBlock(nn.Module): def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=True): super().__init__() @@ -328,25 +359,34 @@ class BasicTransformerBlock(nn.Module): if fuser_type == "gatedSA": # note key_dim here actually is context_dim self.fuser = GatedSelfAttentionDense(query_dim, key_dim, n_heads, d_head) + elif fuser_type == "gatedSA2": + # note key_dim here actually is context_dim + self.fuser = GatedSelfAttentionDense2(query_dim, key_dim, n_heads, d_head) elif fuser_type == "gatedCA": self.fuser = GatedCrossAttentionDense(query_dim, key_dim, value_dim, n_heads, d_head) else: assert False - def forward(self, x, context, objs): + def forward(self, x, context, objs,t): # return checkpoint(self._forward, (x, context, objs), self.parameters(), self.use_checkpoint) - if self.use_checkpoint and x.requires_grad: - return checkpoint.checkpoint(self._forward, x, context, objs) - else: - return self._forward(x, context, objs) - - def _forward(self, x, context, objs): - x = self.attn1( self.norm1(x) ) + x - x = self.fuser(x, objs) # identity mapping in the beginning - x = self.attn2(self.norm2(x), context, context) + x + # import pdb; pdb.set_trace() + # if self.use_checkpoint and x.requires_grad: + # return checkpoint.checkpoint(self._forward, x, context, objs,t) + # else: + return self._forward(x, context, objs,t) + + def _forward(self, x, context, objs,t): + # self_att_grounding = [] + out, self_prob = self.attn1( self.norm1(x) ) + x = x + out + x, self_prob_grounding = self.fuser(x, objs,t) # identity mapping in the beginning + x_1, prob = self.attn2(self.norm2(x), context, context) + x = x + x_1 x = self.ff(self.norm3(x)) + x - return x + # self_att_grounding.append(self_prob) + # self_att_grounding.append(self_prob_grounding) + return x, prob, self_prob class SpatialTransformer(nn.Module): @@ -356,7 +396,7 @@ class SpatialTransformer(nn.Module): query_dim = n_heads * d_head self.norm = Normalize(in_channels) - + self.proj_in = nn.Conv2d(in_channels, query_dim, kernel_size=1, @@ -374,14 +414,18 @@ class SpatialTransformer(nn.Module): stride=1, padding=0)) - def forward(self, x, context, objs): + def forward(self, x, context, objs,t): b, c, h, w = x.shape x_in = x x = self.norm(x) x = self.proj_in(x) x = rearrange(x, 'b c h w -> b (h w) c') + probs = [] + self_prob_list = [] for block in self.transformer_blocks: - x = block(x, context, objs) + x, prob, self_prob = block(x, context, objs,t) + probs.append(prob) + self_prob_list.append(self_prob) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) - return x + x_in \ No newline at end of file + return x + x_in, probs, self_prob_list \ No newline at end of file diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4da360ccbce2627cd457ac0b7ded90d655be5c2a Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93f5570e474602a9a3aea0ff5aa7e5c559b65b1d Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..801791333bcb0415e1f9fe39caf1c14d32b41019 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbb364cafadd2945347ab854337c8a7796e3f3af Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36f8e939771a478c5615ab03a63fdd28a9b8f9ec Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d2479f7e20192b033f24aef127c6a28b70944e9 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a23f0e3ce617bb29d18bf1b11247785f889758e Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..181ed91a5d567acce3f783168e1445ae2aeaab06 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e5e84dfe28e5bf086472a66bf1f7eaad5516221 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-310.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..890ce97d46f85170ac65139d0558ffe304894c5d Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-310.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5df70a6186c395c6a2baaa716781fe7b0455e624 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..045a35181d22dbf23c3637558dbf7de8cdabe768 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b41b3506c39ddcac8985749922726679da227b57 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a62ce04191ef90b485b1d5d6f5cc22ffd279695c Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/canny_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/canny_grounding_downsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..6331d15c76e0418a1e4a050d199727b53006ecfd --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/canny_grounding_downsampler.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F + + + +class GroundingDownsampler(nn.Module): + def __init__(self, resize_input=256, out_dim=8): + super().__init__() + self.resize_input = resize_input + self.out_dim = out_dim + + self.layers = nn.Sequential( + nn.Conv2d(1,4,4,2,1), + nn.SiLU(), + nn.Conv2d(4,self.out_dim,4,2,1) + ) + + def forward(self, grounding_extra_input): + # this is actually gary scale, but converted to rgb in dataset, information redudant + grounding_extra_input = grounding_extra_input[:,0].unsqueeze(1) + + out = torch.nn.functional.interpolate(grounding_extra_input, (self.resize_input,self.resize_input), mode='bicubic') + out = self.layers(out) + + assert out.shape[1] == self.out_dim + return out + + diff --git a/gligen/ldm/modules/diffusionmodules/canny_grounding_net.py b/gligen/ldm/modules/diffusionmodules/canny_grounding_net.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8fcf7c64b387d99f067466a8a265082f805a88 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/canny_grounding_net.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F +from ..attention import SelfAttention, FeedForward +from .convnext import convnext_tiny + + + + +class PositionNet(nn.Module): + def __init__(self, resize_input=448, out_dim=768): + super().__init__() + self.resize_input = resize_input + self.down_factor = 32 # determined by the convnext backbone + self.out_dim = out_dim + assert self.resize_input % self.down_factor == 0 + + self.convnext_tiny_backbone = convnext_tiny(pretrained=True) + + self.num_tokens = (self.resize_input // self.down_factor) ** 2 + + convnext_feature_dim = 768 + self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT + + self.linears = nn.Sequential( + nn.Linear( convnext_feature_dim, 512), + nn.SiLU(), + nn.Linear( 512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim])) + + + def forward(self, canny_edge, mask): + B = canny_edge.shape[0] + + # token from edge map + canny_edge = torch.nn.functional.interpolate(canny_edge, self.resize_input) + canny_edge_feature = self.convnext_tiny_backbone(canny_edge) + objs = canny_edge_feature.reshape(B, -1, self.num_tokens) + objs = objs.permute(0, 2, 1) # N*Num_tokens*dim + + # expand null token + null_objs = self.null_feature.view(1,1,-1) + null_objs = null_objs.repeat(B,self.num_tokens,1) + + # mask replacing + mask = mask.view(-1,1,1) + objs = objs*mask + null_objs*(1-mask) + + # add pos + objs = objs + self.pos_embedding + + # fuse them + objs = self.linears(objs) + + assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim]) + return objs + + + diff --git a/gligen/ldm/modules/diffusionmodules/convnext.py b/gligen/ldm/modules/diffusionmodules/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..71956848b6631ecb7ae12b9d684e69e142a3ef45 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/convnext.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_, DropPath +from timm.models.registry import register_model + +class Block(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + +class ConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + def __init__(self, in_chans=3, num_classes=1000, + depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., + layer_scale_init_value=1e-6, head_init_scale=1., + ): + super().__init__() + + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first") + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.Sequential( + *[Block(dim=dims[i], drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + # self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + # self.head = nn.Linear(dims[-1], num_classes) + + # self.apply(self._init_weights) + # self.head.weight.data.mul_(head_init_scale) + # self.head.bias.data.mul_(head_init_scale) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return x + # return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + return x + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +model_urls = { + "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", + "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", + "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", + "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", + "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", + "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", + "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", + "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", + "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", +} + +@register_model +def convnext_tiny(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) + if pretrained: + url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + model.load_state_dict(checkpoint["model"], strict=False) # we remove classifer head + return model + +@register_model +def convnext_small(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) + if pretrained: + url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def convnext_base(pretrained=False, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + if pretrained: + url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def convnext_large(pretrained=False, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + if pretrained: + url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + if pretrained: + assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" + url = model_urls['convnext_xlarge_22k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + return model \ No newline at end of file diff --git a/gligen/ldm/modules/diffusionmodules/depth_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/depth_grounding_downsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..80826ae2a96b615e24474f43b4ecdc9267049261 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/depth_grounding_downsampler.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F + + + +class GroundingDownsampler(nn.Module): + def __init__(self, resize_input=256, out_dim=8): + super().__init__() + self.resize_input = resize_input + self.out_dim = out_dim + + self.layers = nn.Sequential( + nn.Conv2d(1,4,4,2,1), + nn.SiLU(), + nn.Conv2d(4,self.out_dim,4,2,1) + ) + + def forward(self, grounding_extra_input): + # this is actually gary scale, but converted to rgb in dataset, information redudant + + grounding_extra_input = grounding_extra_input[:,0].unsqueeze(1) + + out = torch.nn.functional.interpolate(grounding_extra_input, (self.resize_input,self.resize_input), mode='bicubic') + out = self.layers(out) + + assert out.shape[1] == self.out_dim + return out + + diff --git a/gligen/ldm/modules/diffusionmodules/depth_grounding_net.py b/gligen/ldm/modules/diffusionmodules/depth_grounding_net.py new file mode 100644 index 0000000000000000000000000000000000000000..637816e79a97e38cf987e6311fa91d9792dc0fce --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/depth_grounding_net.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F +from ..attention import SelfAttention, FeedForward +from .convnext import convnext_tiny + + + + +class PositionNet(nn.Module): + def __init__(self, resize_input=448, out_dim=768): + super().__init__() + self.resize_input = resize_input + self.down_factor = 32 # determined by the convnext backbone + self.out_dim = out_dim + assert self.resize_input % self.down_factor == 0 + + self.convnext_tiny_backbone = convnext_tiny(pretrained=True) + + self.num_tokens = (self.resize_input // self.down_factor) ** 2 + + convnext_feature_dim = 768 + self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT + + self.linears = nn.Sequential( + nn.Linear( convnext_feature_dim, 512), + nn.SiLU(), + nn.Linear( 512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim])) + + + def forward(self, depth, mask): + B = depth.shape[0] + + # token from edge map + depth = torch.nn.functional.interpolate(depth, self.resize_input) + depth_feature = self.convnext_tiny_backbone(depth) + objs = depth_feature.reshape(B, -1, self.num_tokens) + objs = objs.permute(0, 2, 1) # N*Num_tokens*dim + + # expand null token + null_objs = self.null_feature.view(1,1,-1) + null_objs = null_objs.repeat(B,self.num_tokens,1) + + # mask replacing + mask = mask.view(-1,1,1) + objs = objs*mask + null_objs*(1-mask) + + # add pos + objs = objs + self.pos_embedding + + # fuse them + objs = self.linears(objs) + + assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim]) + return objs + + + diff --git a/gligen/ldm/modules/diffusionmodules/grounding_net_example.py b/gligen/ldm/modules/diffusionmodules/grounding_net_example.py new file mode 100644 index 0000000000000000000000000000000000000000..7a09caf5e48bb11f789236a4c34bdbd9ee6cabee --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/grounding_net_example.py @@ -0,0 +1,22 @@ +""" +This is a high-level pseudo code for grounding net. + +This class needs to tokenize grounding input into gronding tokens which +will be used in GatedAttenion layers. + + +class PositionNet(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + kwargs should be defined by model.grounding_tokenizer in config yaml file. + + def forward(self, **kwargs): + + kwargs should be the output of grounding_tokenizer_input network + + return grounding_tokens # with shape: Batch * Num_Of_Token* Token_Channel_Dimension + + + +""" \ No newline at end of file diff --git a/gligen/ldm/modules/diffusionmodules/hed_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/hed_grounding_downsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..99d1e7def372b74db331b6f50d3dc4574ace47a2 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/hed_grounding_downsampler.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F + + + +class GroundingDownsampler(nn.Module): + def __init__(self, out_dim=1): + super().__init__() + self.out_dim = out_dim + # No learnable params for hed edge map, just downsample it with bicubic + + def forward(self, grounding_extra_input): + # this is actually gary scale, but converted to rgb in dataset, information redudant + grounding_extra_input = grounding_extra_input[:,0].unsqueeze(1) + + out = torch.nn.functional.interpolate(grounding_extra_input, (64,64), mode='bicubic') + assert out.shape[1] == self.out_dim + return out + + diff --git a/gligen/ldm/modules/diffusionmodules/hed_grounding_net.py b/gligen/ldm/modules/diffusionmodules/hed_grounding_net.py new file mode 100644 index 0000000000000000000000000000000000000000..e566bb35c914abd19e51c8661d54a8702c3d55df --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/hed_grounding_net.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F +from ..attention import SelfAttention, FeedForward +from .convnext import convnext_tiny + + + + +class PositionNet(nn.Module): + def __init__(self, resize_input=448, out_dim=768): + super().__init__() + self.resize_input = resize_input + self.down_factor = 32 # determined by the res50 backbone + self.out_dim = out_dim + assert self.resize_input % self.down_factor == 0 + + self.convnext_tiny_backbone = convnext_tiny(pretrained=True) + + self.num_tokens = (self.resize_input // self.down_factor) ** 2 + + convnext_feature_dim = 768 + self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT + + self.linears = nn.Sequential( + nn.Linear( convnext_feature_dim, 512), + nn.SiLU(), + nn.Linear( 512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim])) + + + def forward(self, hed_edge, mask): + B = hed_edge.shape[0] + + # token from edge map + hed_edge = torch.nn.functional.interpolate(hed_edge, self.resize_input) + hed_edge_feature = self.convnext_tiny_backbone(hed_edge) + objs = hed_edge_feature.reshape(B, -1, self.num_tokens) + objs = objs.permute(0, 2, 1) # N*Num_tokens*dim + + # expand null token + null_objs = self.null_feature.view(1,1,-1) + null_objs = null_objs.repeat(B,self.num_tokens,1) + + # mask replacing + mask = mask.view(-1,1,1) + objs = objs*mask + null_objs*(1-mask) + + # add pos + objs = objs + self.pos_embedding + + # fuse them + objs = self.linears(objs) + + assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim]) + return objs + + + diff --git a/gligen/ldm/modules/diffusionmodules/keypoint_grounding_net.py b/gligen/ldm/modules/diffusionmodules/keypoint_grounding_net.py new file mode 100644 index 0000000000000000000000000000000000000000..e9da67a713c917ac8aedf09ba2803421550021e3 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/keypoint_grounding_net.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F + + + +class PositionNet(nn.Module): + def __init__(self, max_persons_per_image, out_dim, fourier_freqs=8): + super().__init__() + self.max_persons_per_image = max_persons_per_image + self.out_dim = out_dim + + self.person_embeddings = torch.nn.Parameter(torch.zeros([max_persons_per_image,out_dim])) + self.keypoint_embeddings = torch.nn.Parameter(torch.zeros([17,out_dim])) + + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs*2*2 # 2 is sin&cos, 2 is xy + + self.linears = nn.Sequential( + nn.Linear( self.out_dim + self.position_dim, 512), + nn.SiLU(), + nn.Linear( 512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_person_feature = torch.nn.Parameter(torch.zeros([self.out_dim])) + self.null_xy_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + + def forward(self, points, masks): + + masks = masks.unsqueeze(-1) + N = points.shape[0] + + person_embeddings = self.person_embeddings.unsqueeze(1).repeat(1,17,1).reshape(self.max_persons_per_image*17, self.out_dim) + keypoint_embeddings = torch.cat([self.keypoint_embeddings]*self.max_persons_per_image, dim=0) + person_embeddings = person_embeddings + keypoint_embeddings # (num_person*17) * C + person_embeddings = person_embeddings.unsqueeze(0).repeat(N,1,1) + + # embedding position (it may includes padding as placeholder) + xy_embedding = self.fourier_embedder(points) # B*N*2 --> B*N*C + + + # learnable null embedding + person_null = self.null_person_feature.view(1,1,-1) + xy_null = self.null_xy_feature.view(1,1,-1) + + # replace padding with learnable null embedding + person_embeddings = person_embeddings*masks + (1-masks)*person_null + xy_embedding = xy_embedding*masks + (1-masks)*xy_null + + objs = self.linears( torch.cat([person_embeddings, xy_embedding], dim=-1) ) + + return objs + + + diff --git a/gligen/ldm/modules/diffusionmodules/normal_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/normal_grounding_downsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..b663401b253e09bd3f1cd78e725373c1f537b4f8 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/normal_grounding_downsampler.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F + + + +class GroundingDownsampler(nn.Module): + def __init__(self, resize_input=256, out_dim=8): + super().__init__() + self.resize_input = resize_input + self.out_dim = out_dim + + self.layers = nn.Sequential( + nn.Conv2d(3,4,4,2,1), + nn.SiLU(), + nn.Conv2d(4,self.out_dim,4,2,1) + ) + + def forward(self, grounding_extra_input): + + out = torch.nn.functional.interpolate(grounding_extra_input, (self.resize_input,self.resize_input), mode='bicubic') + out = self.layers(out) + + assert out.shape[1] == self.out_dim + return out + + diff --git a/gligen/ldm/modules/diffusionmodules/normal_grounding_net.py b/gligen/ldm/modules/diffusionmodules/normal_grounding_net.py new file mode 100644 index 0000000000000000000000000000000000000000..38cadb7c9321f2d4aacabf3a4e31ac2207bebb32 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/normal_grounding_net.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F +from ..attention import SelfAttention, FeedForward +from .convnext import convnext_tiny + + + + +class PositionNet(nn.Module): + def __init__(self, resize_input=448, out_dim=768): + super().__init__() + self.resize_input = resize_input + self.down_factor = 32 # determined by the convnext backbone + self.out_dim = out_dim + assert self.resize_input % self.down_factor == 0 + + self.convnext_tiny_backbone = convnext_tiny(pretrained=True) + + self.num_tokens = (self.resize_input // self.down_factor) ** 2 + + convnext_feature_dim = 768 + self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT + + self.linears = nn.Sequential( + nn.Linear( convnext_feature_dim, 512), + nn.SiLU(), + nn.Linear( 512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim])) + + + def forward(self, normal, mask): + B = normal.shape[0] + + # token from edge map + normal = torch.nn.functional.interpolate(normal, self.resize_input) + normal_feature = self.convnext_tiny_backbone(normal) + objs = normal_feature.reshape(B, -1, self.num_tokens) + objs = objs.permute(0, 2, 1) # N*Num_tokens*dim + + # expand null token + null_objs = self.null_feature.view(1,1,-1) + null_objs = null_objs.repeat(B,self.num_tokens,1) + + # mask replacing + mask = mask.view(-1,1,1) + objs = objs*mask + null_objs*(1-mask) + + # add pos + objs = objs + self.pos_embedding + + # fuse them + objs = self.linears(objs) + + assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim]) + return objs + + + diff --git a/gligen/ldm/modules/diffusionmodules/openaimodel.py b/gligen/ldm/modules/diffusionmodules/openaimodel.py index e96ba0266e47c20d4c11de4b94064e27a595ad3b..34e39ea3f9d8ab58055beb26783d14d047878a5a 100644 --- a/gligen/ldm/modules/diffusionmodules/openaimodel.py +++ b/gligen/ldm/modules/diffusionmodules/openaimodel.py @@ -17,7 +17,10 @@ from ldm.modules.diffusionmodules.util import ( timestep_embedding, ) from ldm.modules.attention import SpatialTransformer +# from .positionnet import PositionNet from torch.utils import checkpoint +from ldm.util import instantiate_from_config +from copy import deepcopy class TimestepBlock(nn.Module): """ @@ -37,15 +40,20 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): support it as an extra input. """ - def forward(self, x, emb, context, objs): + def forward(self, x, emb, context, objs,t): + probs = [] + self_prob_list = [] + for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): - x = layer(x, context, objs) + x, prob, self_prob = layer(x, context, objs,t) + probs.append(prob) + self_prob_list.append(self_prob) else: x = layer(x) - return x + return x, probs, self_prob_list class Upsample(nn.Module): @@ -200,10 +208,10 @@ class ResBlock(TimestepBlock): # return checkpoint( # self._forward, (x, emb), self.parameters(), self.use_checkpoint # ) - if self.use_checkpoint and x.requires_grad: - return checkpoint.checkpoint(self._forward, x, emb ) - else: - return self._forward(x, emb) + # if self.use_checkpoint and x.requires_grad: + # return checkpoint.checkpoint(self._forward, x, emb ) + # else: + return self._forward(x, emb) def _forward(self, x, emb): @@ -247,12 +255,15 @@ class UNetModel(nn.Module): use_checkpoint=False, num_heads=8, use_scale_shift_norm=False, - transformer_depth=1, - positive_len = 768, # this is pre-processing embedding len for each 'obj/box' + transformer_depth=1, + positive_len = 768, context_dim=None, - fuser_type = None, + fuser_type = None, is_inpaint = False, - is_style = False, + is_style = False, + grounding_downsampler = None, + + ): super().__init__() @@ -267,13 +278,13 @@ class UNetModel(nn.Module): self.conv_resample = conv_resample self.use_checkpoint = use_checkpoint self.num_heads = num_heads - self.positive_len = positive_len self.context_dim = context_dim self.fuser_type = fuser_type self.is_inpaint = is_inpaint - self.is_style = is_style - self.use_o2 = False # This will be turned into True by externally if use o2 durining training - assert fuser_type in ["gatedSA", "gatedCA"] + self.positive_len = positive_len + assert fuser_type in ["gatedSA","gatedSA2","gatedCA"] + + self.grounding_tokenizer_input = None # set externally time_embed_dim = model_channels * 4 @@ -284,9 +295,25 @@ class UNetModel(nn.Module): ) - total_in_channels = in_channels+in_channels+1 if self.is_inpaint else in_channels - self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, total_in_channels, model_channels, 3, padding=1))]) - + + self.downsample_net = None + self.additional_channel_from_downsampler = 0 + self.first_conv_type = "SD" + self.first_conv_restorable = True + if grounding_downsampler is not None: + self.downsample_net = instantiate_from_config(grounding_downsampler) + self.additional_channel_from_downsampler = self.downsample_net.out_dim + self.first_conv_type = "GLIGEN" + + if is_inpaint: + # The new added channels are: masked image (encoded image) and mask, which is 4+1 + in_c = in_channels+self.additional_channel_from_downsampler+in_channels+1 + self.first_conv_restorable = False # in inpaint; You must use extra channels to take in masked real image + else: + in_c = in_channels+self.additional_channel_from_downsampler + self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_c, model_channels, 3, padding=1))]) + + input_block_chans = [model_channels] ch = model_channels ds = 1 @@ -376,16 +403,36 @@ class UNetModel(nn.Module): zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) - if self.is_style: - from .positionnet_with_image import PositionNet - else: - from .positionnet import PositionNet - self.position_net = PositionNet(positive_len=positive_len, out_dim=context_dim) + # self.position_net = instantiate_from_config(grounding_tokenizer) + from .text_grounding_net import PositionNet + self.position_net = PositionNet(in_dim=positive_len, out_dim=context_dim) + + + + def restore_first_conv_from_SD(self): + if self.first_conv_restorable: + device = self.input_blocks[0][0].weight.device + + SD_weights = th.load("gligen/SD_input_conv_weight_bias.pth") + self.GLIGEN_first_conv_state_dict = deepcopy(self.input_blocks[0][0].state_dict()) + + self.input_blocks[0][0] = conv_nd(2, 4, 320, 3, padding=1) + self.input_blocks[0][0].load_state_dict(SD_weights) + self.input_blocks[0][0].to(device) + + self.first_conv_type = "SD" + else: + print("First conv layer is not restorable and skipped this process, probably because this is an inpainting model?") + + + def restore_first_conv_from_GLIGEN(self): + breakpoint() # TODO def forward_position_net(self,input): + # import pdb; pdb.set_trace() if ("boxes" in input): boxes, masks, text_embeddings = input["boxes"], input["masks"], input["text_embeddings"] _ , self.max_box, _ = text_embeddings.shape @@ -403,10 +450,6 @@ class UNetModel(nn.Module): return objs - - - - def forward_position_net_with_image(self,input): if ("boxes" in input): @@ -441,42 +484,72 @@ class UNetModel(nn.Module): return objs + def forward(self, input,unc=False): + + if ("boxes" in input): + # grounding_input = input["grounding_input"] + boxes, masks, text_embeddings = input["boxes"], input["masks"], input["text_embeddings"] + _ , self.max_box, _ = text_embeddings.shape + else: + # Guidance null case + # grounding_input = self.grounding_tokenizer_input.get_null_input() + # boxes, masks, text_embeddings = input["boxes"]*0, input["masks"]*0, input["text_embeddings"]*0 + dtype = input["x"].dtype + batch = input["x"].shape[0] + device = input["x"].device + boxes = th.zeros(batch, self.max_box, 4,).type(dtype).to(device) + masks = th.zeros(batch, self.max_box).type(dtype).to(device) + text_masks = th.zeros(batch, self.max_box).type(dtype).to(device) + image_masks = th.zeros(batch, self.max_box).type(dtype).to(device) + text_embeddings = th.zeros(batch, self.max_box, self.positive_len).type(dtype).to(device) + image_embeddings = th.zeros(batch, self.max_box, self.positive_len).type(dtype).to(device) + if self.training and random.random() < 0.1 : # random drop for guidance + boxes, masks, text_embeddings = boxes*0, masks*0, text_embeddings*0 - - def forward(self, input): - - if self.is_style: - objs = self.forward_position_net_with_image(input) - else: - objs = self.forward_position_net(input) - + objs = self.position_net( boxes, masks, text_embeddings ) - hs = [] - + # Time embedding + t_emb = timestep_embedding(input["timesteps"], self.model_channels, repeat_only=False) - if self.use_o2: - t_emb = t_emb.to(th.float16) # not sure why apex will not cast this emb = self.time_embed(t_emb) - + # input tensor h = input["x"] - if self.is_inpaint: + t = input["timesteps"] + if self.downsample_net != None and self.first_conv_type=="GLIGEN": + temp = self.downsample_net(input["grounding_extra_input"]) + h = th.cat( [h,temp], dim=1 ) + if self.is_inpaint:#self.inpaint_mode: + if self.downsample_net != None: + breakpoint() # TODO: think about this case h = th.cat( [h, input["inpainting_extra_input"]], dim=1 ) - context = input["context"] + # Text input + context = input["context"] + # Start forwarding + hs = [] + probs_first = [] + self_prob_list_first = [] + for module in self.input_blocks: - h = module(h, emb, context, objs) + h,prob, self_prob = module(h, emb, context, objs,t) hs.append(h) + probs_first.append(prob) + self_prob_list_first.append(self_prob) - h = self.middle_block(h, emb, context, objs) - + h,mid_prob, self_prob_list_second = self.middle_block(h, emb, context, objs,t) + + probs_third = [] + self_prob_list_third = [] for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb, context, objs) + h, prob, self_prob = module(h, emb, context, objs,t) + probs_third.append(prob) + self_prob_list_third.append(self_prob) - return self.out(h) + return self.out(h),probs_third , mid_prob, probs_first, self_prob_list_first, [self_prob_list_second], self_prob_list_third diff --git a/gligen/ldm/modules/diffusionmodules/pseudo_example.py b/gligen/ldm/modules/diffusionmodules/pseudo_example.py new file mode 100644 index 0000000000000000000000000000000000000000..7ba5014e9e8a6e71538232d86cba46e098110c0e --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/pseudo_example.py @@ -0,0 +1,52 @@ +""" +This is a high-level pseudo code for grounding net. + +This class needs to tokenize grounding input into gronding tokens which +will be used in GatedAttenion layers. + + +class PositionNet(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + kwargs should be defined by model.grounding_tokenizer in config yaml file. + + def forward(self, **kwargs): + + kwargs should be the output of grounding_tokenizer_input network + + return grounding_tokens # with shape: Batch * Num_Of_Token* Token_Channel_Dimension + + + +""" + + +# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # + + +""" +This is a high-level pseudo code for downsampler. + +This class needs to process input and output a spatial feature such that it will be +fed into the first conv layer. + + +class GroundingDownsampler(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + kwargs should be defined by model.grounding_downsampler in config yaml file. + + you MUST define self.out_dim such that Unet knows add how many extra layers + + + def forward(self, **kwargs): + + kwargs should be the output of grounding_downsampler_input network + + return spatial_feature # with shape: Batch * self.out_dim * H *W (64*64 for SD) + + + +""" \ No newline at end of file diff --git a/gligen/ldm/modules/diffusionmodules/resnet.py b/gligen/ldm/modules/diffusionmodules/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ce07516fef99554c51c58fb3379448cf89154f --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/resnet.py @@ -0,0 +1,337 @@ +import torch +from torch import Tensor +import torch.nn as nn +from typing import Type, Any, Callable, Union, List, Optional + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + print("Please manually decide which layer as output") + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + #self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # The comment resolution is based on input size is 224*224 + out = {} + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + out['f0'] = x # N*64*56*56 + + x = self.layer1(x) + out['f1'] = x # N*64*56*56 + + x = self.layer2(x) + out['f2'] = x # N*128*28*28 + + x = self.layer3(x) + out['f3'] = x # N*256*14*14 + + x = self.layer4(x) + out['f4'] = x # N*512*7*7 + return x + + + # x = self.avgpool(x) + # x = torch.flatten(x, 1) + # out['penultimate'] = x # N*512 + + # x = self.fc(x) + # out['logits'] = x # N*1000 + + # return out + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict, strict=False) # we remove fc, and only keep backbone + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) + diff --git a/gligen/ldm/modules/diffusionmodules/sem_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/sem_grounding_downsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..10e4fc09a40bb912860cd3186b4970cabd2a0938 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/sem_grounding_downsampler.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F + + + +class GroundingDownsampler(nn.Module): + def __init__(self, resize_input=256, in_dim=152, out_dim=8): + super().__init__() + self.resize_input = resize_input + self.out_dim = out_dim + + self.layers = nn.Sequential( + nn.Conv2d(in_dim,16,4,2,1), + nn.SiLU(), + nn.Conv2d(16,self.out_dim,4,2,1) + ) + + def forward(self, grounding_extra_input): + + out = torch.nn.functional.interpolate(grounding_extra_input, (self.resize_input,self.resize_input), mode='nearest') + out = self.layers(out) + + assert out.shape[1] == self.out_dim + return out + + diff --git a/gligen/ldm/modules/diffusionmodules/sem_grounding_net.py b/gligen/ldm/modules/diffusionmodules/sem_grounding_net.py new file mode 100644 index 0000000000000000000000000000000000000000..80ef6dd49dc5bcc58205913276c39982b07320b9 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/sem_grounding_net.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F +from ..attention import SelfAttention, FeedForward +from .convnext import convnext_tiny + + + + +class PositionNet(nn.Module): + def __init__(self, resize_input=448, in_dim=152, out_dim=768): + super().__init__() + + self.resize_input = resize_input + self.down_factor = 32 # determined by the convnext backbone + self.out_dim = out_dim + assert self.resize_input % self.down_factor == 0 + + self.in_conv = nn.Conv2d(in_dim,3,3,1,1) # from num_sem to 3 channels + self.convnext_tiny_backbone = convnext_tiny(pretrained=True) + + self.num_tokens = (self.resize_input // self.down_factor) ** 2 + + convnext_feature_dim = 768 + self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT + + self.linears = nn.Sequential( + nn.Linear( convnext_feature_dim, 512), + nn.SiLU(), + nn.Linear( 512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim])) + + + def forward(self, sem, mask): + B = sem.shape[0] + + # token from edge map + sem = torch.nn.functional.interpolate(sem, self.resize_input, mode="nearest") + sem = self.in_conv(sem) + sem_feature = self.convnext_tiny_backbone(sem) + objs = sem_feature.reshape(B, -1, self.num_tokens) + objs = objs.permute(0, 2, 1) # N*Num_tokens*dim + + # expand null token + null_objs = self.null_feature.view(1,1,-1) + null_objs = null_objs.repeat(B,self.num_tokens,1) + + # mask replacing + mask = mask.view(-1,1,1) + objs = objs*mask + null_objs*(1-mask) + + # add pos + objs = objs + self.pos_embedding + + # fuse them + objs = self.linears(objs) + + assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim]) + return objs + + + diff --git a/gligen/ldm/modules/diffusionmodules/text_grounding_net.py b/gligen/ldm/modules/diffusionmodules/text_grounding_net.py new file mode 100644 index 0000000000000000000000000000000000000000..288bb99290ebe828d0a191ab1a48b640d9f450cc --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/text_grounding_net.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F + + + +class PositionNet(nn.Module): + def __init__(self, in_dim, out_dim, fourier_freqs=8): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs*2*4 # 2 is sin&cos, 4 is xyxy + + self.linears = nn.Sequential( + nn.Linear( self.in_dim + self.position_dim, 512), + nn.SiLU(), + nn.Linear( 512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.in_dim])) + self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + + def forward(self, boxes, masks, positive_embeddings): + B, N, _ = boxes.shape + masks = masks.unsqueeze(-1) + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C + + # learnable null embedding + positive_null = self.null_positive_feature.view(1,1,-1) + xyxy_null = self.null_position_feature.view(1,1,-1) + + # replace padding with learnable null embedding + positive_embeddings = positive_embeddings*masks + (1-masks)*positive_null + xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null + + objs = self.linears( torch.cat([positive_embeddings, xyxy_embedding], dim=-1) ) + assert objs.shape == torch.Size([B,N,self.out_dim]) + return objs + + + diff --git a/gligen/ldm/modules/diffusionmodules/text_image_grounding_net.py b/gligen/ldm/modules/diffusionmodules/text_image_grounding_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d712c70c6ed1318d5977619c825905a3f722a857 --- /dev/null +++ b/gligen/ldm/modules/diffusionmodules/text_image_grounding_net.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder +import torch.nn.functional as F + + + +class PositionNet(nn.Module): + def __init__(self, in_dim, out_dim, fourier_freqs=8): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs*2*4 # 2 is sin&cos, 4 is xyxy + + # -------------------------------------------------------------- # + self.linears_text = nn.Sequential( + nn.Linear( self.in_dim + self.position_dim, 512), + nn.SiLU(), + nn.Linear( 512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.linears_image = nn.Sequential( + nn.Linear( self.in_dim + self.position_dim, 512), + nn.SiLU(), + nn.Linear( 512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + # -------------------------------------------------------------- # + self.null_text_feature = torch.nn.Parameter(torch.zeros([self.in_dim])) + self.null_image_feature = torch.nn.Parameter(torch.zeros([self.in_dim])) + self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + + def forward(self, boxes, masks, text_masks, image_masks, text_embeddings, image_embeddings): + B, N, _ = boxes.shape + masks = masks.unsqueeze(-1) # B*N*1 + text_masks = text_masks.unsqueeze(-1) # B*N*1 + image_masks = image_masks.unsqueeze(-1) # B*N*1 + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C + + # learnable null embedding + text_null = self.null_text_feature.view(1,1,-1) # 1*1*C + image_null = self.null_image_feature.view(1,1,-1) # 1*1*C + xyxy_null = self.null_position_feature.view(1,1,-1) # 1*1*C + + # replace padding with learnable null embedding + text_embeddings = text_embeddings*text_masks + (1-text_masks)*text_null + image_embeddings = image_embeddings*image_masks + (1-image_masks)*image_null + xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null + + objs_text = self.linears_text( torch.cat([text_embeddings, xyxy_embedding], dim=-1) ) + objs_image = self.linears_image( torch.cat([image_embeddings,xyxy_embedding], dim=-1) ) + objs = torch.cat( [objs_text,objs_image], dim=1 ) + + assert objs.shape == torch.Size([B,N*2,self.out_dim]) + return objs + + + diff --git a/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc b/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22a3bec3b2e28a6253ad9d10efcc793b99dedac1 Binary files /dev/null and b/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc differ diff --git a/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9e4a6a210e228373cf9c9c6f3f9455029c4d145 Binary files /dev/null and b/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc b/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3cb2d36ebb5540deba1ec7f512e3aa7d2fcb356 Binary files /dev/null and b/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc differ diff --git a/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..663f17e84df6d157242a63297c17dc0f4aa7b926 Binary files /dev/null and b/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc differ diff --git a/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc b/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16c6d485b9f67ea0d7b6de43c9298935b2530676 Binary files /dev/null and b/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..732165b4f79c8221b53aaf08739ccd2134a0adff Binary files /dev/null and b/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc b/gligen/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85b12ec23f4cd95b0cae336f58fe9f021dfc0efb Binary files /dev/null and b/gligen/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc differ diff --git a/gligen/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/gligen/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d872efd7c5506a0bb150b9132246e35d4a5c3369 Binary files /dev/null and b/gligen/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc differ diff --git a/gligen/task_grounded_generation.py b/gligen/task_grounded_generation.py index b5bea79a691faa4879adba54f0651f1bc865ac98..7b7d7d276a15271047f18ed00599677f4de7b54e 100644 --- a/gligen/task_grounded_generation.py +++ b/gligen/task_grounded_generation.py @@ -77,10 +77,12 @@ def load_common_ckpt(config, common_ckpt): return [autoencoder, text_encoder, diffusion] def load_ckpt(config, state_dict, common_instances): - model = instantiate_from_config(config.model).to(device).eval() + + model = instantiate_from_config(config.model).to(device) model.load_state_dict(state_dict['model']) set_alpha_scale(model, config.alpha_scale) + print("ckpt is loaded") return [model] + common_instances @@ -97,7 +99,7 @@ def project(x, projection_matrix): """ return x@torch.transpose(projection_matrix, 0, 1) - +@torch.no_grad() def get_clip_feature(model, processor, input, is_image=False): feature_type = ['before','after_reproject'] # text feature, image feature @@ -138,6 +140,7 @@ def complete_mask(has_mask, max_objs): @torch.no_grad() def fire_clip(text_encoder, meta, batch=1, max_objs=30, clip_model=None): + # import pdb; pdb.set_trace() phrases = meta["phrases"] images = meta["images"] @@ -184,19 +187,66 @@ def fire_clip(text_encoder, meta, batch=1, max_objs=30, clip_model=None): } return batch_to_device(out, device) +def remove_numbers(text): + result = ''.join([char for char in text if not char.isdigit()]) + return result +def process_box_phrase(names, bboxes): + d = {} + for i, phrase in enumerate(names): + phrase = phrase.replace('_',' ') + list_noun = phrase.split(' ') + for n in list_noun: + n = remove_numbers(n) + if not n in d.keys(): + d.update({n:[np.array(bboxes[i])]}) + else: + d[n].append(np.array(bboxes[i])) + return d + +def Pharse2idx_2(prompt, name_box): + prompt = prompt.replace('.','') + prompt = prompt.replace(',','') + prompt_list = prompt.strip('.').split(' ') + object_positions = [] + bbox_to_self_att = [] + for obj in name_box.keys(): + obj_position = [] + in_prompt = False + for word in obj.split(' '): + if word in prompt_list: + obj_first_index = prompt_list.index(word) + 1 + obj_position.append(obj_first_index) + in_prompt = True + elif word +'s' in prompt_list: + obj_first_index = prompt_list.index(word+'s') + 1 + obj_position.append(obj_first_index) + in_prompt = True + elif word +'es' in prompt_list: + obj_first_index = prompt_list.index(word+'es') + 1 + obj_position.append(obj_first_index) + in_prompt = True + if in_prompt : + bbox_to_self_att.append(np.array(name_box[obj])) + + object_positions.append(obj_position) + return object_positions, bbox_to_self_att -@torch.no_grad() + + +# @torch.no_grad() def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs): # -------------- prepare model and misc --------------- # + model, autoencoder, text_encoder, diffusion = loaded_model_list + batch_size = instruction["batch_size"] is_inpaint = True if "input_image" in instruction else False save_folder = os.path.join("create_samples", instruction["save_folder_name"]) - + # -------------- set seed if required --------------- # if instruction.get('fix_seed', False): @@ -206,10 +256,13 @@ def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs): torch.manual_seed(random_seed) # ------------- prepare input for the model ------------- # - batch = fire_clip(text_encoder, instruction, batch_size, clip_model=kwargs.get('clip_model', None)) - context = text_encoder.encode( [instruction["prompt"]]*batch_size ) - uc = text_encoder.encode( batch_size*[""] ) - # print(batch['boxes']) + with torch.no_grad(): + batch = fire_clip(text_encoder, instruction, batch_size, clip_model=kwargs.get('clip_model', None)) + context = text_encoder.encode( [instruction["prompt"]]*batch_size ) + uc = text_encoder.encode( batch_size*[""] ) + name_box = process_box_phrase(instruction['phrases'], instruction['locations']) + + position, box_att = Pharse2idx_2(instruction['prompt'],name_box ) input = dict(x = None, timesteps = None, context = context, @@ -218,7 +271,9 @@ def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs): text_masks = batch['text_masks'], image_masks = batch['image_masks'], text_embeddings = batch["text_embeddings"], - image_embeddings = batch["image_embeddings"] ) + image_embeddings = batch["image_embeddings"], + boxes_att=box_att, + object_position = position ) inpainting_mask = x0 = None # used for inpainting if is_inpaint: @@ -228,10 +283,8 @@ def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs): if instruction["actual_mask"] is not None: inpainting_mask = instruction["actual_mask"][None, None].expand(batch['boxes'].shape[0], -1, -1, -1).cuda() else: - # inpainting_mask = draw_masks_from_boxes( batch['boxes'], (x0.shape[-2], x0.shape[-1]) ).cuda() actual_boxes = [instruction['inpainting_boxes_nodrop'] for _ in range(batch['boxes'].shape[0])] inpainting_mask = draw_masks_from_boxes(actual_boxes, (x0.shape[-2], x0.shape[-1]) ).cuda() - # extra input for the model masked_x0 = x0*inpainting_mask inpainting_extra_input = torch.cat([masked_x0,inpainting_mask], dim=1) input["inpainting_extra_input"] = inpainting_extra_input @@ -249,7 +302,8 @@ def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs): # ------------- run sampler ... ------------- # shape = (batch_size, model.in_channels, model.image_size, model.image_size) samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=instruction['guidance_scale'], mask=inpainting_mask, x0=x0) - samples_fake = autoencoder.decode(samples_fake) + with torch.no_grad(): + samples_fake = autoencoder.decode(samples_fake) # ------------- other logistics ------------- # diff --git a/guide_imgs/0_A_train_on_top_of_a_surfboard..jpg b/guide_imgs/0_A_train_on_top_of_a_surfboard..jpg new file mode 100644 index 0000000000000000000000000000000000000000..1318f6f48ee175b459f23437c9d87e5057a605a2 Binary files /dev/null and b/guide_imgs/0_A_train_on_top_of_a_surfboard..jpg differ diff --git a/guide_imgs/0_a_bus_on_the_left_of_a_car.jpg b/guide_imgs/0_a_bus_on_the_left_of_a_car.jpg new file mode 100644 index 0000000000000000000000000000000000000000..24d35534cb7d736d63835bec13769c4a7fc78275 Binary files /dev/null and b/guide_imgs/0_a_bus_on_the_left_of_a_car.jpg differ diff --git a/guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg b/guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ec9c39c682b3a66a37a6ab8dce8d0e034f278956 Binary files /dev/null and b/guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg differ diff --git a/guide_imgs/0_two_cups_filled_with_steaming_hot_coffee_sit_side-by-side_on_a_wooden_table..jpg b/guide_imgs/0_two_cups_filled_with_steaming_hot_coffee_sit_side-by-side_on_a_wooden_table..jpg new file mode 100644 index 0000000000000000000000000000000000000000..deb50c00c256df9620d455c33935b8461d4f8c99 Binary files /dev/null and b/guide_imgs/0_two_cups_filled_with_steaming_hot_coffee_sit_side-by-side_on_a_wooden_table..jpg differ diff --git a/guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg b/guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg new file mode 100644 index 0000000000000000000000000000000000000000..29770c541f2fc67ae62bed1d5733729fe16119e6 Binary files /dev/null and b/guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg differ diff --git a/guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg b/guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg new file mode 100644 index 0000000000000000000000000000000000000000..5247285078fd13bc156407c134e8a78600d2a219 Binary files /dev/null and b/guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg differ diff --git a/guide_imgs/1_A_book_on_the_right_of_a_vase_on_the_table._There_is_a_cup_of_tea_on_the_top_of_the_book.jpg b/guide_imgs/1_A_book_on_the_right_of_a_vase_on_the_table._There_is_a_cup_of_tea_on_the_top_of_the_book.jpg new file mode 100644 index 0000000000000000000000000000000000000000..17aeb56bafcd0b7152eea7335f2343bc44fe919a Binary files /dev/null and b/guide_imgs/1_A_book_on_the_right_of_a_vase_on_the_table._There_is_a_cup_of_tea_on_the_top_of_the_book.jpg differ diff --git a/guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg b/guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg new file mode 100644 index 0000000000000000000000000000000000000000..00e1ee1373f4f6e480835152cd460da27878abd3 Binary files /dev/null and b/guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg differ diff --git a/guide_imgs/1_Two_cars_on_the_street..jpg b/guide_imgs/1_Two_cars_on_the_street..jpg new file mode 100644 index 0000000000000000000000000000000000000000..cecd67700703cf466915e8d62a80a9b9f49a26fd Binary files /dev/null and b/guide_imgs/1_Two_cars_on_the_street..jpg differ diff --git a/guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg b/guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg new file mode 100644 index 0000000000000000000000000000000000000000..f5fec00f44ade8070ac5392ac078fdea7189afab Binary files /dev/null and b/guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg differ diff --git a/guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg b/guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg new file mode 100644 index 0000000000000000000000000000000000000000..44ece39f17bbace51284f6b6251550773749db89 Binary files /dev/null and b/guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg differ diff --git a/guide_imgs/70_two_cats_are_curled_up_together_on_a_sunny_windowsill,_purring_contentedly..jpg b/guide_imgs/70_two_cats_are_curled_up_together_on_a_sunny_windowsill,_purring_contentedly..jpg new file mode 100644 index 0000000000000000000000000000000000000000..a0f9ea1efb86e911b6ea8d4d747152ecb63a93cf Binary files /dev/null and b/guide_imgs/70_two_cats_are_curled_up_together_on_a_sunny_windowsill,_purring_contentedly..jpg differ diff --git a/guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg b/guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg new file mode 100644 index 0000000000000000000000000000000000000000..5e4d2b6e47f4e11a0242e7309cd02dfb136b0cbc Binary files /dev/null and b/guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg differ diff --git a/images/cat_dog.jpg b/images/cat_dog.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a27416e5a6c7befe1e956667bb1eb4caac6e042d Binary files /dev/null and b/images/cat_dog.jpg differ diff --git a/images/img.png b/images/img.png new file mode 100644 index 0000000000000000000000000000000000000000..c71d32c25b47b434f85912881bca51a0799709de Binary files /dev/null and b/images/img.png differ