|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from typing import Union |
|
|
|
from transformers import PretrainedConfig, CLIPVisionConfig |
|
from transformers.models.auto import CONFIG_MAPPING |
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class LlavaMlpConfig(PretrainedConfig): |
|
model_type = "llava_mlp" |
|
|
|
def __init__( |
|
self, |
|
num_hidden_layers=2, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs |
|
) -> "PretrainedConfig": |
|
cls._set_token_in_kwargs(kwargs) |
|
|
|
config_dict, kwargs = cls.get_config_dict( |
|
pretrained_model_name_or_path, **kwargs |
|
) |
|
|
|
|
|
if config_dict.get("model_type") == "llava": |
|
config_dict = config_dict["mlp_config"] |
|
|
|
if ( |
|
"model_type" in config_dict |
|
and hasattr(cls, "model_type") |
|
and config_dict["model_type"] != cls.model_type |
|
): |
|
logger.warning( |
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " |
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." |
|
) |
|
|
|
return cls.from_dict(config_dict, **kwargs) |
|
|
|
|
|
class LlavaConfig(PretrainedConfig): |
|
model_type = "llava" |
|
is_composition = True |
|
|
|
def __init__( |
|
self, |
|
vision_config=None, |
|
mlp_config=None, |
|
text_config=None, |
|
vision_select_layer=-2, |
|
vision_select_feature="patch", |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
if vision_config is None: |
|
vision_config = {} |
|
logger.info( |
|
"vision_config is None. initializing the CLIPVisionConfig with default values." |
|
) |
|
|
|
if mlp_config is None: |
|
mlp_config = {} |
|
logger.info( |
|
"mlp_config is None. Initializing the LlavaMlpConfig with default values." |
|
) |
|
|
|
if text_config is None: |
|
text_config = {} |
|
logger.info( |
|
"text_config is None. Initializing the text config with default values (`OPTConfig`)." |
|
) |
|
|
|
self.vision_config = CLIPVisionConfig(**vision_config) |
|
self.mlp_config = LlavaMlpConfig(**mlp_config) |
|
text_model_type = text_config["model_type"] |
|
self.text_config = CONFIG_MAPPING[text_model_type](**text_config) |
|
|
|
self.tie_word_embeddings = self.text_config.tie_word_embeddings |
|
self.is_encoder_decoder = self.text_config.is_encoder_decoder |
|
|
|
self.use_decoder_only_language_model = ( |
|
self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES |
|
) |
|
self.vision_select_layer = vision_select_layer |
|
assert vision_select_feature in [ |
|
"cls_patch", |
|
"patch", |
|
], f"Unexpected select feature: {vision_select_feature}" |
|
self.vision_select_feature = vision_select_feature |
|
self.initializer_factor = 1.0 |
|
self.initializer_range = 0.02 |
|
|
|
@classmethod |
|
def from_vision_mlp_text_configs( |
|
cls, |
|
vision_config: CLIPVisionConfig, |
|
mlp_config: LlavaMlpConfig, |
|
text_config: PretrainedConfig, |
|
**kwargs, |
|
): |
|
return cls( |
|
vision_config=vision_config.to_dict(), |
|
mlp_config=mlp_config.to_dict(), |
|
text_config=text_config.to_dict(), |
|
**kwargs, |
|
) |