import transformers from transformers import AutoProcessor, AutoModelForCausalLM from transformers import ViTFeatureExtractor, ViTModel, ViTConfig from typing import List, Optional, Tuple, Union import warnings import ipdb import os import torch from torch import nn from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss from itertools import product import numpy as np import transformers.models.git.modeling_git as modeling_git import transformers.models.vit.modeling_vit as modeling_vit from transformers.models.opt.modeling_opt import OPTConfig import transformers.models.opt.modeling_opt as hg_opt import transformers.models.clip.modeling_clip as modeling_clip from transformers.modeling_outputs import SequenceClassifierOutputWithPast class GitForCausalLM(modeling_git.GitForCausalLM): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) del self.output self.output = nn.Linear( self.config.hidden_size, self.config.vocab_size, bias=False) self.post_init() del self.git.image_encoder self.git.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16') dino_cfg = self.git.image_encoder.config config = self.git.config config.vision_config.hidden_size = dino_cfg.hidden_size del self.git.visual_projection self.git.visual_projection = modeling_git.GitProjection(config) num_tks = (dino_cfg.image_size // dino_cfg.patch_size) ** 2 + 1 self.git.encoder.layer[0].attention.self.image_patch_tokens = num_tks def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.Tensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], modeling_git.CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False outputs = self.git( input_ids, attention_mask=attention_mask, position_ids=position_ids, pixel_values=pixel_values, head_mask=head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.output(sequence_output) loss = None if labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one if pixel_values is not None: num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens else: num_image_tokens = 0 shifted_logits = logits[:, num_image_tokens:-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss() loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return modeling_git.CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class GitForSequenceClassification(modeling_git.GitPreTrainedModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.num_labels = self.config.num_labels self.classifier = nn.Linear( self.config.hidden_size, self.config.num_labels, bias=False) self.post_init() self.git = modeling_git.GitModel(self.config) del self.git.image_encoder self.git.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16') dino_cfg = self.git.image_encoder.config config = self.git.config config.vision_config.hidden_size = dino_cfg.hidden_size del self.git.visual_projection self.git.visual_projection = modeling_git.GitProjection(config) num_tks = (dino_cfg.image_size // dino_cfg.patch_size) ** 2 + 1 self.git.encoder.layer[0].attention.self.image_patch_tokens = num_tks def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, *args, **kwargs) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.git( input_ids, attention_mask=attention_mask, position_ids=position_ids, pixel_values=pixel_values, head_mask=head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, *args, **kwargs) hidden_states = outputs[0] logits = self.classifier(hidden_states) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] else: batch_size, sequence_length = inputs_embeds.shape[:2] if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 # logger.warning( # f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " # "unexpected if using padding tokens in conjunction with `inputs_embeds.`" # ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )