# coding=utf-8 # Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Auto Model class.""" from collections import OrderedDict from ...utils import logging from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update from .configuration_auto import CONFIG_MAPPING_NAMES logger = logging.get_logger(__name__) FLAX_MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping ("albert", "FlaxAlbertModel"), ("bart", "FlaxBartModel"), ("beit", "FlaxBeitModel"), ("bert", "FlaxBertModel"), ("big_bird", "FlaxBigBirdModel"), ("blenderbot", "FlaxBlenderbotModel"), ("blenderbot-small", "FlaxBlenderbotSmallModel"), ("bloom", "FlaxBloomModel"), ("clip", "FlaxCLIPModel"), ("distilbert", "FlaxDistilBertModel"), ("electra", "FlaxElectraModel"), ("gpt-sw3", "FlaxGPT2Model"), ("gpt2", "FlaxGPT2Model"), ("gpt_neo", "FlaxGPTNeoModel"), ("gptj", "FlaxGPTJModel"), ("llama", "FlaxLlamaModel"), ("longt5", "FlaxLongT5Model"), ("marian", "FlaxMarianModel"), ("mbart", "FlaxMBartModel"), ("mt5", "FlaxMT5Model"), ("opt", "FlaxOPTModel"), ("pegasus", "FlaxPegasusModel"), ("regnet", "FlaxRegNetModel"), ("resnet", "FlaxResNetModel"), ("roberta", "FlaxRobertaModel"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"), ("roformer", "FlaxRoFormerModel"), ("t5", "FlaxT5Model"), ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), ("vit", "FlaxViTModel"), ("wav2vec2", "FlaxWav2Vec2Model"), ("whisper", "FlaxWhisperModel"), ("xglm", "FlaxXGLMModel"), ("xlm-roberta", "FlaxXLMRobertaModel"), ] ) FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping ("albert", "FlaxAlbertForPreTraining"), ("bart", "FlaxBartForConditionalGeneration"), ("bert", "FlaxBertForPreTraining"), ("big_bird", "FlaxBigBirdForPreTraining"), ("electra", "FlaxElectraForPreTraining"), ("longt5", "FlaxLongT5ForConditionalGeneration"), ("mbart", "FlaxMBartForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"), ("roberta", "FlaxRobertaForMaskedLM"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"), ("t5", "FlaxT5ForConditionalGeneration"), ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), ("whisper", "FlaxWhisperForConditionalGeneration"), ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), ] ) FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping ("albert", "FlaxAlbertForMaskedLM"), ("bart", "FlaxBartForConditionalGeneration"), ("bert", "FlaxBertForMaskedLM"), ("big_bird", "FlaxBigBirdForMaskedLM"), ("distilbert", "FlaxDistilBertForMaskedLM"), ("electra", "FlaxElectraForMaskedLM"), ("mbart", "FlaxMBartForConditionalGeneration"), ("roberta", "FlaxRobertaForMaskedLM"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"), ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), ] ) FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping ("bart", "FlaxBartForConditionalGeneration"), ("blenderbot", "FlaxBlenderbotForConditionalGeneration"), ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"), ("encoder-decoder", "FlaxEncoderDecoderModel"), ("longt5", "FlaxLongT5ForConditionalGeneration"), ("marian", "FlaxMarianMTModel"), ("mbart", "FlaxMBartForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"), ("pegasus", "FlaxPegasusForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"), ] ) FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image-classsification ("beit", "FlaxBeitForImageClassification"), ("regnet", "FlaxRegNetForImageClassification"), ("resnet", "FlaxResNetForImageClassification"), ("vit", "FlaxViTForImageClassification"), ] ) FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( [ ("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"), ] ) FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping ("bart", "FlaxBartForCausalLM"), ("bert", "FlaxBertForCausalLM"), ("big_bird", "FlaxBigBirdForCausalLM"), ("bloom", "FlaxBloomForCausalLM"), ("electra", "FlaxElectraForCausalLM"), ("gpt-sw3", "FlaxGPT2LMHeadModel"), ("gpt2", "FlaxGPT2LMHeadModel"), ("gpt_neo", "FlaxGPTNeoForCausalLM"), ("gptj", "FlaxGPTJForCausalLM"), ("llama", "FlaxLlamaForCausalLM"), ("opt", "FlaxOPTForCausalLM"), ("roberta", "FlaxRobertaForCausalLM"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"), ("xglm", "FlaxXGLMForCausalLM"), ("xlm-roberta", "FlaxXLMRobertaForCausalLM"), ] ) FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping ("albert", "FlaxAlbertForSequenceClassification"), ("bart", "FlaxBartForSequenceClassification"), ("bert", "FlaxBertForSequenceClassification"), ("big_bird", "FlaxBigBirdForSequenceClassification"), ("distilbert", "FlaxDistilBertForSequenceClassification"), ("electra", "FlaxElectraForSequenceClassification"), ("mbart", "FlaxMBartForSequenceClassification"), ("roberta", "FlaxRobertaForSequenceClassification"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"), ("roformer", "FlaxRoFormerForSequenceClassification"), ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"), ] ) FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping ("albert", "FlaxAlbertForQuestionAnswering"), ("bart", "FlaxBartForQuestionAnswering"), ("bert", "FlaxBertForQuestionAnswering"), ("big_bird", "FlaxBigBirdForQuestionAnswering"), ("distilbert", "FlaxDistilBertForQuestionAnswering"), ("electra", "FlaxElectraForQuestionAnswering"), ("mbart", "FlaxMBartForQuestionAnswering"), ("roberta", "FlaxRobertaForQuestionAnswering"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"), ("roformer", "FlaxRoFormerForQuestionAnswering"), ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"), ] ) FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping ("albert", "FlaxAlbertForTokenClassification"), ("bert", "FlaxBertForTokenClassification"), ("big_bird", "FlaxBigBirdForTokenClassification"), ("distilbert", "FlaxDistilBertForTokenClassification"), ("electra", "FlaxElectraForTokenClassification"), ("roberta", "FlaxRobertaForTokenClassification"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"), ("roformer", "FlaxRoFormerForTokenClassification"), ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"), ] ) FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping ("albert", "FlaxAlbertForMultipleChoice"), ("bert", "FlaxBertForMultipleChoice"), ("big_bird", "FlaxBigBirdForMultipleChoice"), ("distilbert", "FlaxDistilBertForMultipleChoice"), ("electra", "FlaxElectraForMultipleChoice"), ("roberta", "FlaxRobertaForMultipleChoice"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"), ("roformer", "FlaxRoFormerForMultipleChoice"), ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"), ] ) FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( [ ("bert", "FlaxBertForNextSentencePrediction"), ] ) FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( [ ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"), ("whisper", "FlaxWhisperForConditionalGeneration"), ] ) FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ ("whisper", "FlaxWhisperForAudioClassification"), ] ) FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES) FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES ) FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES ) FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES ) FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES ) FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES ) FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES ) FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES ) FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES ) FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES ) class FlaxAutoModel(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_MAPPING FlaxAutoModel = auto_class_update(FlaxAutoModel) class FlaxAutoModelForPreTraining(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining") class FlaxAutoModelForCausalLM(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling") class FlaxAutoModelForMaskedLM(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling") class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING FlaxAutoModelForSeq2SeqLM = auto_class_update( FlaxAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" ) class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING FlaxAutoModelForSequenceClassification = auto_class_update( FlaxAutoModelForSequenceClassification, head_doc="sequence classification" ) class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering") class FlaxAutoModelForTokenClassification(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING FlaxAutoModelForTokenClassification = auto_class_update( FlaxAutoModelForTokenClassification, head_doc="token classification" ) class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice") class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING FlaxAutoModelForNextSentencePrediction = auto_class_update( FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction" ) class FlaxAutoModelForImageClassification(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING FlaxAutoModelForImageClassification = auto_class_update( FlaxAutoModelForImageClassification, head_doc="image classification" ) class FlaxAutoModelForVision2Seq(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling") class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING FlaxAutoModelForSpeechSeq2Seq = auto_class_update( FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" )