# Copyright 2023 Stability AI team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union, Any from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, PreTrainedModel, CLIPVisionModel, ) from transformers.utils import logging, ModelOutput from .configuration_llava import LlavaConfig logger = logging.get_logger(__name__) @dataclass class LlavaForConditionalGenerationModelOutput(ModelOutput): loss: Optional[Tuple[torch.FloatTensor]] = None logits: Optional[Tuple[torch.FloatTensor]] = None vision_outputs: Optional[torch.FloatTensor] = None language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None def to_tuple(self) -> Tuple[Any]: return tuple( self[k] if k not in ["vision_outputs", "language_model_outputs"] else getattr(self, k).to_tuple() for k in self.keys() ) class LlavaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = LlavaConfig base_model_prefix = "llava" # Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_range if ( isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear) ): module.weight.data.normal_(mean=0.0, std=factor) if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class LlavaForConditionalGeneration(LlavaPreTrainedModel): config_class = LlavaConfig main_input_name = "pixel_values" _no_split_modules = [] def __init__(self, config: LlavaConfig): super().__init__(config) self.vision_model = CLIPVisionModel(config.vision_config) if config.use_decoder_only_language_model: language_model = AutoModelForCausalLM.from_config(config.text_config) else: language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) if language_model._no_split_modules is not None: self._no_split_modules.extend(language_model._no_split_modules) if language_model._keep_in_fp32_modules is not None: self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules) self.language_model = language_model modules = [ nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size) ] for _ in range(1, config.mlp_config.num_hidden_layers): modules.append(nn.GELU()) modules.append( nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size ) ) self.mlp = nn.Sequential(*modules) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def get_output_embeddings(self) -> nn.Module: return self.language_model.get_output_embeddings() def get_encoder(self): return self.language_model.get_encoder() def get_decoder(self): return self.language_model.get_decoder() def _tie_weights(self): if not self.config.use_decoder_only_language_model: self.language_model.encoder.embed_tokens = self.language_model.shared self.language_model.decoder.embed_tokens = self.language_model.shared def _preprocess_accelerate(self): r""" Some pre-processing hacks to make the model `accelerate` compatible. Check https://github.com/huggingface/transformers/pull/21707 for more details. """ hf_device_map = self.hf_device_map if ( len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1 ): # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`. logger.warning( "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." " Please pass a `device_map` that contains `language_model` to remove this warning." " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" " more details on creating a `device_map` for large models.", ) if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = ( True # For `generate` compatibility ) def forward( self, pixel_values: torch.FloatTensor, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, LlavaForConditionalGenerationModelOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # step 1: forward the images through the vision encoder, vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, return_dict=return_dict, output_hidden_states=True, ) # (bsz, seq len, hidden_size) image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer] if self.config.vision_select_feature == "patch": image_embeds = image_embeds[:, 1:] elif self.config.vision_select_feature == "cls_patch": image_embeds = image_embeds else: raise ValueError(f"Unexpected select feature: {self.select_feature}") # step 2: forward the image embeddings through the mlp image_embeds = self.mlp(image_embeds) image_attention_mask = torch.ones( image_embeds.size()[:-1], device=image_embeds.device ) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # step 3: concatenate inputs_embeds = torch.cat( [image_embeds, inputs_embeds.to(image_embeds.device)], dim=1, ) if attention_mask is None: attention_mask = torch.ones_like(input_ids, device=input_ids.device) attention_mask = torch.cat( [image_attention_mask.to(attention_mask.device), attention_mask], dim=1, ) if self.config.use_decoder_only_language_model: outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits if return_dict else outputs[0] loss = None # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: labels = labels.to(logits.device) logits = logits[:, -labels.size(1) :, :] # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous().to(logits.device) # Flatten the tokens loss_fct = CrossEntropyLoss(reduction="mean") loss = loss_fct( shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1), ) else: outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, labels=labels, ) loss = outputs.loss if return_dict else outputs[0] logits = outputs.logits if return_dict else outputs[1] if not return_dict: output = (logits, vision_outputs, outputs) return ((loss,) + output) if loss is not None else output return LlavaForConditionalGenerationModelOutput( loss=loss, logits=logits, vision_outputs=vision_outputs, language_model_outputs=outputs, ) def get_image_embeds(self, pixel_values: torch.FloatTensor): vision_outputs = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, ) image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer] if self.config.vision_select_feature == "patch": image_embeds = image_embeds[:, 1:] elif self.config.vision_select_feature == "cls_patch": image_embeds = image_embeds else: raise ValueError(f"Unexpected select feature: {self.select_feature}") image_embeds = self.mlp(image_embeds) image_attention_mask = torch.ones( image_embeds.size()[:-1], device=image_embeds.device ) return dict( image_embeds=image_embeds, image_attention_mask=image_attention_mask, ) def prepare_for_lm_generation( self, pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): batch_size = pixel_values.shape[0] vision_outputs = self.get_image_embeds(pixel_values) image_embeds = vision_outputs["image_embeds"] image_attention_mask = vision_outputs["image_attention_mask"] if input_ids is None: input_ids = ( torch.LongTensor([[self.config.text_config.bos_token_id]]) .repeat(batch_size, 1) .to(image_embeds.device) ) if attention_mask is None: attention_mask = torch.ones_like(input_ids) attention_mask = torch.cat( [ image_attention_mask, attention_mask.to(image_attention_mask.device), ], dim=1, ) # concatenate query embeddings with prompt embeddings inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = torch.cat( [image_embeds, inputs_embeds.to(image_embeds.device)], dim=1, ) return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask) @torch.no_grad() def generate( self, pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, **generate_kwargs, ) -> torch.LongTensor: if hasattr(self, "hf_device_map"): # preprocess for `accelerate` self._preprocess_accelerate() encodings = self.prepare_for_lm_generation( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, ) outputs = self.language_model.generate( **encodings, **generate_kwargs, ) return outputs