File size: 4,337 Bytes
f188f75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright 2023 Stability AI team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Union
from transformers import PretrainedConfig, CLIPVisionConfig
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import logging
logger = logging.get_logger(__name__)
class LlavaMlpConfig(PretrainedConfig):
model_type = "llava_mlp"
def __init__(
self,
num_hidden_layers=2,
**kwargs,
):
super().__init__(**kwargs)
self.num_hidden_layers = num_hidden_layers
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
# get the qformer config dict if we are loading from InstructBlipConfig
if config_dict.get("model_type") == "llava":
config_dict = config_dict["mlp_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class LlavaConfig(PretrainedConfig):
model_type = "llava"
is_composition = True
def __init__(
self,
vision_config=None,
mlp_config=None,
text_config=None,
vision_select_layer=-2,
vision_select_feature="patch",
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {}
logger.info(
"vision_config is None. initializing the CLIPVisionConfig with default values."
)
if mlp_config is None:
mlp_config = {}
logger.info(
"mlp_config is None. Initializing the LlavaMlpConfig with default values."
)
if text_config is None:
text_config = {}
logger.info(
"text_config is None. Initializing the text config with default values (`OPTConfig`)."
)
self.vision_config = CLIPVisionConfig(**vision_config)
self.mlp_config = LlavaMlpConfig(**mlp_config)
text_model_type = text_config["model_type"]
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder
self.use_decoder_only_language_model = (
self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
self.vision_select_layer = vision_select_layer
assert vision_select_feature in [
"cls_patch",
"patch",
], f"Unexpected select feature: {vision_select_feature}"
self.vision_select_feature = vision_select_feature
self.initializer_factor = 1.0
self.initializer_range = 0.02
@classmethod
def from_vision_mlp_text_configs(
cls,
vision_config: CLIPVisionConfig,
mlp_config: LlavaMlpConfig,
text_config: PretrainedConfig,
**kwargs,
):
return cls(
vision_config=vision_config.to_dict(),
mlp_config=mlp_config.to_dict(),
text_config=text_config.to_dict(),
**kwargs,
) |