TaiVisionLM-base-v2 / configuration_taivisionlm.py
benchang1110's picture
Upload TaiVisionForCausalLM
b26c61c verified
raw
history blame
No virus
3.8 kB
"""TaiVisionLM configuration"""
from transformers import PretrainedConfig
from transformers import logging, CONFIG_MAPPING
import warnings
import transformers
logger = logging.get_logger(__name__)
class TaiVisionLMConfig(PretrainedConfig):
model_type = "taivisionlm"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_idx=32000,
vocab_size=32001,
projection_dim=768,
hidden_size=2048,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_idx
self._vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.vision_config = vision_config
self.is_encoder_decoder = False
if isinstance(self.vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
)
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
attention_dropout=0.0,
hidden_act="gelu_pytorch_tanh",
hidden_size=768,
image_size=224,
intermediate_size=3072,
layer_norm_eps=1e-06,
num_attention_heads=12,
num_channels=3,
num_hidden_layers=12,
patch_size=16,
)
self.vocab_size = vocab_size
self.text_config = text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gpt2"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
self.text_config = CONFIG_MAPPING["llama"](
architecture= ["LlamaForCausalLM"],
hidden_act = "silu",
attention_bias = False,
attention_dropout = 0.0,
bos_token_id = 1,
eos_token_id = 2,
hidden_size = 2048,
initializer_range = 0.02,
intermediate_size = 5632,
max_position_embeddings = 2048,
model_type = "llama",
num_attention_heads = 32,
num_hidden_layers = 22,
num_key_value_heads = 4,
pretraining_tp = 1,
rms_norm_eps = 1e-05,
rope_scaling = None,
rope_theta = 10000.0,
tie_word_embeddings = False,
torch_dtype = "bfloat16",
transformers_version = "4.40.2",
use_cache = True,
vocab_size = 32001
)
self.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
self.pad_token_id = self.text_config.pad_token_id
self.vision_config.projection_dim = projection_dim
super().__init__(**kwargs)
@property
def vocab_size(self):
warnings.warn(
"The `vocab_size` attribute is deprecated and will be removed in v4.44, Please use `text_config.vocab_size` instead.",
FutureWarning,
)
return self._vocab_size
@vocab_size.setter
def vocab_size(self, value):
self._vocab_size = value
def to_dict(self):
output = super().to_dict()
output.pop("_vocab_size", None)
return output