# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn, Tensor from torchmultimodal.models.albef.image_encoder import ALBEFVisionEncoder from torchmultimodal.models.albef.model import ALBEFModel, ALBEFModelWithSimilarity from torchmultimodal.models.albef.multimodal_encoder import ALBEFMultimodalEncoder from torchmultimodal.modules.encoders.bert_text_encoder import bert_text_encoder from torchmultimodal.modules.layers.text_embedding import BERTTextEmbeddings from torchmultimodal.modules.losses.albef import ( CausalLanguageModelingLoss, ImageTextContrastiveLoss, ) from torchmultimodal.utils.attention import get_causal_attention_mask from torchmultimodal.utils.common import momentum_update, remove_grad _ALBEF_PRETRAINED_URLS = { "vqa": "https://download.pytorch.org/models/multimodal/albef/pretrained_vqa_checkpoint.pt", "retrieval": "https://download.pytorch.org/models/multimodal/albef/pretrained_retrieval_checkpoint.pt", } class PredictionHead(nn.Module): """ Predict the following token autoregressively. Args: vocab_size (int): The number of different tokens the prediction_head can predict. hidden_size (int): The hidden size of the prediction_head. layer_norm_eps (float): The epsilon used by the prediction_head normalization layer. transform_act_fn (Callable[[Tensor], Tensor]): The activation function in the prediction_head. Inputs: hidden_states (Tensor): The hidden states of preceding tokens. Returns: Tensor: Prediction scores for the following token. """ def __init__( self, vocab_size: int = 30522, hidden_size: int = 768, layer_norm_eps: float = 1e-12, transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu, ) -> None: super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.transform_act_fn = transform_act_fn self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.decoder = nn.Linear(hidden_size, vocab_size) def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.layer_norm(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class ALBEFDecoder(nn.Module): """ Generate the prediction scores for answers from image and question hidden states. Args: text_embeddings (ALBEFTextEmbeddings): Instantiated ALBEFTextEmbeddings. multimodal_encoder (ALBEFMultimodalEncoder): Instantiated ALBEFMultimodalEncoder. prediction_head (PredictionHead): Instantiated PredictionHead. Inputs: input_ids (Tensor of shape (batch_size, seq_len)): Input ids for input text tokens. attention_mask (Tensor of shape (batch_size, seq_len)): Input attention mask to avoid performing attention on padding token indices. encoder_hidden_states (Tensor of shape (batch_size, encoder_seq_len, hidden_size)): The encoder hidden states. encoder_attention_mask (Tensor of shape (batch_size, encoder_seq_len)): The attention mask for encoder hidden states. Returns: Tensor: Prediction scores for answers. """ def __init__( self, text_embeddings: BERTTextEmbeddings, multimodal_encoder: ALBEFMultimodalEncoder, prediction_head: PredictionHead, ) -> None: super().__init__() self.text_embeddings = text_embeddings self.multimodal_encoder = multimodal_encoder self.prediction_head = prediction_head def get_extended_attention_mask_for_decoder(self, attention_mask: Tensor) -> Tensor: """ Apply a causal mask in addition to the padding mask and make the mask broadcastable, such that future and masked tokens are ignored. Args: attention_mask (Tensor): Padding mask with ones indicating tokens to attend to, zeros for tokens to ignore. Returns: extended_attention_mask (Tensor): The broadcastable attention mask, with the same dtype as ``attention_mask.dtype``. """ device = attention_mask.device batch_size, seq_length = attention_mask.shape causal_mask = get_causal_attention_mask(seq_length).to(device) causal_mask = causal_mask.repeat(batch_size, 1).view( batch_size, seq_length, seq_length ) extended_attention_mask = ( causal_mask[:, None, :, :] * attention_mask[:, None, None, :] ) extended_attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype) return extended_attention_mask def forward( self, input_ids: Tensor, attention_mask: Tensor, encoder_hidden_states: Tensor, encoder_attention_mask: Tensor, ) -> Tensor: hidden_states = self.text_embeddings(input_ids) attention_mask = self.get_extended_attention_mask_for_decoder(attention_mask) decoder_output = self.multimodal_encoder( hidden_states=hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, ) prediction_scores = self.prediction_head(decoder_output) return prediction_scores class ALBEFModelForVQA(nn.Module): """ ALBEF Model for VQA finetuning and inference. Args: model (ALBEFModel): Instantiated ALBEFModel. answer_decoder (ALBEFDecoder): Instantiated ALBEFDecoder. loss (CausalLanguageModelingLoss): Instantiated CausalLanguageModelingLoss. Inputs: image (Tensor of shape (B, C, H, W)): Image features. question (Tensor of shape (B, L)): Question text features. question_atts (Tensor of shape (B, L)): Question attention mask. answers (Tensor of shape (N, M)): Answer text features. answers_atts (Tensor of shape (N, M)): Answer attention mask. ans_weights (Optional[Tensor] of shape (N)): Weights for each answer. Required if is_train is True. ans_lengths (Optional[List[int]] of length B): Number of answers for each question. ans_lengths should sum to N. Required if is_train is True. alpha (Optional[float]): The interpolation value between clm_loss and loss_distill. Required if is_train is True. k (Optional[int]): The number of answers to return for inference. Required if is_train is False. is_train (Optional[bool]): Whether the model is in training. Returns: is_train is True: Tensor: The masked language modeling loss for input. is_train is False: Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers. """ def __init__( self, model: ALBEFModel, answer_decoder: ALBEFDecoder, loss: CausalLanguageModelingLoss, ) -> None: super().__init__() self.model = model self.answer_decoder = answer_decoder self.loss = loss self.answer_decoder_m = copy.deepcopy(self.answer_decoder) remove_grad( self.answer_decoder_m ) # remove gradient for the momentum decoder model def _train_forward( self, image: Tensor, question: Tensor, question_atts: Tensor, answers: Tensor, answers_atts: Tensor, ans_weights: Tensor, ans_lengths: List[int], alpha: float, ) -> Tensor: """ Forward step for training. Encode the inputs with the ALBEFModel. Generate pseudo-targets using answer_decoder_m (momentum decoder model). Generate answer predictions using answer_decoder. Compute masked language modeling loss of the predictions using answers as labels, pseudo-targets as soft-labels, and alpha as their interpolation value. Inputs: image (Tensor of shape (B, C, H, W)): Image features. question (Tensor of shape (B, L)): Question text features. question_atts (Tensor of shape (B, L)): Question attention mask. answers (Tensor of shape (N, M)): Answer text features. answers_atts (Tensor of shape (N, M)): Answer attention mask. ans_weights (Tensor of shape (N)): Weights for each answer. ans_lengths (List[int] of length B): Number of answers for each question. ans_lengths should sum to N. alpha (float): The interpolation value between clm_loss and loss_distill. Returns: Tensor: The masked language modeling loss for input. """ # get image-question embeddings from the ALBEFModel and format it to match the ans_lengths encoder_outputs = self.model(image, question, question_atts) ( encoder_hidden_states, encoder_hidden_states_m, encoder_attention_mask, ) = self._encoder_hidden_states( encoder_outputs.multimodal_embeddings, encoder_outputs.multimodal_embeddings_m, question_atts, ans_lengths, ) # use the momentum model to generate pseudo-targets with torch.no_grad(): momentum_update( self.answer_decoder, self.answer_decoder_m, self.model.momentum ) prediction_scores_m = self.answer_decoder_m( input_ids=answers, attention_mask=answers_atts, encoder_hidden_states=encoder_hidden_states_m, encoder_attention_mask=encoder_attention_mask, ) # generate answer predictions prediction_scores = self.answer_decoder( input_ids=answers, attention_mask=answers_atts, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, ) # compute masked language modeling loss from the prediction scores labels = answers.masked_fill(answers == 0, self.loss.mask_token_id) loss = self.loss(labels, prediction_scores, prediction_scores_m, alpha) loss = ans_weights * loss loss = loss.sum() / image.size(0) return loss def _eval_forward( self, image: Tensor, question: Tensor, question_atts: Tensor, answers: Tensor, answer_atts: Tensor, k: int = 128, ) -> Tuple[Tensor, Tensor]: """ Forward step for evaluation. Encode the inputs with the ALBEFModel. Generate answer autoregressively using the decoder, starting with the [CLS] token. Compute the answer ids and their perspective probabilities of the top k predictions. Inputs: image (Tensor of shape (B, C, H, W)): Image features. question (Tensor of shape (B, L)): Question text features. question_atts (Tensor of shape (B, L)): Question attention mask. answers (Tensor of shape (N, M)): Answer text features. answer_atts (Tensor of shape (N, M)): Answer attention mask. k (int): The number of answers to return for inference. Returns: Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers. """ # get multimodal embeddings from the ALBEFModel and # feed it to the decoder as cross attention encoder_outputs = self.model(image, question, question_atts) # use cls token as the decoder's initial input token num_ques = question.size(0) start_ids = answers[0, 0].repeat(num_ques, 1) atts = torch.ones(start_ids.shape).to(image.device) # auto-regressively generates the answer prediction_scores = self.answer_decoder( input_ids=start_ids, attention_mask=atts, encoder_hidden_states=encoder_outputs.multimodal_embeddings, encoder_attention_mask=question_atts, ) logits = prediction_scores[:, 0, :] answer_first_token = answers[:, 1] prob_first_token = F.softmax(logits, dim=1).index_select( dim=1, index=answer_first_token ) topk_probs, topk_ids = prob_first_token.topk(k, dim=1) input_ids = [] input_atts = [] for topk_id in topk_ids: input_ids.append(answers.index_select(dim=0, index=topk_id)) input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) input_ids = torch.cat(input_ids) input_atts = torch.cat(input_atts) targets_ids = input_ids.masked_fill(input_ids == 0, self.loss.mask_token_id) question_states = encoder_outputs.multimodal_embeddings.repeat_interleave( k, dim=0 ) question_atts = question_atts.repeat_interleave(k, dim=0) prediction_scores = self.answer_decoder( input_ids=input_ids, attention_mask=input_atts, encoder_hidden_states=question_states, encoder_attention_mask=question_atts, ) answer_loss = self.loss(targets_ids, prediction_scores) answer_loss = answer_loss.view(input_ids.size(0), -1) # topk_prob: first token probability topk_probs = topk_probs.view(-1, 1) log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) # re-calculate log probabilities for the answer sequences using chain rule log_probs_sum = log_probs.sum(1) log_probs_sum = log_probs_sum.view(num_ques, k) topk_probs = F.softmax(log_probs_sum, dim=-1) # get top-k after re-ranking topk_probs, rerank_id = topk_probs.topk(k, dim=1) topk_ids = torch.gather(topk_ids, 1, rerank_id) return topk_ids, topk_probs def _encoder_hidden_states( self, multimodal_embeds: Tensor, multimodal_embeds_m: Tensor, question_atts: Tensor, ans_lengths: List[int], ) -> Tuple[Tensor, Tensor, Tensor]: """ Repeat each image-question input, repeat its embedding and mask to match the number of answers it has. Args: multimodal_embeds (Tensor): Image-question embeddings. multimodal_embeds_m (Tensor): Image-question embeddings from the momentum model. question_atts (Tensor): Question attention mask. ans_lengths (List[int]): The number of answers each image-question input has. Returns: encoder_hidden_states (Tensor): Image-question embeddings after the repetition. encoder_hidden_states_m (Tensor): Image-question embeddings from the momentum model after the repetition. encoder_attention_mask (Tensor): Question attention mask after the repetition. """ encoder_hidden_states = [] encoder_attention_mask = [] for b, n in enumerate(ans_lengths): encoder_hidden_states += [multimodal_embeds[b]] * n encoder_attention_mask += [question_atts[b]] * n encoder_hidden_states = torch.stack(encoder_hidden_states) encoder_attention_mask = torch.stack(encoder_attention_mask) with torch.no_grad(): encoder_hidden_states_m = [] for b, n in enumerate(ans_lengths): encoder_hidden_states_m += [multimodal_embeds_m[b]] * n encoder_hidden_states_m = torch.stack(encoder_hidden_states_m) return encoder_hidden_states, encoder_hidden_states_m, encoder_attention_mask def forward( self, image: Tensor, question: Tensor, question_atts: Tensor, answers: Tensor, answers_atts: Tensor, ans_weights: Optional[Tensor] = None, ans_lengths: Optional[List[int]] = None, alpha: Optional[float] = 0.0, k: Optional[int] = 128, is_train: Optional[bool] = True, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: if is_train: return self._train_forward( image, question, question_atts, answers, answers_atts, ans_weights, ans_lengths, alpha, ) else: return self._eval_forward( image, question, question_atts, answers, answers_atts, k, ) class ALBEFModelForRetrieval(nn.Module): """ ALBEF Model for Retrieval finetuning and inference. In training mode, the forward step computes image-text contrastive loss and image-text matching loss. In evaluation mode, the forward step takes 3 types of input: image: encode image input, project and normalize the embeddings. text: encode text input, project and normalize the embeddings. multimodal: create multimodal embeddings from image and text embeddings, and compute image-text matching scores. Args: model_with_similarity (ALBEFModelWithSimilarity): Instantiated ALBEFModelWithSimilarity. itc_loss (ImageTextContrastiveLoss): Instantiated ImageTextContrastiveLoss. hidden_size (int): Dimensionality of encoder outputs. Inputs: image (Optional[Tensor] of shape (B, C, H, W)): Image features. Required if is_train is True. Required if input_type is "image" or "multimodal". text (Optional[Tensor] of shape (B, L)): Text features. Required if is_train is True. Required if input_type is "text" or "multimodal". text_atts (Tensor of shape (B, L)): Text attention mask. Required if is_train is True. Required if input_type is "text" or "multimodal". idx (Tensor of shape (B)): Identifier for each image sample. Required if is_train is True. alpha (Optional[float]): The interpolation value between clm_loss and loss_distill. Default is 0. input_type (Optional[str]): "image", "text", or "multimodal" indicating the encoding type. Required if is_train is False. is_train (Optional[bool]): Whether the model is in training. Default is True. Returns: is_train is True: Tensor: The sum of itc loss and itm loss. is_train is False: input_type is "image": Tuple[Tensor, Tensor]: Image embeddings and projected image features. input_type is "text": Tuple[Tensor, Tensor]: Text embeddings and projected text features. input_type is "multimodal" Tensor: Scores for the retrieval task. """ def __init__( self, model_with_similarity: ALBEFModelWithSimilarity, itc_loss: ImageTextContrastiveLoss, hidden_size: int, ) -> None: super().__init__() self.model_with_similarity = model_with_similarity self.itc_loss = itc_loss self.itm_head = nn.Linear(hidden_size, 2) def _train_forward( self, image: Tensor, text: Tensor, text_atts: Tensor, idx: Tensor, alpha: float, ) -> Tensor: encoder_output = self.model_with_similarity(image, text, text_atts, idx) # compute image-text contrastive loss similarity_outputs = encoder_output.similarity similarity_targets = encoder_output.sim_targets itc_loss = self.itc_loss( similarity_outputs.sim_i2t, similarity_outputs.sim_t2i, similarity_outputs.sim_i2t_m, similarity_outputs.sim_t2i_m, similarity_targets, alpha, ) # compute image-text matching loss pos_embeddings = encoder_output.multimodal_embeddings[:, 0, :] neg_embeddings = encoder_output.multimodal_embeddings_neg[:, 0, :] vl_embeddings = torch.cat([pos_embeddings, neg_embeddings], dim=0) vl_output = self.itm_head(vl_embeddings) itm_labels = torch.cat( [ torch.ones(pos_embeddings.size(0), dtype=torch.long), torch.zeros(neg_embeddings.size(0), dtype=torch.long), ], dim=0, ).to(vl_embeddings.device) itm_loss = F.cross_entropy(vl_output, itm_labels) loss = itc_loss + itm_loss return loss def _encode_image( self, image: Tensor, ) -> Tuple[Tensor, Tensor]: image_embed = self.model_with_similarity.albef_model.vision_encoder(image) image_feat = F.normalize( self.model_with_similarity.vision_proj(image_embed[:, 0, :]), dim=-1 ) return image_embed, image_feat def _encode_text( self, text: Tensor, text_atts: Tensor, ) -> Tuple[Tensor, Tensor]: text_embed = self.model_with_similarity.albef_model.text_encoder( text, text_atts ).last_hidden_state text_feat = F.normalize( self.model_with_similarity.text_proj(text_embed[:, 0, :]), dim=-1 ) return text_embed, text_feat def _image_text_matching_score( self, image: Tensor, text: Tensor, text_atts: Tensor, ) -> Tensor: multimodal_embeds = self.model_with_similarity.albef_model.multimodal_encoder( text, text_atts, image, ) score = self.itm_head(multimodal_embeds[:, 0, :])[:, 1] return score def _eval_forward( self, input_type: str, image: Optional[Tensor], text: Optional[Tensor], text_atts: Optional[Tensor], ) -> Union[Tensor, Tuple[Tensor, Tensor]]: if input_type == "image": assert image is not None, "image input tensor cannot be None" return self._encode_image(image) elif input_type == "text": assert ( text is not None and text_atts is not None ), "text and text attention mask cannot be None" return self._encode_text(text, text_atts) elif input_type == "multimodal": assert ( image is not None and text is not None and text_atts is not None ), "image embeddings, text embeddings, and text attention mask cannot be None" return self._image_text_matching_score(image, text, text_atts) else: raise ValueError("input_type must be image, text, or multimodal") def forward( self, image: Optional[Tensor] = None, text: Optional[Tensor] = None, text_atts: Optional[Tensor] = None, idx: Optional[Tensor] = None, alpha: Optional[Tensor] = 0.0, input_type: Optional[str] = None, is_train: Optional[bool] = True, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: if is_train: return self._train_forward( image, text, text_atts, idx, alpha, ) else: return self._eval_forward( input_type, image, text, text_atts, ) def albef_model_for_vqa( config: Dict[str, Any], pretrained: bool = False ) -> ALBEFModelForVQA: vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"]) text_encoder = bert_text_encoder(**config["text_encoder_args"]) question_multimodal_encoder = ALBEFMultimodalEncoder( **config["multimodal_encoder_args"] ) text_embeddings = BERTTextEmbeddings(**config["text_embeddings_args"]) answer_multimodal_encoder = ALBEFMultimodalEncoder( **config["multimodal_encoder_args"] ) prediction_head = PredictionHead(**config["prediction_head_args"]) albef_model = ALBEFModel(vision_encoder, text_encoder, question_multimodal_encoder) decoder = ALBEFDecoder(text_embeddings, answer_multimodal_encoder, prediction_head) loss = CausalLanguageModelingLoss() model = ALBEFModelForVQA(albef_model, decoder, loss) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( _ALBEF_PRETRAINED_URLS["vqa"], map_location="cpu" ) model.load_state_dict(checkpoint) return model def albef_model_for_retrieval( config: Dict[str, Any], pretrained: bool = False ) -> ALBEFModelForRetrieval: vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"]) text_encoder = bert_text_encoder(**config["text_encoder_args"]) multimodal_encoder = ALBEFMultimodalEncoder(**config["multimodal_encoder_args"]) vision_proj = nn.Linear(**config["projection_args"]) text_proj = nn.Linear(**config["projection_args"]) albef_model = ALBEFModel(vision_encoder, text_encoder, multimodal_encoder) albef_model_with_sim = ALBEFModelWithSimilarity( albef_model, vision_proj, text_proj, **config["similarity_args"] ) itc_loss = ImageTextContrastiveLoss() model = ALBEFModelForRetrieval( albef_model_with_sim, itc_loss, config["hidden_size"] ) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( _ALBEF_PRETRAINED_URLS["retrieval"], map_location="cpu" ) model.load_state_dict(checkpoint) return model