|
# wordllama |
|
|
|
## Installation |
|
|
|
Use the github repo or install via pip: https://github.com/dleemiller/WordLlama |
|
```python |
|
pip install wordllama |
|
``` |
|
|
|
## Intended Use |
|
|
|
This model is intended for use in natural language processing applications that require text embeddings, such as text classification, sentiment analysis, and document clustering. |
|
It's a token embedding model that is comparable to word embedding models, but substantionally smaller in size (16mb default 256-dim model). |
|
|
|
```python |
|
from wordllama import load |
|
|
|
wl = load() |
|
similarity_score = wl.similarity("i went to the car", "i went to the pawn shop") |
|
print(similarity_score) # Output: 0.06641249096796882 |
|
``` |
|
|
|
## Model Architecture |
|
|
|
Wordllama is based on token embedding codebooks extracted from large language models. |
|
It is trained like a general embedding, with MultipleNegativesRankingLoss using the sentence transformers library, |
|
using Matryoshka Representation Learning so that embeddings can be truncated to 64, 128, 256, 512 or 1024 dimensions. |
|
|
|
To create WordLlama L2 "supercat", we extract and concatenate the token embedding codebooks from several large language models that |
|
use the llama2 tokenizer vocabulary (32k vocab size). This includes models like Llama2 70B and Phi-3 Medium. |
|
Then we add a trainable token weight parameter and initialize stopwords to a smaller value (0.1). Finally, we |
|
train a projection from the large, concatenated codebook down to a smaller dimension and average pool. |
|
|
|
We use popular embeddings datasets from sentence transformers, and matryoshka representation learning (MRL) so that |
|
dimensions can be truncated. For "binary" models, we train using a straight through estimator, so that the embeddings |
|
can be binarized eg, (x>0).sign() and packed into integers for hamming distance computation. |
|
|
|
After training, we save a new, small token embedding codebook, which is analogous to vectors of a word embedding. |
|
|
|
|
|
## MTEB Results (l2_supercat) |
|
|
|
| Metric | WL64 | WL128 | WL256 (X) | WL512 | WL1024 | GloVe 300d | Komninos | all-MiniLM-L6-v2 | |
|
|------------------------|-------------|--------------|--------------|--------------|---------------|------------|----------|------------------| |
|
| Clustering | 30.27 | 32.20 | 33.25 | 33.40 | 33.62 | 27.73 | 26.57 | 42.35 | |
|
| Reranking | 50.38 | 51.52 | 52.03 | 52.32 | 52.39 | 43.29 | 44.75 | 58.04 | |
|
| Classification | 53.14 | 56.25 | 58.21 | 59.13 | 59.50 | 57.29 | 57.65 | 63.05 | |
|
| Pair Classification | 75.80 | 77.59 | 78.22 | 78.50 | 78.60 | 70.92 | 72.94 | 82.37 | |
|
| STS | 66.24 | 67.53 | 67.91 | 68.22 | 68.27 | 61.85 | 62.46 | 78.90 | |
|
| CQA DupStack | 18.76 | 22.54 | 24.12 | 24.59 | 24.83 | 15.47 | 16.79 | 41.32 | |
|
| SummEval | 30.79 | 29.99 | 30.99 | 29.56 | 29.39 | 28.87 | 30.49 | 30.81 | |
|
|
|
--- |
|
license: mit |
|
datasets: |
|
- sentence-transformers/all-nli |
|
- sentence-transformers/gooaq |
|
language: |
|
- en |
|
--- |
|
|