Add Sentence Transformers integration (#2)
Browse files- Add Sentence Transformers integration with CLS pooling (fc6047009e0ee0f2ad4a36e4bae86bb10ba961fd)
- Add tags/library_name for tighter integration with HF (249dd29816720d49dc21809339445932237b866b)
- 1_Pooling/config.json +10 -0
- README.md +33 -0
- config_sentence_transformers.json +10 -0
- modules.json +14 -0
- sentence_bert_config.json +4 -0
1_Pooling/config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 768,
|
3 |
+
"pooling_mode_cls_token": true,
|
4 |
+
"pooling_mode_mean_tokens": false,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
7 |
+
"pooling_mode_weightedmean_tokens": false,
|
8 |
+
"pooling_mode_lasttoken": false,
|
9 |
+
"include_prompt": true
|
10 |
+
}
|
README.md
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
---
|
2 |
tags:
|
3 |
- feature-extraction
|
|
|
|
|
|
|
4 |
language: en
|
5 |
datasets:
|
6 |
- SciDocs
|
@@ -28,6 +31,30 @@ PubMedNCL: Working with biomedical papers? Try [PubMedNCL](https://huggingface.c
|
|
28 |
|
29 |
## How to use the pretrained model
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
```python
|
32 |
from transformers import AutoTokenizer, AutoModel
|
33 |
|
@@ -49,6 +76,12 @@ result = model(**inputs)
|
|
49 |
|
50 |
# take the first token ([CLS] token) in the batch as the embedding
|
51 |
embeddings = result.last_hidden_state[:, 0, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
```
|
53 |
|
54 |
## Triplet Mining Parameters
|
|
|
1 |
---
|
2 |
tags:
|
3 |
- feature-extraction
|
4 |
+
- sentence-transformers
|
5 |
+
- transformers
|
6 |
+
library_name: sentence-transformers
|
7 |
language: en
|
8 |
datasets:
|
9 |
- SciDocs
|
|
|
31 |
|
32 |
## How to use the pretrained model
|
33 |
|
34 |
+
### Sentence Transformers
|
35 |
+
|
36 |
+
```python
|
37 |
+
from sentence_transformers import SentenceTransformer
|
38 |
+
|
39 |
+
# Load the model
|
40 |
+
model = SentenceTransformer("malteos/scincl")
|
41 |
+
|
42 |
+
# Concatenate the title and abstract with the [SEP] token
|
43 |
+
papers = [
|
44 |
+
"BERT [SEP] We introduce a new language representation model called BERT",
|
45 |
+
"Attention is all you need [SEP] The dominant sequence transduction models are based on complex recurrent or convolutional neural networks",
|
46 |
+
]
|
47 |
+
# Inference
|
48 |
+
embeddings = model.encode(papers)
|
49 |
+
|
50 |
+
# Compute the (cosine) similarity between embeddings
|
51 |
+
similarity = model.similarity(embeddings[0], embeddings[1])
|
52 |
+
print(similarity.item())
|
53 |
+
# => 0.8440517783164978
|
54 |
+
```
|
55 |
+
|
56 |
+
### Transformers
|
57 |
+
|
58 |
```python
|
59 |
from transformers import AutoTokenizer, AutoModel
|
60 |
|
|
|
76 |
|
77 |
# take the first token ([CLS] token) in the batch as the embedding
|
78 |
embeddings = result.last_hidden_state[:, 0, :]
|
79 |
+
|
80 |
+
# calculate the similarity
|
81 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
82 |
+
similarity = (embeddings[0] @ embeddings[1].T)
|
83 |
+
print(similarity.item())
|
84 |
+
# => 0.8440518379211426
|
85 |
```
|
86 |
|
87 |
## Triplet Mining Parameters
|
config_sentence_transformers.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "3.0.0",
|
4 |
+
"transformers": "4.41.2",
|
5 |
+
"pytorch": "2.3.0+cu121"
|
6 |
+
},
|
7 |
+
"prompts": {},
|
8 |
+
"default_prompt_name": null,
|
9 |
+
"similarity_fn_name": "cosine"
|
10 |
+
}
|
modules.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
}
|
14 |
+
]
|
sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 512,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|