Allow loading via AutoModelForSequenceClassification
Hello!
Pull Request overview
- Allow loading via
AutoModelForSequenceClassification
Intro
This model is looking awesome! Looking forward to learning more about LoCo as well. I'd be more than happy to help out to make sure that these models load well for your users.
I also sent a message in Slack, but I'm not sure if you all are in there, so I'll repeat it here too:
As a quick introduction, I'm Tom & I am in charge of Sentence Transformers nowadays. I encountered a few slight issues in your model configurations, and I took some time to address them on togethercomputer/m2-bert-80M-8k-retrieval:
- Allow loading via AutoModelForSequenceClassification (#1): There was a bug preventing your README snippet from working.
- Allow loading via AutoModel (#2): The configuration to load with AutoModel was missing.
- Allow loading via AutoTokenizer (#3): The configuration to defer the AutoTokenizer to
bert-base-cased
did not work - the auto_map can't be used like that sadly. This PR allows loading the tokenizer for this model directly, without having to override model_max_length.
Feel free to check these out & distribute the fixes across your models if you wish. Feel free to ask us if you need any assistance as well (we can also add you to a Slack channel for contact with us, if you're not in one already).
Additionally, I would certainly recommend including the MTEB results for these models in the model README metadata - it could be great for additional visibility.
Lastly, I'm looking into 1st party support for Sentence Transformers, allowing your models to be loaded directly with ST as well! It might allow your models to reach an even larger audience.
Details
I wanted to experiment with this model using:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
max_seq_length = 8192
testing_string = "Every morning, I make a cup of coffee to start my day."
model = AutoModelForSequenceClassification.from_pretrained(
"togethercomputer/m2-bert-80M-8k-retrieval",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased",
model_max_length=max_seq_length
)
input_ids = tokenizer(
[testing_string],
return_tensors="pt",
padding="max_length",
return_token_type_ids=False,
truncation=True,
max_length=max_seq_length
)
outputs = model(**input_ids)
embeddings = outputs['sentence_embedding']
print(embeddings[0,:10], embeddings[0].sum())
But I ran into this issue:
You are using a model of type m2_bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Traceback (most recent call last):
File "c:\code\m2-bert-80M-8k-retrieval\demo.py", line 5, in <module>
model = AutoModelForSequenceClassification.from_pretrained(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\tom\.conda\envs\sentence-transformers\Lib\site-packages\transformers\models\auto\auto_factory.py", line 511, in from_pretrained
cls.register(config.__class__, model_class, exist_ok=True)
File "C:\Users\tom\.conda\envs\sentence-transformers\Lib\site-packages\transformers\models\auto\auto_factory.py", line 537, in register
raise ValueError(
ValueError: The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has <class 'transformers.models.bert.configuration_bert.BertConfig'> and you passed <class 'transformers_modules.togethercomputer.m2-bert-80M-8k-retrieval.66ea5d6b12ab7e3d332bba708d76f83ce2909b2e.configuration_bert.BertConfig'>. Fix one of those so they match!
In short, the classes that I'm trying to initialize (e.g. your BertForTextEncoding
) are configured to work with transformers
its BertConfig
, rather than your own custom BertConfig
. This PR resolves this problem, by overriding the config_class
that is adopted from the BertPreTrainedModel
superclass.
After this PR
The above script now returns:
You are using a model of type m2_bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
-- Bidirectional: True
-- Using Long Conv Residual: True
-- Hyena w: 10
-- Hyena w mod: 1
-- Hyena filter order: 128
-- Hyena filter dropout: 0.2
-- Hyena filter wd: 0.1
-- Hyena filter emb dim: 5
-- Hyena filter lr: 0.001
-- Hyena filter lr pos emb: 1e-05
tensor([ 0.0399, 0.2460, 0.4248, 0.1803, -0.0941, -0.1501, 0.0705, 0.0478,
0.0119, -0.0807], grad_fn=<SliceBackward0>) tensor(-1.7552, grad_fn=<SumBackward0>)
π
- Tom Aarsen
Thank you!!