# Copyright 2023 Stability AI team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Union from transformers import PretrainedConfig, CLIPVisionConfig from transformers.models.auto import CONFIG_MAPPING from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.utils import logging logger = logging.get_logger(__name__) class LlavaMlpConfig(PretrainedConfig): model_type = "llava_mlp" def __init__( self, num_hidden_layers=2, **kwargs, ): super().__init__(**kwargs) self.num_hidden_layers = num_hidden_layers @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict( pretrained_model_name_or_path, **kwargs ) # get the qformer config dict if we are loading from InstructBlipConfig if config_dict.get("model_type") == "llava": config_dict = config_dict["mlp_config"] if ( "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class LlavaConfig(PretrainedConfig): model_type = "llava" is_composition = True def __init__( self, vision_config=None, mlp_config=None, text_config=None, vision_select_layer=-2, vision_select_feature="patch", **kwargs, ): super().__init__(**kwargs) if vision_config is None: vision_config = {} logger.info( "vision_config is None. initializing the CLIPVisionConfig with default values." ) if mlp_config is None: mlp_config = {} logger.info( "mlp_config is None. Initializing the LlavaMlpConfig with default values." ) if text_config is None: text_config = {} logger.info( "text_config is None. Initializing the text config with default values (`OPTConfig`)." ) self.vision_config = CLIPVisionConfig(**vision_config) self.mlp_config = LlavaMlpConfig(**mlp_config) text_model_type = text_config["model_type"] self.text_config = CONFIG_MAPPING[text_model_type](**text_config) self.tie_word_embeddings = self.text_config.tie_word_embeddings self.is_encoder_decoder = self.text_config.is_encoder_decoder self.use_decoder_only_language_model = ( self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES ) self.vision_select_layer = vision_select_layer assert vision_select_feature in [ "cls_patch", "patch", ], f"Unexpected select feature: {vision_select_feature}" self.vision_select_feature = vision_select_feature self.initializer_factor = 1.0 self.initializer_range = 0.02 @classmethod def from_vision_mlp_text_configs( cls, vision_config: CLIPVisionConfig, mlp_config: LlavaMlpConfig, text_config: PretrainedConfig, **kwargs, ): return cls( vision_config=vision_config.to_dict(), mlp_config=mlp_config.to_dict(), text_config=text_config.to_dict(), **kwargs, )