|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch JapaneseStableLMAlpha model. """ |
|
import torch |
|
from torch import nn |
|
from transformers import ( |
|
InstructBlipPreTrainedModel, |
|
InstructBlipVisionModel, |
|
InstructBlipQFormerModel, |
|
InstructBlipForConditionalGeneration, |
|
AutoModelForCausalLM, |
|
AutoModelForSeq2SeqLM, |
|
) |
|
from transformers.utils import logging |
|
from .modeling_japanese_stablelm_alpha import JapaneseStableLMAlphaForCausalLM |
|
from .configuration_japanese_instructblip_alpha import JapaneseInstructBlipAlphaConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class JapaneseInstructBlipAlphaForConditionalGeneration(InstructBlipForConditionalGeneration): |
|
config_class = JapaneseInstructBlipAlphaConfig |
|
|
|
def __init__(self, config: JapaneseInstructBlipAlphaConfig): |
|
InstructBlipPreTrainedModel.__init__(self, config) |
|
|
|
self.vision_model = InstructBlipVisionModel(config.vision_config) |
|
|
|
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) |
|
self.qformer = InstructBlipQFormerModel(config.qformer_config) |
|
|
|
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) |
|
|
|
if config.use_decoder_only_language_model: |
|
language_model = JapaneseStableLMAlphaForCausalLM(config.text_config) |
|
else: |
|
raise NotImplementedError |
|
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config, trust_remote_code=True,) |
|
|
|
if language_model._no_split_modules is not None: |
|
self._no_split_modules.extend(language_model._no_split_modules) |
|
|
|
if language_model._keep_in_fp32_modules is not None: |
|
self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules) |
|
|
|
self.language_model = language_model |
|
|
|
|
|
self.post_init() |
|
|