Patch Sentence Transformers integration
#2
by
tomaarsen
HF staff
- opened
- {1_Pool → 1_Pooling}/config.json +1 -1
- README.md +5 -7
- config_sentence_transformers.json +25 -1
- modules.json +6 -0
- sentence_bert_config.json +3 -0
{1_Pool → 1_Pooling}/config.json
RENAMED
@@ -6,5 +6,5 @@
|
|
6 |
"pooling_mode_mean_sqrt_len_tokens": false,
|
7 |
"pooling_mode_weightedmean_tokens": false,
|
8 |
"pooling_mode_lasttoken": false,
|
9 |
-
"include_prompt":
|
10 |
}
|
|
|
6 |
"pooling_mode_mean_sqrt_len_tokens": false,
|
7 |
"pooling_mode_weightedmean_tokens": false,
|
8 |
"pooling_mode_lasttoken": false,
|
9 |
+
"include_prompt": true
|
10 |
}
|
README.md
CHANGED
@@ -262,6 +262,7 @@ model-index:
|
|
262 |
pipeline_tag: feature-extraction
|
263 |
tags:
|
264 |
- mteb
|
|
|
265 |
library_name: transformers
|
266 |
---
|
267 |
## MiniCPM-Embedding
|
@@ -401,21 +402,18 @@ import torch
|
|
401 |
from sentence_transformers import SentenceTransformer
|
402 |
|
403 |
model_name = "openbmb/MiniCPM-Embedding"
|
404 |
-
model = SentenceTransformer(model_name, trust_remote_code=True, model_kwargs={"attn_implementation":"flash_attention_2", "torch_dtype":torch.float16})
|
405 |
-
model.max_seq_length = 512
|
406 |
-
model.tokenizer.padding_side="right"
|
407 |
|
408 |
queries = ["中国的首都是哪里?"]
|
409 |
passages = ["beijing", "shanghai"]
|
410 |
|
411 |
-
|
412 |
INSTRUCTION = "Query: "
|
413 |
|
414 |
-
embeddings_query = model.encode(queries, prompt=INSTRUCTION
|
415 |
-
embeddings_doc = model.encode(passages
|
416 |
|
417 |
scores = (embeddings_query @ embeddings_doc.T)
|
418 |
-
print(scores.tolist()) # [[0.
|
419 |
```
|
420 |
|
421 |
## 实验结果 Evaluation Results
|
|
|
262 |
pipeline_tag: feature-extraction
|
263 |
tags:
|
264 |
- mteb
|
265 |
+
- sentence-transformers
|
266 |
library_name: transformers
|
267 |
---
|
268 |
## MiniCPM-Embedding
|
|
|
402 |
from sentence_transformers import SentenceTransformer
|
403 |
|
404 |
model_name = "openbmb/MiniCPM-Embedding"
|
405 |
+
model = SentenceTransformer(model_name, trust_remote_code=True, model_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": torch.float16})
|
|
|
|
|
406 |
|
407 |
queries = ["中国的首都是哪里?"]
|
408 |
passages = ["beijing", "shanghai"]
|
409 |
|
|
|
410 |
INSTRUCTION = "Query: "
|
411 |
|
412 |
+
embeddings_query = model.encode(queries, prompt=INSTRUCTION)
|
413 |
+
embeddings_doc = model.encode(passages)
|
414 |
|
415 |
scores = (embeddings_query @ embeddings_doc.T)
|
416 |
+
print(scores.tolist()) # [[0.35365450382232666, 0.18592746555805206]]
|
417 |
```
|
418 |
|
419 |
## 实验结果 Evaluation Results
|
config_sentence_transformers.json
CHANGED
@@ -4,6 +4,30 @@
|
|
4 |
"transformers": "4.37.2",
|
5 |
"pytorch": "2.0.1+cu121"
|
6 |
},
|
7 |
-
"prompts": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
"default_prompt_name": null
|
9 |
}
|
|
|
4 |
"transformers": "4.37.2",
|
5 |
"pytorch": "2.0.1+cu121"
|
6 |
},
|
7 |
+
"prompts": {
|
8 |
+
"fiqa": "Instruction: Given a financial question, retrieve user replies that best answer the question. Query: ",
|
9 |
+
"dbpedia": "Instruction: Given a query, retrieve relevant entity descriptions from DBPedia. Query: ",
|
10 |
+
"CmedqaRetrieval": "Instruction: 为这个医疗问题检索相关回答。 Query: ",
|
11 |
+
"nfcorpus": "Instruction: Given a question, retrieve relevant documents that best answer the question. Query: ",
|
12 |
+
"touche2020": "Instruction: Given a question, retrieve detailed and persuasive arguments that answer the question. Query: ",
|
13 |
+
"CovidRetrieval": "Instruction: 为这个问题检索相关政策回答。 Query: ",
|
14 |
+
"scifact": "Instruction: Given a scientific claim, retrieve documents that support or refute the claim. Query: ",
|
15 |
+
"scidocs": "Instruction: Given a scientific paper title, retrieve paper abstracts that are cited by the given paper. Query: ",
|
16 |
+
"nq": "Instruction: Given a question, retrieve Wikipedia passages that answer the question. Query: ",
|
17 |
+
"T2Retrieval": "Instruction: 为这个问题检索相关段落。 Query: ",
|
18 |
+
"VideoRetrieval": "Instruction: 为这个电影标题检索相关段落。 Query: ",
|
19 |
+
"DuRetrieval": "Instruction: 为这个问题检索相关百度知道回答。 Query: ",
|
20 |
+
"MMarcoRetrieval": "Instruction: 为这个查询检索相关段落。 Query: ",
|
21 |
+
"hotpotqa": "Instruction: Given a multi-hop question, retrieve documents that can help answer the question. Query: ",
|
22 |
+
"quora": "Instruction: Given a question, retrieve questions that are semantically equivalent to the given question. Query: ",
|
23 |
+
"climate-fever": "Instruction: Given a claim about climate change, retrieve documents that support or refute the claim. Query: ",
|
24 |
+
"arguana": "Instruction: Given a claim, find documents that refute the claim. Query: ",
|
25 |
+
"fever": "Instruction: Given a claim, retrieve documents that support or refute the claim. Query: ",
|
26 |
+
"trec-covid": "Instruction: Given a query on COVID-19, retrieve documents that answer the query. Query: ",
|
27 |
+
"msmarco": "Instruction: Given a web search query, retrieve relevant passages that answer the query. Query: ",
|
28 |
+
"EcomRetrieval": "Instruction: 为这个查询检索相关商品标题。 Query: ",
|
29 |
+
"MedicalRetrieval": "Instruction: 为这个医学问题检索相关回答。 Query: ",
|
30 |
+
"CAQstack":"Instruction: Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question. Query: "
|
31 |
+
},
|
32 |
"default_prompt_name": null
|
33 |
}
|
modules.json
CHANGED
@@ -10,5 +10,11 @@
|
|
10 |
"name": "1",
|
11 |
"path": "1_Pooling",
|
12 |
"type": "sentence_transformers.models.Pooling"
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
}
|
14 |
]
|
|
|
10 |
"name": "1",
|
11 |
"path": "1_Pooling",
|
12 |
"type": "sentence_transformers.models.Pooling"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"idx": 2,
|
16 |
+
"name": "2",
|
17 |
+
"path": "2_Normalize",
|
18 |
+
"type": "sentence_transformers.models.Normalize"
|
19 |
}
|
20 |
]
|
sentence_bert_config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 512
|
3 |
+
}
|