File size: 24,151 Bytes
8133f69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
# File: model_utils
# -----------------
# Contain utilities for models, such as loading and saving models

import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from transformers import GenerationConfig
from dataset import process_idefics_listener_generation_input
import pdb

def filter_targets(logits, index_to_token):
    target_logits = logits[:, index_to_token]
    return target_logits

class IdeficsJointInferenceModel(nn.Module):

    def __init__(self, listener_lambda, speaker_lambda,
                 model=None, listener=None, speaker=None):
        super().__init__()
        self.l_lambda = listener_lambda
        self.s_lambda = speaker_lambda

        self.has_shared_parameters = model is not None
        if self.has_shared_parameters:
            self.model = model
        else:
            self.listener = listener
            self.speaker = speaker

    def forward(self, inf_mode, arguments):
        if inf_mode == "joint_comprehension":
            return self.comprehension_side(arguments)
        elif inf_mode == "joint_reranking":
            return self.reranking_side(arguments)
        elif inf_mode == "comprehension":
            return self.split_comprehension_forward(arguments)
        elif inf_mode == "split_reranking":
            return self.split_reranking_forward(arguments)
        elif inf_mode == "generation":
            return self.split_generation_forward(arguments)

    def get_listener(self):
        if self.has_shared_parameters:
            return self.model
        else:
            return self.listener            

    def get_speaker(self):
        if self.has_shared_parameters:
            return self.model
        else:
            return self.speaker            

    def get_image_embeddings(self, pixel_values, pixel_attention_mask, model):
        '''
        Get image embeddings to avoid repeated computation for images during joint inference.
        Adapted from the IDEFICS-2 source code.
        '''
        # Get the model
        model = self.get_listener() if model == "listener" else self.get_speaker()
        if len(pixel_attention_mask.shape) == 5:
            pixel_attention_mask = pixel_attention_mask[:, 0].contiguous()

        # Assume images of form: BxCxcnlxHxW        
        batch_size, num_images, num_channels, height, width = pixel_values.shape
        pixel_values = pixel_values.to(dtype=model.dtype)  # fp16 compatibility
        pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])

        # Remove padding images - padding images are full 0.
        nb_values_per_image = pixel_values.shape[1:].numel()
        real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
        pixel_values = pixel_values[real_images_inds].contiguous()

        # Remove padding images from the mask/pP p
        pixel_attention_mask = pixel_attention_mask.view(
            batch_size * num_images, *pixel_attention_mask.shape[2:]
        )
        pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()

        patch_size = model.model.config.vision_config.patch_size
        patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
        patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
        patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

        # Get sequence from the vision encoder
        image_hidden_states = model.model.model.vision_model(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        ).last_hidden_state

        # Modality projection & resampling
        image_hidden_states = model.model.model.connector(
            image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
        )

        return image_hidden_states

    def split_comprehension_side(self, input_tokens, attn_mask, images, image_attn_mask, index_to_token):
        '''
        Redundant with split_comprehension_forward except for the final computation. 
        Used during deployment in ray_models.py.
        '''
        listener = self.get_listener()
        all_logits = listener(
            input_ids=input_tokens,
            attention_mask=attn_mask,
            pixel_values=images, 
            pixel_attention_mask=image_attn_mask
        )['logits']
        target_logits = filter_targets(all_logits[:, -1], index_to_token)
        listener_log_probs = F.log_softmax(target_logits, dim=1)
        return listener_log_probs

    def split_comprehension_forward(self, arguments):
        input_tokens, attn_mask, images, image_attn_mask = arguments
        listener = self.get_listener()
        all_logits = listener(
            input_ids=input_tokens,
            attention_mask=attn_mask,
            pixel_values=images,
            pixel_attention_mask=image_attn_mask
        )['logits']
        return all_logits

    def split_generation_forward(self, arguments):
        input_tokens, attn_mask, images, image_attn_mask = arguments
        speaker = self.get_speaker()
        all_logits = speaker(
            input_ids=input_tokens,
            attention_mask=attn_mask,
            pixel_values=images,
            pixel_attention_mask=image_attn_mask
        )['logits']
        return all_logits

    def split_reranking_forward(self, arguments):
        images, input_tokens, attn_mask, image_attn_mask, target_tokens, target_mask = arguments

        # Get the image embeddings
        image_embeddings = self.get_image_embeddings(images, image_attn_mask, "speaker")
        embed_shape = image_embeddings.shape
        B, mult = input_tokens.shape[:2]
        C = images.shape[1]
        image_embeddings = image_embeddings.view(B, C, *embed_shape[1:])
        image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:])

        annotation_mask = torch.zeros(B, mult, device=image_embeddings.device).bool()
        _, speaker_log_probs = self.reranking_speaker_side(image_embeddings, input_tokens, attn_mask,
                                                           image_attn_mask, target_tokens, target_mask,
                                                           annotation_mask)
        return speaker_log_probs

    def comprehension_side(self, arguments):
        images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, \
            s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label = arguments

        if self.has_shared_parameters:
            image_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener")
            listener_log_probs = self.comprehension_listener_side(
                image_embeddings, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token
            ) # TODO

            speaker_log_probs = self.comprehension_speaker_side(
                image_embeddings, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label
            )
        else:
            # Deprecated and not used in experiments
            listener_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener")
            listener_log_probs = self.comprehension_listener_side(
                listener_embeddings, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token
            )        

            speaker_embeddings = self.get_image_embeddings(images, "speaker")
            speaker_log_probs = self.comprehension_speaker_side(
                speaker_embeddings, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label
            )

        joint_log_probs = self.comprehension_reranking(listener_log_probs, speaker_log_probs)
        return listener_log_probs, speaker_log_probs, joint_log_probs

    def comprehension_listener_side(self, image_encoder_embeddings, input_tokens, attn_mask, image_attn_mask,
                                    index_to_token):
        listener = self.get_listener()
        all_logits = listener(
            input_ids=input_tokens,
            attention_mask=attn_mask,
            image_hidden_states=image_encoder_embeddings,
            pixel_attention_mask=image_attn_mask
        )['logits']

        target_logits = filter_targets(all_logits[:, -1], index_to_token) # BxC
        listener_log_probs = F.log_softmax(target_logits, dim=1)
        return listener_log_probs

    def comprehension_speaker_side(self, image_encoder_embeddings, input_tokens, attn_mask, image_attn_mask,
                                   target_mask, target_label):
        # Expand embeddings
        B, C = input_tokens.shape[:2]
        embed_shape = image_encoder_embeddings.shape
        image_encoder_embeddings = image_encoder_embeddings.view(B, C, *embed_shape[1:])
        image_encoder_embeddings = image_encoder_embeddings.unsqueeze(1).repeat(1, C, 1, 1, 1).view(-1, *embed_shape[1:])
        input_tokens = input_tokens.view(B*C, -1)
        attn_mask = attn_mask.view(B*C, -1)

        # Forward pass
        speaker = self.get_speaker()
        all_logits = speaker(
            input_ids=input_tokens,
            attention_mask=attn_mask,
            image_hidden_states=image_encoder_embeddings,
        )['logits']
        
        # Get tokenwise probabilities
        all_log_probs = F.log_softmax(all_logits, dim=2)
        target_label = target_label.view(B*C, -1).unsqueeze(2)
        target_mask = target_mask.view(B*C, -1)
        token_log_probs = torch.gather(all_log_probs, 2, target_label).squeeze(2) # BCxT

        # Compute the log probabilities
        token_log_probs = token_log_probs * target_mask
        utterance_log_probs = torch.sum(token_log_probs, dim=1).view(B, C)

        return utterance_log_probs

    def comprehension_reranking(self, listener_log_probs, speaker_log_probs):
        rerank_weights = self.l_lambda * listener_log_probs + (1 - self.l_lambda) * speaker_log_probs
        rerank_denominator = torch.logsumexp(rerank_weights, dim=1).unsqueeze(1)
        rerank_log_distribution = rerank_weights - rerank_denominator            
        return rerank_log_distribution
    
    def reranking_side(self, arguments):
        images, label, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_tokens, s_target_mask, \
            l_input_tokens, l_attn_mask, l_image_attn_mask, \
            index_to_token, annotation_mask = arguments

        # Repeat image embeddings according to number of distractors
        if self.has_shared_parameters:
            image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker")
            embed_shape = image_embeddings.shape
            B, mult = s_input_tokens.shape[:2]
            C = images.shape[1]
            image_embeddings = image_embeddings.view(B, C, *embed_shape[1:])
            image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:])

            speaker_logits, speaker_log_probs = self.reranking_speaker_side(image_embeddings, s_input_tokens,
                                                                            s_attn_mask, s_image_attn_mask,
                                                                            s_target_tokens, s_target_mask,
                                                                            annotation_mask)

            listener_log_probs = self.reranking_listener_side(image_embeddings, l_input_tokens, l_attn_mask,
                                                              l_image_attn_mask, label, index_to_token,
                                                              annotation_mask)
        else:
            # Deprecated and no longer used in main experiments
            image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker")
            embed_shape = image_embeddings.shape
            B, mult = s_input_tokens.shape[:2]
            C = images.shape[1]
            image_embeddings = image_embeddings.view(B, C, *embed_shape[1:])
            image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:])

            speaker_logits, speaker_log_probs = self.reranking_speaker_side(image_embeddings, s_input_tokens,
                                                                            s_attn_mask, s_image_attn_mask,
                                                                            s_target_tokens, s_target_mask,
                                                                            annotation_mask)


            image_embeddings = self.get_image_embeddings(images, l_image_attn_mask, "listener")
            embed_shape = image_embeddings.shape
            B, mult = s_input_tokens.shape[:2]
            C = images.shape[1]
            image_embeddings = image_embeddings.view(B, C, *embed_shape[1:])
            image_embeddings = image_embeddings.unsqueeze(1).repeat(1, mult, 1, 1, 1).view(-1, *embed_shape[1:])

            listener_log_probs = self.reranking_listener_side(image_embeddings, l_input_tokens, l_attn_mask,
                                                              l_image_attn_mask, label, index_to_token, annotation_mask)

        # Full forward passes
        utterance_distribution = self.reranking_combination(speaker_log_probs, listener_log_probs)
        return speaker_logits, speaker_log_probs, listener_log_probs, utterance_distribution
        

    def reranking_speaker_side(self, image_embeddings, input_tokens, attn_mask, image_attn_mask,
                               target_tokens, target_mask, annotation_mask):
        # Flatten inputs and outputs
        B, mult = input_tokens.shape[:2]
        input_tokens = input_tokens.view(B*mult, -1)
        attn_mask = attn_mask.view(B*mult, -1)
        target_tokens = target_tokens.view(B*mult, -1).unsqueeze(-1)
        target_mask = target_mask.view(B*mult, -1)

        # Forward pass: Compute utterance probabilities for all
        speaker = self.get_speaker()
        all_logits = speaker(
            input_ids=input_tokens,
            attention_mask=attn_mask,
            image_hidden_states=image_embeddings,
        )['logits']
        
        # Compute utterance log probabilities
        all_log_probs = F.log_softmax(all_logits, dim=2)
        token_log_probs = torch.gather(all_log_probs, 2, target_tokens).squeeze(2) # BCxT
        token_log_probs = token_log_probs * target_mask
        utterance_log_probs = torch.sum(token_log_probs, dim=1).view(B, mult)
        utterance_log_probs[annotation_mask] = float('-inf') # Mask in the event there aren't 9 distractors
            
        return all_logits, utterance_log_probs

    def reranking_listener_side(self, image_embeddings, input_tokens, attn_mask, image_attn_mask,
                                label, index_to_token, annotation_mask):
        # Flatten inputs and outputs
        B, mult = input_tokens.shape[:2]
        input_tokens = input_tokens.view(B*mult, -1)
        attn_mask = attn_mask.view(B*mult, -1)
        label = label.unsqueeze(1).repeat(1, mult).view(-1).unsqueeze(1)

        # Forward pass: Compute listener log-probs
        listener = self.get_listener()
        all_logits = listener(
            input_ids=input_tokens,
            attention_mask=attn_mask,
            image_hidden_states=image_embeddings,
        )['logits']

        target_logits = filter_targets(all_logits[:, -1], index_to_token) # BmultxC
        listener_log_probs = F.log_softmax(target_logits, dim=1) #BmultxC
        utterance_log_probs = torch.gather(listener_log_probs, 1, label).squeeze(1).view(B, mult)

        utterance_log_probs[annotation_mask] = float('-inf') # Mask in the event there aren't mult distractors

        return utterance_log_probs

    def reranking_combination(self, speaker_utterance_log_probs, listener_utterance_log_probs):
        weights = self.s_lambda * speaker_utterance_log_probs + (1-self.s_lambda) * listener_utterance_log_probs
        rerank_denominator = torch.logsumexp(weights, dim=1).unsqueeze(1)
        rerank_log_distribution = weights - rerank_denominator
        return rerank_log_distribution

    def split_generate(self, input_tokens, attn_mask, images, image_attn_mask, processor,
                       max_steps=25, sampling_type="nucleus", temperature=1.0,
                       top_k=40, top_p=0.9, repetition_penalty=1, num_samples=1):
        # (1) Perform generation
        speaker = self.get_speaker()
        generation_config = GenerationConfig(
            max_new_tokens=max_steps,
            do_sample=True,
            temperature=temperature,
            top_k=top_k, top_p=top_p,
            repetition_penalty=repetition_penalty,
            num_return_sequences=num_samples,
            output_hidden_states=True,
            return_dict_in_generate=True
        ) 
        outputs = speaker.generate(
            input_ids=input_tokens,
            attention_mask=attn_mask,
            pixel_values=images,
            pixel_attention_mask=image_attn_mask,
            generation_config=generation_config,
            use_cache=True
        )

        # (2) Get the speaker captions
        B = input_tokens.shape[0]
        observed_steps = len(outputs['hidden_states'])
        filtered_seqs = []
        for seq in outputs['sequences']:
            filtered_seqs.append(seq[-observed_steps:])
        speaker_outputs = processor.batch_decode(filtered_seqs, skip_special_tokens=True)
        
        # (3) Get the speaker log probabilities
        target_outputs = torch.stack(filtered_seqs, dim=0) # BNxT
        target_mask = target_outputs != 0
        final_states = torch.stack([outputs['hidden_states'][i][-1][:, -1] for i in range(observed_steps)], dim=1) # BNxTxD
        token_logits = speaker.lm_head(final_states) # BNxTxV
        token_log_probs = F.log_softmax(token_logits, dim=2)
        token_log_probs = torch.gather(token_log_probs, 2, target_outputs.unsqueeze(2)).squeeze(2)
        
        # (4) Choose the output with the top probability
        if B == 1:
            utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(num_samples) # N
            best_idx = torch.argmax(utterance_log_probs).item()
            return [speaker_outputs[best_idx]]
        else:
            utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(B, num_samples) # N
            best_indices = torch.argmax(utterance_log_probs, dim=1)
            choices = []
            for i in range(B):
                curr_index = num_samples * i + best_indices[i].item()                
                choices.append(speaker_outputs[curr_index])
            return choices
            

    def generate(self, images, s_input_tokens, s_attn_mask, s_image_attn_mask, label,
                 image_paths, processor, image_dir, index_to_token,
                 max_steps=25, sampling_type="nucleus", temperature=1.0, top_k=40,
                 top_p=0.9, repetition_penalty=1, num_samples=10):
        # Get the repeated image embeddings; assume parameter sharing
        image_embeddings = self.get_image_embeddings(images, s_image_attn_mask, "speaker") 

        # Sample utterances from the speaker
        speaker_utterance_log_probs, speaker_utterances = self.generate_speaker_side(processor, images, s_input_tokens,
                                                                                     s_attn_mask, s_image_attn_mask, max_steps,
                                                                                     sampling_type, temperature,
                                                                                     top_k, top_p, repetition_penalty,
                                                                                     num_samples) # BxN, BN list

        # Get probabilities for the utterances from the listener
        listener_log_probs = self.generate_listener_side(image_embeddings, speaker_utterances, label, image_paths, processor,
                                                         image_dir, index_to_token, num_samples)

        # Reranked selection
        utterance_weights = self.s_lambda*speaker_utterance_log_probs + (1-self.s_lambda)*listener_log_probs
        chosen_indices = torch.argmax(utterance_weights, dim=1)
        choices = []
        for i in range(speaker_utterance_log_probs.shape[0]):
            curr_index = num_samples * i + chosen_indices[i].item()
            choices.append(speaker_utterances[curr_index])
            
        return choices, speaker_utterances, listener_log_probs, speaker_utterance_log_probs, utterance_weights

    def generate_speaker_side(self, processor, images, s_input_tokens, s_attn_mask, s_image_attn_mask, max_steps,
                              sampling_type, temperature, top_k, top_p, repetition_penalty, num_samples):
        # (1) Perform generation
        speaker = self.get_speaker()
        generation_config = GenerationConfig(
            max_new_tokens=max_steps,
            do_sample=True,
            temperature=temperature,
            top_k=top_k, top_p=top_p,
            repetition_penalty=repetition_penalty,
            num_return_sequences=num_samples,
            output_hidden_states=True,
            return_dict_in_generate=True
        ) 
        outputs = speaker.generate(
            input_ids=s_input_tokens,
            attention_mask=s_attn_mask,
            pixel_values=images,
            pixel_attention_mask=s_image_attn_mask,
            generation_config=generation_config,
            use_cache=True
        )

        # (2) Get the speaker captions
        B = s_input_tokens.shape[0]
        observed_steps = len(outputs['hidden_states'])
        filtered_seqs = []
        for seq in outputs['sequences']:
            filtered_seqs.append(seq[-observed_steps:])
        speaker_outputs = processor.batch_decode(filtered_seqs, skip_special_tokens=True)

        # (3) Get the speaker log probabilities
        target_outputs = torch.stack(filtered_seqs, dim=0) # BNxT
        target_mask = target_outputs != 0
        final_states = torch.stack([outputs['hidden_states'][i][-1][:, -1] for i in range(observed_steps)], dim=1) # BNxTxD
        token_logits = speaker.lm_head(final_states) # BNxTxV
        token_log_probs = F.log_softmax(token_logits, dim=2)
        token_log_probs = torch.gather(token_log_probs, 2, target_outputs.unsqueeze(2)).squeeze(2)
        utterance_log_probs = torch.sum(token_log_probs * target_mask, dim=1).view(B, num_samples) # BxN
            
        return utterance_log_probs, speaker_outputs

    def generate_listener_side(self, image_embeddings, speaker_utterances, label, image_paths, processor,
                               image_dir, index_to_token, num_samples):
        # Construct the inputs
        B = label.shape[0]
        embed_shape = image_embeddings.shape
        image_embeddings = image_embeddings.view(B, -1, *embed_shape[1:])
        image_embeddings = image_embeddings.unsqueeze(1).repeat(1, num_samples, 1, 1, 1).view(-1, *embed_shape[1:])

        l_batch = process_idefics_listener_generation_input(image_paths, speaker_utterances, processor, 
                                                            image_dir, num_samples, image_embeddings.device)
        l_input_tokens, l_attn_mask, _, l_image_attn_mask = l_batch
        label = label.unsqueeze(1).repeat(1, num_samples).view(-1).unsqueeze(1)

        # Forward pass
        listener = self.get_listener()
        all_logits = listener(
            input_ids=l_input_tokens,
            attention_mask=l_attn_mask,
            image_hidden_states=image_embeddings,
            pixel_attention_mask=l_image_attn_mask
        )['logits']

        target_logits = filter_targets(all_logits[:, -1], index_to_token)
        listener_log_probs = F.log_softmax(target_logits, dim=1)
        utterance_log_probs = torch.gather(listener_log_probs, 1, label).squeeze(1).view(B, num_samples)

        return utterance_log_probs