File size: 8,946 Bytes
8133f69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from transformers import AutoProcessor
from PIL import Image
import os
import torch
import pickle

## ACTUAL INPUT CONSTRUCTION
BASE_SPEAKER_LEN = 787

def joint_listener_input(processor, context_images, description, device):
    # Preliminaries
    img_dir = "tangram_pngs"
    raw_images = process_images(img_dir, context_images)
    target_anno = description.lower()
    prompt = construct_listener_full_prompt(
        processor, target_anno, 0, "verbose_instruction"
    )

    # Listener processing
    outputs = processor(
        text=[prompt],
        images=[raw_images],
        return_tensors="pt"
    ).to(device)
    l_input_tokens = outputs['input_ids'][:, :-2]
    l_attn_mask = outputs['attention_mask'][:, :-2]
    l_attn_mask[(l_input_tokens == 0).bool()] = 0
    images = outputs['pixel_values']
    l_image_attn_mask = outputs['pixel_attention_mask']

    # Speaker processing
    prompts = []
    for i in range(10):
        prompt = construct_speaker_full_prompt(processor, description, i, "information_after")
        prompts.append(prompt)
    outputs = processor(
        text=prompts,
        images=[raw_images]*10,
        padding='longest',
        return_tensors="pt"
    ).to(device)

    s_input_tokens = outputs['input_ids'][:, :-1]
    s_attn_mask = outputs['attention_mask'][:, :-1]
    s_attn_mask[(s_input_tokens == 0).bool()] = 0
    s_image_attn_mask = outputs['pixel_attention_mask']
    s_target_tokens = outputs['input_ids'][:, 1:]
    s_target_mask = []
    for i in range(10):
        curr_mask = create_speaker_caption_mask(outputs['input_ids'][i], s_attn_mask[i])
        s_target_mask.append(curr_mask)
    s_target_mask = torch.stack(s_target_mask, dim=0)

    return images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens.unsqueeze(0), \
        s_attn_mask.unsqueeze(0), s_image_attn_mask.unsqueeze(0), s_target_mask.unsqueeze(0), \
        s_target_tokens.unsqueeze(0)

def joint_speaker_input(processor, image_paths, target_path, device):
    # Get the prompt
    img_dir = "tangram_pngs"
    raw_images = process_images(img_dir, image_paths)
    target_idx = image_paths.index(target_path)
    base_prompt = construct_speaker_base_prompt(processor, target_idx, "information_after", process=True)

    # Create the basic input
    outputs = processor(
        text=[base_prompt],
        images=[raw_images],
        return_tensors="pt"
    ).to(device)

    input_tokens = outputs['input_ids']
    attn_mask = outputs['attention_mask']
    attn_mask[(input_tokens == 0).bool()] = 0
    images = outputs['pixel_values']
    image_attn_mask = outputs['pixel_attention_mask']

    return input_tokens, attn_mask, images, image_attn_mask, torch.LongTensor([target_idx]).to(device)


## UTILITIES

def get_processor():
    checkpoint = "HuggingFaceM4/idefics2-8b"
    processor = AutoProcessor.from_pretrained(checkpoint, do_image_splitting=False,
                                               size={"longest_edge": 448, "shortest_edge": 224})
    return processor

def get_index_to_token():
    index_to_token_path = "index_to_token.pkl"
    with open(index_to_token_path, 'rb') as f:
        index_to_token = pickle.load(f)
    return index_to_token

def process_images(img_dir, context_images):
    raw_images = []
    for img in context_images:
        image_path = os.path.join(img_dir, img)
        raw_image = Image.open(image_path).convert('RGB')
        raw_images.append(raw_image)
    return raw_images

def create_speaker_caption_mask(all_token_ids, text_mask):
    # Overall token comp: pad + base + caption
    padding_tokens = torch.sum(all_token_ids == 0).item()
    caption_tokens = all_token_ids.shape[0] - (padding_tokens + BASE_SPEAKER_LEN)

    # Construct a mask where the last caption tokens are 1
    target_mask = torch.zeros_like(text_mask)
    target_mask[-caption_tokens:] = 1

    return target_mask.bool()

def construct_listener_full_prompt(processor, target_anno, target_idx, comprehension_prompt_type="verbose_instruction"):
    target_anno = target_anno.lower().strip()
    messages = []

    if comprehension_prompt_type == "verbose_instruction":
        # User side: Intro
        messages.append(
            {
                "role" : "user",
                "content" : [
                    {"type" : "text", "text" : "You will be presented with a sequence of 10 images and a caption describing exactly one of them. "},
                    {"type" : "text", "text" : "Your task is to guess which image the caption describes. "},
                ]
            }
        )

        # User side: Images
        for i in range(10):
            if i == 0:
                messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "})
            else:
                messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "})
            messages[0]["content"].append({"type" : "image"})

        # User side: Caption
        messages[0]["content"].append({"type" : "text", "text" : f". Caption: {target_anno}"})
        messages[0]["content"].append({"type" : "text", "text" : f" Does this caption describe Image 0, 1, 2, 3, 4, 5, 6, 7, 8 or 9?"})

        # Model side: Guess
        messages.append(
            {
                "role" : "assistant",
                "content" : [
                    {"type" : "text", "text" : f"The caption describes Image {target_idx}"}
                ]
            }
        )
    else:
        assert(False)

    return processor.apply_chat_template(messages, add_generation_prompt=False).strip() 

def construct_speaker_full_prompt(processor, target_anno, target_idx,
                                  generation_prompt_type="information_after"):
    messages = construct_speaker_base_prompt(processor, target_idx, generation_prompt_type)

    # Assistant response
    target_anno = target_anno.lower().strip()
    messages.append(
        {
            "role" : "assistant",
            "content" : [
                {"type" : "text", "text" : target_anno}
            ]
        }
    )

    return processor.apply_chat_template(messages, add_generation_prompt=False).strip()

def construct_speaker_base_prompt(processor, target_idx, generation_prompt_type="information_after", process=False):
    messages = []

    if generation_prompt_type == "information_after":
        # User side: Intro
        messages.append(
            {
                "role" : "user",
                "content" : [
                    {"type" : "text", "text" : "You will be presented with a sequence of 10 images and be assigned a target image. "},
                    {"type" : "text", "text" : "Your task is to produce a caption for your target image such that anyone could guess the image from your description. "},
                ]
            }
        )

        # User side: Images
        for i in range(10):
            if i == 0:
                messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "})
            else:
                messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "})
            messages[0]["content"].append({"type" : "image"})

        # User side: Target assignment
        messages[0]["content"].append({"type" : "text", "text" : f". Your target image is Image {target_idx}. Produce your caption now."})
    else:
        assert(False)

    if process:
        prompt = processor.apply_chat_template(messages, add_generation_prompt=True).strip()
        return prompt
    else:
        return messages

def process_idefics_listener_generation_input(speaker_context, captions, processor, img_dir, num_samples, device):
    # First construct the prompts
    prompts, raw_images = get_listener_generation_prompts(speaker_context, captions, num_samples, img_dir, processor)

    # Process the prompts
    listener_inputs = processor(
        text=prompts,
        images=raw_images,
        padding='longest',
        return_tensors='pt'
    )

    input_tokens = listener_inputs['input_ids'][:, :-2].to(device)
    attn_mask = listener_inputs['attention_mask'][:, :-2].to(device)
    attn_mask[input_tokens == 0] = 0
    images = listener_inputs['pixel_values'].to(device)
    image_attn_mask = listener_inputs['pixel_attention_mask'].to(device)

    return input_tokens, attn_mask, images, image_attn_mask

def get_listener_generation_prompts(speaker_contexts, captions, num_samples, img_dir, processor):
    prompts = []
    all_raw_images = []

    for i, speaker_context in enumerate(speaker_contexts):
        raw_images = process_images(img_dir, speaker_context)
        for j in range(num_samples):
            curr_idx = i * num_samples + j
            caption = captions[curr_idx]
            prompt = construct_listener_full_prompt(processor, caption, 0, "verbose_instruction")

            prompts.append(prompt)
            all_raw_images.append(raw_images)
    return prompts, all_raw_images