import argparse from PIL import Image, ImageDraw from evaluator import Evaluator from omegaconf import OmegaConf from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler import os from transformers import CLIPProcessor, CLIPModel from copy import deepcopy import torch from ldm.util import instantiate_from_config from trainer import read_official_ckpt, batch_to_device from evaluator import set_alpha_scale, save_images, draw_masks_from_boxes import numpy as np import clip from functools import partial import torchvision.transforms.functional as F import random device = "cuda" def alpha_generator(length, type=[1,0,0]): """ length is total timestpes needed for sampling. type should be a list containing three values which sum should be 1 It means the percentage of three stages: alpha=1 stage linear deacy stage alpha=0 stage. For example if length=100, type=[0.8,0.1,0.1] then the first 800 stpes, alpha will be 1, and then linearly decay to 0 in the next 100 steps, and the last 100 stpes are 0. """ assert len(type)==3 assert type[0] + type[1] + type[2] == 1 stage0_length = int(type[0]*length) stage1_length = int(type[1]*length) stage2_length = length - stage0_length - stage1_length if stage1_length != 0: decay_alphas = np.arange(start=0, stop=1, step=1/stage1_length)[::-1] decay_alphas = list(decay_alphas) else: decay_alphas = [] alphas = [1]*stage0_length + decay_alphas + [0]*stage2_length assert len(alphas) == length return alphas def draw_box(img, locations): colors = ["red", "green", "blue", "olive", "orange", "brown", "cyan", "purple"] draw = ImageDraw.Draw(img) WW,HH = img.size for bid, box in enumerate(locations): draw.rectangle([box[0]*WW, box[1]*HH, box[2]*WW, box[3]*HH], outline =colors[bid % len(colors)], width=5) return img def load_common_ckpt(config, common_ckpt): autoencoder = instantiate_from_config(config.autoencoder).to(device).eval() text_encoder = instantiate_from_config(config.text_encoder).to(device).eval() diffusion = instantiate_from_config(config.diffusion).to(device) autoencoder.load_state_dict( common_ckpt["autoencoder"] ) text_encoder.load_state_dict( common_ckpt["text_encoder"] ) diffusion.load_state_dict( common_ckpt["diffusion"] ) return [autoencoder, text_encoder, diffusion] def load_ckpt(config, state_dict, common_instances): model = instantiate_from_config(config.model).to(device) model.load_state_dict(state_dict['model']) set_alpha_scale(model, config.alpha_scale) print("ckpt is loaded") return [model] + common_instances def project(x, projection_matrix): """ x (Batch*768) should be the penultimate feature of CLIP (before projection) projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer defined in CLIP (out_dim, in_dim), thus we need to apply transpose below. this function will return the CLIP feature (without normalziation) """ return x@torch.transpose(projection_matrix, 0, 1) @torch.no_grad() def get_clip_feature(model, processor, input, is_image=False): feature_type = ['before','after_reproject'] # text feature, image feature if is_image: image = input #Image.open(input).convert("RGB") inputs = processor(images=[image], return_tensors="pt", padding=True) inputs['pixel_values'] = inputs['pixel_values'].cuda() # we use our own preprocessing without center_crop inputs['input_ids'] = torch.tensor([[0,1,2,3]]).cuda() # placeholder outputs = model(**inputs) feature = outputs.image_embeds if feature_type[1] == 'after_renorm': feature = feature*28.7 if feature_type[1] == 'after_reproject': feature = project( feature, torch.load('gligen/projection_matrix.pth').cuda().T ).squeeze(0) feature = ( feature / feature.norm() ) * 28.7 feature = feature.unsqueeze(0) else: inputs = processor(text=input, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['pixel_values'] = torch.ones(1,3,224,224).cuda() # placeholder inputs['attention_mask'] = inputs['attention_mask'].cuda() outputs = model(**inputs) feature = outputs.text_embeds if feature_type[0] == 'after' else outputs.text_model_output.pooler_output return feature def complete_mask(has_mask, max_objs): mask = torch.ones(1,max_objs) if type(has_mask) == int or type(has_mask) == float: return mask * has_mask else: for idx, value in enumerate(has_mask): mask[0,idx] = value return mask @torch.no_grad() def fire_clip(text_encoder, meta, batch=1, max_objs=30, clip_model=None): # import pdb; pdb.set_trace() phrases = meta["phrases"] images = meta["images"] if clip_model is None: version = "openai/clip-vit-large-patch14" model = CLIPModel.from_pretrained(version).cuda() processor = CLIPProcessor.from_pretrained(version) else: version = "openai/clip-vit-large-patch14" assert clip_model['version'] == version model = clip_model['model'] processor = clip_model['processor'] boxes = torch.zeros(max_objs, 4) masks = torch.zeros(max_objs) text_embeddings = torch.zeros(max_objs, 768) image_embeddings = torch.zeros(max_objs, 768) text_features = [] image_features = [] for phrase, image in zip(phrases,images): text_features.append( get_clip_feature(model, processor, phrase, is_image=False) ) image_features.append( get_clip_feature(model, processor, image, is_image=True) ) if len(text_features) > 0: text_features = torch.cat(text_features, dim=0) image_features = torch.cat(image_features, dim=0) for idx, (box, text_feature, image_feature) in enumerate(zip( meta['locations'], text_features, image_features)): boxes[idx] = torch.tensor(box) masks[idx] = 1 text_embeddings[idx] = text_feature image_embeddings[idx] = image_feature out = { "boxes" : boxes.unsqueeze(0).repeat(batch,1,1), "masks" : masks.unsqueeze(0).repeat(batch,1), "text_masks" : masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta["has_text_mask"], max_objs ), "image_masks" : masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta["has_image_mask"], max_objs ), "text_embeddings" : text_embeddings.unsqueeze(0).repeat(batch,1,1), "image_embeddings" : image_embeddings.unsqueeze(0).repeat(batch,1,1) } return batch_to_device(out, device) def remove_numbers(text): result = ''.join([char for char in text if not char.isdigit()]) return result def process_box_phrase(names, bboxes): d = {} for i, phrase in enumerate(names): phrase = phrase.replace('_',' ') list_noun = phrase.split(' ') for n in list_noun: n = remove_numbers(n) if not n in d.keys(): d.update({n:[np.array(bboxes[i])]}) else: d[n].append(np.array(bboxes[i])) return d def Pharse2idx_2(prompt, name_box): prompt = prompt.replace('.','') prompt = prompt.replace(',','') prompt_list = prompt.strip('.').split(' ') object_positions = [] bbox_to_self_att = [] for obj in name_box.keys(): obj_position = [] in_prompt = False for word in obj.split(' '): if word in prompt_list: obj_first_index = prompt_list.index(word) + 1 obj_position.append(obj_first_index) in_prompt = True elif word +'s' in prompt_list: obj_first_index = prompt_list.index(word+'s') + 1 obj_position.append(obj_first_index) in_prompt = True elif word +'es' in prompt_list: obj_first_index = prompt_list.index(word+'es') + 1 obj_position.append(obj_first_index) in_prompt = True if in_prompt : bbox_to_self_att.append(np.array(name_box[obj])) object_positions.append(obj_position) return object_positions, bbox_to_self_att # @torch.no_grad() def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs): # -------------- prepare model and misc --------------- # model, autoencoder, text_encoder, diffusion = loaded_model_list batch_size = instruction["batch_size"] is_inpaint = True if "input_image" in instruction else False save_folder = os.path.join("create_samples", instruction["save_folder_name"]) # -------------- set seed if required --------------- # if instruction.get('fix_seed', False): random_seed = instruction['rand_seed'] random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) # ------------- prepare input for the model ------------- # with torch.no_grad(): batch = fire_clip(text_encoder, instruction, batch_size, clip_model=kwargs.get('clip_model', None)) context = text_encoder.encode( [instruction["prompt"]]*batch_size ) uc = text_encoder.encode( batch_size*[""] ) name_box = process_box_phrase(instruction['phrases'], instruction['locations']) position, box_att = Pharse2idx_2(instruction['prompt'],name_box ) input = dict(x = None, timesteps = None, context = context, boxes = batch['boxes'], masks = batch['masks'], text_masks = batch['text_masks'], image_masks = batch['image_masks'], text_embeddings = batch["text_embeddings"], image_embeddings = batch["image_embeddings"], boxes_att=box_att, object_position = position ) inpainting_mask = x0 = None # used for inpainting if is_inpaint: input_image = F.pil_to_tensor( instruction["input_image"] ) input_image = ( input_image.float().unsqueeze(0).cuda() / 255 - 0.5 ) / 0.5 x0 = autoencoder.encode( input_image ) if instruction["actual_mask"] is not None: inpainting_mask = instruction["actual_mask"][None, None].expand(batch['boxes'].shape[0], -1, -1, -1).cuda() else: actual_boxes = [instruction['inpainting_boxes_nodrop'] for _ in range(batch['boxes'].shape[0])] inpainting_mask = draw_masks_from_boxes(actual_boxes, (x0.shape[-2], x0.shape[-1]) ).cuda() masked_x0 = x0*inpainting_mask inpainting_extra_input = torch.cat([masked_x0,inpainting_mask], dim=1) input["inpainting_extra_input"] = inpainting_extra_input # ------------- prepare sampler ------------- # alpha_generator_func = partial(alpha_generator, type=instruction["alpha_type"]) if False: sampler = DDIMSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale) steps = 250 else: sampler = PLMSSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale) steps = 50 # ------------- run sampler ... ------------- # shape = (batch_size, model.in_channels, model.image_size, model.image_size) samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=instruction['guidance_scale'], mask=inpainting_mask, x0=x0) with torch.no_grad(): samples_fake = autoencoder.decode(samples_fake) # ------------- other logistics ------------- # sample_list = [] for sample in samples_fake: sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5 sample = sample.cpu().numpy().transpose(1,2,0) * 255 sample = Image.fromarray(sample.astype(np.uint8)) sample_list.append(sample) return sample_list, None # if __name__ == "__main__": # parser = argparse.ArgumentParser() # parser.add_argument("--folder", type=str, default="create_samples", help="path to OUTPUT") # parser.add_argument("--official_ckpt", type=str, default='../../../data/sd-v1-4.ckpt', help="") # parser.add_argument("--batch_size", type=int, default=10, help="This will overwrite the one in yaml.") # parser.add_argument("--no_plms", action='store_true') # parser.add_argument("--guidance_scale", type=float, default=5, help="") # parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model") # args = parser.parse_args() # assert "sd-v1-4.ckpt" in args.official_ckpt, "only support for stable-diffusion model" # grounded_generation(args)