japanese-llava-small-sfcoco2023 / configuration_llava.py
ohashi56225's picture
Upload LlavaForConditionalGeneration
f188f75
raw
history blame contribute delete
No virus
4.34 kB
# Copyright 2023 Stability AI team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Union
from transformers import PretrainedConfig, CLIPVisionConfig
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import logging
logger = logging.get_logger(__name__)
class LlavaMlpConfig(PretrainedConfig):
model_type = "llava_mlp"
def __init__(
self,
num_hidden_layers=2,
**kwargs,
):
super().__init__(**kwargs)
self.num_hidden_layers = num_hidden_layers
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
# get the qformer config dict if we are loading from InstructBlipConfig
if config_dict.get("model_type") == "llava":
config_dict = config_dict["mlp_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class LlavaConfig(PretrainedConfig):
model_type = "llava"
is_composition = True
def __init__(
self,
vision_config=None,
mlp_config=None,
text_config=None,
vision_select_layer=-2,
vision_select_feature="patch",
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {}
logger.info(
"vision_config is None. initializing the CLIPVisionConfig with default values."
)
if mlp_config is None:
mlp_config = {}
logger.info(
"mlp_config is None. Initializing the LlavaMlpConfig with default values."
)
if text_config is None:
text_config = {}
logger.info(
"text_config is None. Initializing the text config with default values (`OPTConfig`)."
)
self.vision_config = CLIPVisionConfig(**vision_config)
self.mlp_config = LlavaMlpConfig(**mlp_config)
text_model_type = text_config["model_type"]
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder
self.use_decoder_only_language_model = (
self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
self.vision_select_layer = vision_select_layer
assert vision_select_feature in [
"cls_patch",
"patch",
], f"Unexpected select feature: {vision_select_feature}"
self.vision_select_feature = vision_select_feature
self.initializer_factor = 1.0
self.initializer_range = 0.02
@classmethod
def from_vision_mlp_text_configs(
cls,
vision_config: CLIPVisionConfig,
mlp_config: LlavaMlpConfig,
text_config: PretrainedConfig,
**kwargs,
):
return cls(
vision_config=vision_config.to_dict(),
mlp_config=mlp_config.to_dict(),
text_config=text_config.to_dict(),
**kwargs,
)