Allow loading via AutoModel
#2
by
tomaarsen
HF staff
- opened
Hello!
Pull Request overview
- Allow loading via
AutoModel
Details
I wanted to experiment with loading this model with just AutoModel
:
from transformers import AutoTokenizer, AutoModel
max_seq_length = 8192
testing_string = "Every morning, I make a cup of coffee to start my day."
model = AutoModel.from_pretrained(
"togethercomputer/m2-bert-80M-8k-retrieval",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased",
model_max_length=max_seq_length
)
input_ids = tokenizer(
[testing_string],
return_tensors="pt",
padding="max_length",
return_token_type_ids=False,
truncation=True,
max_length=max_seq_length
)
encoder_outputs, pooled_output = model(**input_ids)
print(encoder_outputs.shape, pooled_output.shape)
But I ran into this issue:
You are using a model of type m2_bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Traceback (most recent call last):
File "c:\code\m2-bert-80M-8k-retrieval\demo.py", line 5, in <module>
model = AutoModel.from_pretrained(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\tom\.conda\envs\sentence-transformers\Lib\site-packages\transformers\models\auto\auto_factory.py", line 519, in from_pretrained
raise ValueError(
ValueError: Unrecognized configuration class <class 'transformers_modules.togethercomputer.m2-bert-80M-8k-retrieval.66ea5d6b12ab7e3d332bba708d76f83ce2909b2e.configuration_bert.BertConfig'> for this kind of AutoModel: AutoModel.
Model type should be one of AlbertConfig, AlignConfig, AltCLIPConfig, ASTConfig, AutoformerConfig, BarkConfig, BartConfig, BeitConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BitConfig, BlenderbotConfig, BlenderbotSmallConfig, BlipConfig, Blip2Config, BloomConfig, BridgeTowerConfig, CamembertConfig, CanineConfig, ChineseCLIPConfig, ClapConfig, CLIPConfig, CLIPSegConfig, CodeGenConfig, ConditionalDetrConfig, ConvBertConfig, ConvNextConfig, ConvNextV2Config, CpmAntConfig, CTRLConfig, CvtConfig, Data2VecAudioConfig, Data2VecTextConfig, Data2VecVisionConfig, DebertaConfig, DebertaV2Config, DecisionTransformerConfig, DeformableDetrConfig, DeiTConfig, DetaConfig, DetrConfig, DinatConfig, Dinov2Config, DistilBertConfig, DonutSwinConfig, DPRConfig, DPTConfig, EfficientFormerConfig, EfficientNetConfig, ElectraConfig, EncodecConfig, ErnieConfig, ErnieMConfig, EsmConfig, FalconConfig, FlaubertConfig, FlavaConfig, FNetConfig, FocalNetConfig, FSMTConfig, FunnelConfig, GitConfig, GLPNConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, GPTSanJapaneseConfig, GraphormerConfig, GroupViTConfig, HubertConfig, IBertConfig, IdeficsConfig, ImageGPTConfig, InformerConfig, JukeboxConfig, LayoutLMConfig, LayoutLMv2Config, LayoutLMv3Config, LEDConfig, LevitConfig, LiltConfig, LlamaConfig, LongformerConfig, LongT5Config, LukeConfig, LxmertConfig, M2M100Config, MarianConfig, MarkupLMConfig, Mask2FormerConfig, MaskFormerConfig, MaskFormerSwinConfig, MBartConfig, MCTCTConfig, MegaConfig, MegatronBertConfig, MgpstrConfig, MobileBertConfig, MobileNetV1Config, MobileNetV2Config, MobileViTConfig, MobileViTV2Config, MPNetConfig, MptConfig, MraConfig, MT5Config, MvpConfig, NatConfig, NezhaConfig, NllbMoeConfig, NystromformerConfig, OneFormerConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, OwlViTConfig, PegasusConfig, PegasusXConfig, PerceiverConfig, PLBartConfig, PoolFormerConfig, ProphetNetConfig, PvtConfig, QDQBertConfig, ReformerConfig, RegNetConfig, RemBertConfig, ResNetConfig, RetriBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, RwkvConfig, SamConfig, SegformerConfig, SEWConfig, SEWDConfig, Speech2TextConfig, SpeechT5Config, SplinterConfig, SqueezeBertConfig, SwiftFormerConfig, SwinConfig, Swin2SRConfig, Swinv2Config, SwitchTransformersConfig, T5Config, TableTransformerConfig, TapasConfig, TimeSeriesTransformerConfig, TimesformerConfig, TimmBackboneConfig, TrajectoryTransformerConfig, TransfoXLConfig, TvltConfig, UMT5Config, UniSpeechConfig, UniSpeechSatConfig, VanConfig, VideoMAEConfig, ViltConfig, VisionTextDualEncoderConfig, VisualBertConfig, ViTConfig, ViTHybridConfig, ViTMAEConfig, ViTMSNConfig, VivitConfig, Wav2Vec2Config, Wav2Vec2ConformerConfig, WavLMConfig, WhisperConfig, XCLIPConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, YolosConfig, YosoConfig.
In short, the custom BertConfig is not defined in AutoModel, so it doesn't know what class to initialize. This PR fixes this by also setting a value for AutoModel
in the config.json
.
After this PR
The above script should now return:
You are using a model of type m2_bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
-- Bidirectional: True
-- Using Long Conv Residual: True
-- Hyena w: 10
-- Hyena w mod: 1
-- Hyena filter order: 128
-- Hyena filter dropout: 0.2
-- Hyena filter wd: 0.1
-- Hyena filter emb dim: 5
-- Hyena filter lr: 0.001
-- Hyena filter lr pos emb: 1e-05
torch.Size([1, 8192, 768]) torch.Size([1, 768])
π
- Tom Aarsen
tomaarsen
changed pull request status to
open
Thank you!!
danfu09
changed pull request status to
merged