dleemiller's picture
Update README.md
ebe0c07 verified
|
raw
history blame
No virus
3.36 kB

wordllama

Installation

Use the github repo or install via pip: https://github.com/dleemiller/WordLlama

pip install wordllama

Intended Use

This model is intended for use in natural language processing applications that require text embeddings, such as text classification, sentiment analysis, and document clustering. It's a token embedding model that is comparable to word embedding models, but substantionally smaller in size (16mb default 256-dim model).

from wordllama import load

wl = load()
similarity_score = wl.similarity("i went to the car", "i went to the pawn shop")
print(similarity_score)  # Output: 0.06641249096796882

Model Architecture

Wordllama is based on token embedding codebooks extracted from large language models. It is trained like a general embedding, with MultipleNegativesRankingLoss using the sentence transformers library, using Matryoshka Representation Learning so that embeddings can be truncated to 64, 128, 256, 512 or 1024 dimensions.

To create WordLlama L2 "supercat", we extract and concatenate the token embedding codebooks from several large language models that use the llama2 tokenizer vocabulary (32k vocab size). This includes models like Llama2 70B and Phi-3 Medium. Then we add a trainable token weight parameter and initialize stopwords to a smaller value (0.1). Finally, we train a projection from the large, concatenated codebook down to a smaller dimension and average pool.

We use popular embeddings datasets from sentence transformers, and matryoshka representation learning (MRL) so that dimensions can be truncated. For "binary" models, we train using a straight through estimator, so that the embeddings can be binarized eg, (x>0).sign() and packed into integers for hamming distance computation.

After training, we save a new, small token embedding codebook, which is analogous to vectors of a word embedding.

MTEB Results (l2_supercat)

Metric WL64 WL128 WL256 (X) WL512 WL1024 GloVe 300d Komninos all-MiniLM-L6-v2
Clustering 30.27 32.20 33.25 33.40 33.62 27.73 26.57 42.35
Reranking 50.38 51.52 52.03 52.32 52.39 43.29 44.75 58.04
Classification 53.14 56.25 58.21 59.13 59.50 57.29 57.65 63.05
Pair Classification 75.80 77.59 78.22 78.50 78.60 70.92 72.94 82.37
STS 66.24 67.53 67.91 68.22 68.27 61.85 62.46 78.90
CQA DupStack 18.76 22.54 24.12 24.59 24.83 15.47 16.79 41.32
SummEval 30.79 29.99 30.99 29.56 29.39 28.87 30.49 30.81

license: mit datasets: - sentence-transformers/all-nli - sentence-transformers/gooaq language: - en