Spaces:
Sleeping
Sleeping
from transformers import AutoProcessor | |
from PIL import Image | |
import os | |
import torch | |
import pickle | |
## ACTUAL INPUT CONSTRUCTION | |
BASE_SPEAKER_LEN = 787 | |
def joint_listener_input(processor, context_images, description, device): | |
# Preliminaries | |
img_dir = "tangram_pngs" | |
raw_images = process_images(img_dir, context_images) | |
target_anno = description.lower() | |
prompt = construct_listener_full_prompt( | |
processor, target_anno, 0, "verbose_instruction" | |
) | |
# Listener processing | |
outputs = processor( | |
text=[prompt], | |
images=[raw_images], | |
return_tensors="pt" | |
).to(device) | |
l_input_tokens = outputs['input_ids'][:, :-2] | |
l_attn_mask = outputs['attention_mask'][:, :-2] | |
l_attn_mask[(l_input_tokens == 0).bool()] = 0 | |
images = outputs['pixel_values'] | |
l_image_attn_mask = outputs['pixel_attention_mask'] | |
# Speaker processing | |
prompts = [] | |
for i in range(10): | |
prompt = construct_speaker_full_prompt(processor, description, i, "information_after") | |
prompts.append(prompt) | |
outputs = processor( | |
text=prompts, | |
images=[raw_images]*10, | |
padding='longest', | |
return_tensors="pt" | |
).to(device) | |
s_input_tokens = outputs['input_ids'][:, :-1] | |
s_attn_mask = outputs['attention_mask'][:, :-1] | |
s_attn_mask[(s_input_tokens == 0).bool()] = 0 | |
s_image_attn_mask = outputs['pixel_attention_mask'] | |
s_target_tokens = outputs['input_ids'][:, 1:] | |
s_target_mask = [] | |
for i in range(10): | |
curr_mask = create_speaker_caption_mask(outputs['input_ids'][i], s_attn_mask[i]) | |
s_target_mask.append(curr_mask) | |
s_target_mask = torch.stack(s_target_mask, dim=0) | |
return images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens.unsqueeze(0), \ | |
s_attn_mask.unsqueeze(0), s_image_attn_mask.unsqueeze(0), s_target_mask.unsqueeze(0), \ | |
s_target_tokens.unsqueeze(0) | |
def joint_speaker_input(processor, image_paths, target_path, device): | |
# Get the prompt | |
img_dir = "tangram_pngs" | |
raw_images = process_images(img_dir, image_paths) | |
target_idx = image_paths.index(target_path) | |
base_prompt = construct_speaker_base_prompt(processor, target_idx, "information_after", process=True) | |
# Create the basic input | |
outputs = processor( | |
text=[base_prompt], | |
images=[raw_images], | |
return_tensors="pt" | |
).to(device) | |
input_tokens = outputs['input_ids'] | |
attn_mask = outputs['attention_mask'] | |
attn_mask[(input_tokens == 0).bool()] = 0 | |
images = outputs['pixel_values'] | |
image_attn_mask = outputs['pixel_attention_mask'] | |
return input_tokens, attn_mask, images, image_attn_mask, torch.LongTensor([target_idx]).to(device) | |
## UTILITIES | |
def get_processor(): | |
checkpoint = "HuggingFaceM4/idefics2-8b" | |
processor = AutoProcessor.from_pretrained(checkpoint, do_image_splitting=False, | |
size={"longest_edge": 448, "shortest_edge": 224}) | |
return processor | |
def get_index_to_token(): | |
index_to_token_path = "index_to_token.pkl" | |
with open(index_to_token_path, 'rb') as f: | |
index_to_token = pickle.load(f) | |
return index_to_token | |
def process_images(img_dir, context_images): | |
raw_images = [] | |
for img in context_images: | |
image_path = os.path.join(img_dir, img) | |
raw_image = Image.open(image_path).convert('RGB') | |
raw_images.append(raw_image) | |
return raw_images | |
def create_speaker_caption_mask(all_token_ids, text_mask): | |
# Overall token comp: pad + base + caption | |
padding_tokens = torch.sum(all_token_ids == 0).item() | |
caption_tokens = all_token_ids.shape[0] - (padding_tokens + BASE_SPEAKER_LEN) | |
# Construct a mask where the last caption tokens are 1 | |
target_mask = torch.zeros_like(text_mask) | |
target_mask[-caption_tokens:] = 1 | |
return target_mask.bool() | |
def construct_listener_full_prompt(processor, target_anno, target_idx, comprehension_prompt_type="verbose_instruction"): | |
target_anno = target_anno.lower().strip() | |
messages = [] | |
if comprehension_prompt_type == "verbose_instruction": | |
# User side: Intro | |
messages.append( | |
{ | |
"role" : "user", | |
"content" : [ | |
{"type" : "text", "text" : "You will be presented with a sequence of 10 images and a caption describing exactly one of them. "}, | |
{"type" : "text", "text" : "Your task is to guess which image the caption describes. "}, | |
] | |
} | |
) | |
# User side: Images | |
for i in range(10): | |
if i == 0: | |
messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "}) | |
else: | |
messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "}) | |
messages[0]["content"].append({"type" : "image"}) | |
# User side: Caption | |
messages[0]["content"].append({"type" : "text", "text" : f". Caption: {target_anno}"}) | |
messages[0]["content"].append({"type" : "text", "text" : f" Does this caption describe Image 0, 1, 2, 3, 4, 5, 6, 7, 8 or 9?"}) | |
# Model side: Guess | |
messages.append( | |
{ | |
"role" : "assistant", | |
"content" : [ | |
{"type" : "text", "text" : f"The caption describes Image {target_idx}"} | |
] | |
} | |
) | |
else: | |
assert(False) | |
return processor.apply_chat_template(messages, add_generation_prompt=False).strip() | |
def construct_speaker_full_prompt(processor, target_anno, target_idx, | |
generation_prompt_type="information_after"): | |
messages = construct_speaker_base_prompt(processor, target_idx, generation_prompt_type) | |
# Assistant response | |
target_anno = target_anno.lower().strip() | |
messages.append( | |
{ | |
"role" : "assistant", | |
"content" : [ | |
{"type" : "text", "text" : target_anno} | |
] | |
} | |
) | |
return processor.apply_chat_template(messages, add_generation_prompt=False).strip() | |
def construct_speaker_base_prompt(processor, target_idx, generation_prompt_type="information_after", process=False): | |
messages = [] | |
if generation_prompt_type == "information_after": | |
# User side: Intro | |
messages.append( | |
{ | |
"role" : "user", | |
"content" : [ | |
{"type" : "text", "text" : "You will be presented with a sequence of 10 images and be assigned a target image. "}, | |
{"type" : "text", "text" : "Your task is to produce a caption for your target image such that anyone could guess the image from your description. "}, | |
] | |
} | |
) | |
# User side: Images | |
for i in range(10): | |
if i == 0: | |
messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "}) | |
else: | |
messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "}) | |
messages[0]["content"].append({"type" : "image"}) | |
# User side: Target assignment | |
messages[0]["content"].append({"type" : "text", "text" : f". Your target image is Image {target_idx}. Produce your caption now."}) | |
else: | |
assert(False) | |
if process: | |
prompt = processor.apply_chat_template(messages, add_generation_prompt=True).strip() | |
return prompt | |
else: | |
return messages | |
def process_idefics_listener_generation_input(speaker_context, captions, processor, img_dir, num_samples, device): | |
# First construct the prompts | |
prompts, raw_images = get_listener_generation_prompts(speaker_context, captions, num_samples, img_dir, processor) | |
# Process the prompts | |
listener_inputs = processor( | |
text=prompts, | |
images=raw_images, | |
padding='longest', | |
return_tensors='pt' | |
) | |
input_tokens = listener_inputs['input_ids'][:, :-2].to(device) | |
attn_mask = listener_inputs['attention_mask'][:, :-2].to(device) | |
attn_mask[input_tokens == 0] = 0 | |
images = listener_inputs['pixel_values'].to(device) | |
image_attn_mask = listener_inputs['pixel_attention_mask'].to(device) | |
return input_tokens, attn_mask, images, image_attn_mask | |
def get_listener_generation_prompts(speaker_contexts, captions, num_samples, img_dir, processor): | |
prompts = [] | |
all_raw_images = [] | |
for i, speaker_context in enumerate(speaker_contexts): | |
raw_images = process_images(img_dir, speaker_context) | |
for j in range(num_samples): | |
curr_idx = i * num_samples + j | |
caption = captions[curr_idx] | |
prompt = construct_listener_full_prompt(processor, caption, 0, "verbose_instruction") | |
prompts.append(prompt) | |
all_raw_images.append(raw_images) | |
return prompts, all_raw_images | |