from transformers import AutoProcessor from PIL import Image import os import torch import pickle ## ACTUAL INPUT CONSTRUCTION BASE_SPEAKER_LEN = 787 def joint_listener_input(processor, context_images, description, device): # Preliminaries img_dir = "tangram_pngs" raw_images = process_images(img_dir, context_images) target_anno = description.lower() prompt = construct_listener_full_prompt( processor, target_anno, 0, "verbose_instruction" ) # Listener processing outputs = processor( text=[prompt], images=[raw_images], return_tensors="pt" ).to(device) l_input_tokens = outputs['input_ids'][:, :-2] l_attn_mask = outputs['attention_mask'][:, :-2] l_attn_mask[(l_input_tokens == 0).bool()] = 0 images = outputs['pixel_values'] l_image_attn_mask = outputs['pixel_attention_mask'] # Speaker processing prompts = [] for i in range(10): prompt = construct_speaker_full_prompt(processor, description, i, "information_after") prompts.append(prompt) outputs = processor( text=prompts, images=[raw_images]*10, padding='longest', return_tensors="pt" ).to(device) s_input_tokens = outputs['input_ids'][:, :-1] s_attn_mask = outputs['attention_mask'][:, :-1] s_attn_mask[(s_input_tokens == 0).bool()] = 0 s_image_attn_mask = outputs['pixel_attention_mask'] s_target_tokens = outputs['input_ids'][:, 1:] s_target_mask = [] for i in range(10): curr_mask = create_speaker_caption_mask(outputs['input_ids'][i], s_attn_mask[i]) s_target_mask.append(curr_mask) s_target_mask = torch.stack(s_target_mask, dim=0) return images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens.unsqueeze(0), \ s_attn_mask.unsqueeze(0), s_image_attn_mask.unsqueeze(0), s_target_mask.unsqueeze(0), \ s_target_tokens.unsqueeze(0) def joint_speaker_input(processor, image_paths, target_path, device): # Get the prompt img_dir = "tangram_pngs" raw_images = process_images(img_dir, image_paths) target_idx = image_paths.index(target_path) base_prompt = construct_speaker_base_prompt(processor, target_idx, "information_after", process=True) # Create the basic input outputs = processor( text=[base_prompt], images=[raw_images], return_tensors="pt" ).to(device) input_tokens = outputs['input_ids'] attn_mask = outputs['attention_mask'] attn_mask[(input_tokens == 0).bool()] = 0 images = outputs['pixel_values'] image_attn_mask = outputs['pixel_attention_mask'] return input_tokens, attn_mask, images, image_attn_mask, torch.LongTensor([target_idx]).to(device) ## UTILITIES def get_processor(): checkpoint = "HuggingFaceM4/idefics2-8b" processor = AutoProcessor.from_pretrained(checkpoint, do_image_splitting=False, size={"longest_edge": 448, "shortest_edge": 224}) return processor def get_index_to_token(): index_to_token_path = "index_to_token.pkl" with open(index_to_token_path, 'rb') as f: index_to_token = pickle.load(f) return index_to_token def process_images(img_dir, context_images): raw_images = [] for img in context_images: image_path = os.path.join(img_dir, img) raw_image = Image.open(image_path).convert('RGB') raw_images.append(raw_image) return raw_images def create_speaker_caption_mask(all_token_ids, text_mask): # Overall token comp: pad + base + caption padding_tokens = torch.sum(all_token_ids == 0).item() caption_tokens = all_token_ids.shape[0] - (padding_tokens + BASE_SPEAKER_LEN) # Construct a mask where the last caption tokens are 1 target_mask = torch.zeros_like(text_mask) target_mask[-caption_tokens:] = 1 return target_mask.bool() def construct_listener_full_prompt(processor, target_anno, target_idx, comprehension_prompt_type="verbose_instruction"): target_anno = target_anno.lower().strip() messages = [] if comprehension_prompt_type == "verbose_instruction": # User side: Intro messages.append( { "role" : "user", "content" : [ {"type" : "text", "text" : "You will be presented with a sequence of 10 images and a caption describing exactly one of them. "}, {"type" : "text", "text" : "Your task is to guess which image the caption describes. "}, ] } ) # User side: Images for i in range(10): if i == 0: messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "}) else: messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "}) messages[0]["content"].append({"type" : "image"}) # User side: Caption messages[0]["content"].append({"type" : "text", "text" : f". Caption: {target_anno}"}) messages[0]["content"].append({"type" : "text", "text" : f" Does this caption describe Image 0, 1, 2, 3, 4, 5, 6, 7, 8 or 9?"}) # Model side: Guess messages.append( { "role" : "assistant", "content" : [ {"type" : "text", "text" : f"The caption describes Image {target_idx}"} ] } ) else: assert(False) return processor.apply_chat_template(messages, add_generation_prompt=False).strip() def construct_speaker_full_prompt(processor, target_anno, target_idx, generation_prompt_type="information_after"): messages = construct_speaker_base_prompt(processor, target_idx, generation_prompt_type) # Assistant response target_anno = target_anno.lower().strip() messages.append( { "role" : "assistant", "content" : [ {"type" : "text", "text" : target_anno} ] } ) return processor.apply_chat_template(messages, add_generation_prompt=False).strip() def construct_speaker_base_prompt(processor, target_idx, generation_prompt_type="information_after", process=False): messages = [] if generation_prompt_type == "information_after": # User side: Intro messages.append( { "role" : "user", "content" : [ {"type" : "text", "text" : "You will be presented with a sequence of 10 images and be assigned a target image. "}, {"type" : "text", "text" : "Your task is to produce a caption for your target image such that anyone could guess the image from your description. "}, ] } ) # User side: Images for i in range(10): if i == 0: messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "}) else: messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "}) messages[0]["content"].append({"type" : "image"}) # User side: Target assignment messages[0]["content"].append({"type" : "text", "text" : f". Your target image is Image {target_idx}. Produce your caption now."}) else: assert(False) if process: prompt = processor.apply_chat_template(messages, add_generation_prompt=True).strip() return prompt else: return messages def process_idefics_listener_generation_input(speaker_context, captions, processor, img_dir, num_samples, device): # First construct the prompts prompts, raw_images = get_listener_generation_prompts(speaker_context, captions, num_samples, img_dir, processor) # Process the prompts listener_inputs = processor( text=prompts, images=raw_images, padding='longest', return_tensors='pt' ) input_tokens = listener_inputs['input_ids'][:, :-2].to(device) attn_mask = listener_inputs['attention_mask'][:, :-2].to(device) attn_mask[input_tokens == 0] = 0 images = listener_inputs['pixel_values'].to(device) image_attn_mask = listener_inputs['pixel_attention_mask'].to(device) return input_tokens, attn_mask, images, image_attn_mask def get_listener_generation_prompts(speaker_contexts, captions, num_samples, img_dir, processor): prompts = [] all_raw_images = [] for i, speaker_context in enumerate(speaker_contexts): raw_images = process_images(img_dir, speaker_context) for j in range(num_samples): curr_idx = i * num_samples + j caption = captions[curr_idx] prompt = construct_listener_full_prompt(processor, caption, 0, "verbose_instruction") prompts.append(prompt) all_raw_images.append(raw_images) return prompts, all_raw_images