Update README.md
Browse files
README.md
CHANGED
@@ -562,6 +562,119 @@ language:
|
|
562 |
library_name: transformers
|
563 |
---
|
564 |
|
565 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
|
567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
library_name: transformers
|
563 |
---
|
564 |
|
565 |
+
---
|
566 |
+
|
567 |
+
## Zhihui_LLM_Embedding
|
568 |
+
|
569 |
+
### Model Introduction
|
570 |
+
|
571 |
+
**Zhihui_LLM_Embedding** is an embedding model specifically designed to enhance Chinese text retrieval capabilities. It is built on a 7B LLM and enhanced bidirectional attention mechanism to improved contextual understanding. The model is trained on an extensive corpus from various fields within an extremely large batch. **Zhihui_LLM_Embedding** excels in retrieval tasks, ranking **1st position** on the C-MTEB leaderboard with a leading performance score of **76.74** as of June 25, 2024.
|
572 |
+
|
573 |
+
### Optimization points
|
574 |
+
* Data source enhancement: Leverages the knowledge of LLMs through three types of distillation methods.(GPT3.5 & GPT4)
|
575 |
+
* Data Refinement: LLM scores candidate positive passages to select the most relevant examples.
|
576 |
+
* Query Rewriting: LLM generates queries that can be answered by positive documents but are unrelated to negatives, thus enhancing the query's quality and diversity.
|
577 |
+
* Query Expansion: Queries are expanded based on multiple topics for long documents.
|
578 |
+
* Negative example mining: Use multiple methods and different ranges of negative selection to mine hard negative examples.
|
579 |
+
* Improved Contrastive Loss: Design a novel InfoNCE loss assigns higher weights to the harder negative examples to improve the fine-grained feature representation of the model.
|
580 |
+
* Bidirectional-attention: Remove the causal attention of LLMs during contrastive training of decoder-only LLM to produce rich contextualized representations.
|
581 |
+
* Training efficiency: Using Gradient Cache to scale contrastive learning batches beyond GPU memory constraints allows the model to learn from more challenging negative examples.
|
582 |
+
* Others: Dataset-Homogenous Batching、cross-batch negative sampling
|
583 |
+
|
584 |
+
### Model Details
|
585 |
+
* Base Decoder-only LLM: [gte-Qwen2-7B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct)
|
586 |
+
* Pooling Methods: Last token
|
587 |
+
* Embedding Dimension: 3584
|
588 |
+
|
589 |
+
### Usage
|
590 |
+
##### Requirements
|
591 |
+
```
|
592 |
+
transformers>=4.40.2
|
593 |
+
flash_attn>=2.5.8
|
594 |
+
sentence-transformers>=2.7.0
|
595 |
+
```
|
596 |
+
##### How to use
|
597 |
+
Here is an example of how to encode queries and passages using Huggingface-transformer and Sentence-transformer.
|
598 |
+
##### Usage (HuggingFace Transformers)
|
599 |
+
```python
|
600 |
+
import torch
|
601 |
+
import torch.nn.functional as F
|
602 |
+
|
603 |
+
from torch import Tensor
|
604 |
+
from transformers import AutoTokenizer, AutoModel
|
605 |
+
|
606 |
+
|
607 |
+
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
608 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
609 |
+
if left_padding:
|
610 |
+
return last_hidden_states[:, -1]
|
611 |
+
else:
|
612 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
613 |
+
batch_size = last_hidden_states.shape[0]
|
614 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
615 |
+
|
616 |
+
|
617 |
+
def get_detailed_instruct(task_description: str, query: str) -> str:
|
618 |
+
return f'Instruct: {task_description}\nQuery: {query}'
|
619 |
+
|
620 |
+
|
621 |
+
|
622 |
+
task = 'Given a web search query, retrieve relevant passages that answer the query'
|
623 |
+
queries = [
|
624 |
+
get_detailed_instruct(task, "国家法定节假日共多少天"),
|
625 |
+
get_detailed_instruct(task, "如何查看好友申请")
|
626 |
+
]
|
627 |
+
|
628 |
+
documents = [
|
629 |
+
"一年国家法定节假日为11天。根据公布的国家法定节假日调整方案,调整的主要内容包括:元旦放假1天不变;春节放假3天,放假时间为农历正月初一、初二、初三;“五一”国际劳动节1天不变;“十一”国庆节放假3天;清明节、端午节、中秋节增设为国家法定节假日,各放假1天(农历节日如遇闰月,以第一个月为休假日)。3、允许周末上移下错,与法定节假日形成连休。",
|
630 |
+
"这个直接去我的QQ中心不就好了么那里可以查到 我的好友单向好友好友恢复、 以及好友申请 啊可以是你加别人的 或 别人加你的都可以查得到QQ空间里 这个没注意 要有的话也会在你进空间的时候会提示你的QQ 空间里 上面消息 就可以看见了!望采纳!谢谢这个直接去我的QQ中心不就好了么那里可以查到 我的好友单向好友好友恢复、 以及好友申请 啊可以是你加别人的 或 别人加你的都可以查得到",
|
631 |
+
]
|
632 |
+
input_texts = queries + documents
|
633 |
+
|
634 |
+
tokenizer = AutoTokenizer.from_pretrained('Lenovo-Zhihui/Zhihui_LLM_Embedding', trust_remote_code=True)
|
635 |
+
model = AutoModel.from_pretrained('Lenovo-Zhihui/Zhihui_LLM_Embedding', trust_remote_code=True)
|
636 |
+
|
637 |
+
max_length = 512
|
638 |
+
|
639 |
+
# Tokenize the input texts
|
640 |
+
batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
|
641 |
+
outputs = model(**batch_dict)
|
642 |
+
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
643 |
+
|
644 |
+
# normalize embeddings
|
645 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
646 |
+
scores = (embeddings[:2] @ embeddings[2:].T)
|
647 |
+
print(scores.tolist())
|
648 |
+
|
649 |
+
```
|
650 |
+
##### Usage (Sentence-Transformers)
|
651 |
+
```python
|
652 |
+
from sentence_transformers import SentenceTransformer
|
653 |
+
model = SentenceTransformer("Lenovo-Zhihui/Zhihui_LLM_Embedding", trust_remote_code=True)
|
654 |
+
model.max_seq_length = 512
|
655 |
+
# 数据来源DuRetrieval https://huggingface.co/datasets/C-MTEB/DuRetrieval
|
656 |
+
queries = [
|
657 |
+
"国家法定节假日共多少天",
|
658 |
+
"如何查看好友申请",
|
659 |
+
]
|
660 |
+
documents = [
|
661 |
+
"一年国家法定节假日为11天。根据公布的国家法定节假日调整方案,调整的主要内容包括:元旦放假1天不变;春节放假3天,放假时间为农历正月初一、初二、初三;“五一”国际劳动节1天不变;“十一”国庆节放假3天;清明节、端午节、中秋节增设为国家法定节假日,各放假1天(农历节日如遇闰月,以第一个月为休假日)。3、允许周末上移下错,与法定节假日形成连休。",
|
662 |
+
"这个直接去我的QQ中心不就好了么那里可以查到 我的好友单向好友好友恢复、 以及好友申请 啊可以是你加别人的 或 别人加你的都可以查得到QQ空间里 这个没注意 要有的话也会在你进空间的时候会提示你的QQ 空间里 上面消息 就可以看见了!望采纳!谢谢这个直接去我的QQ中心不就好了么那里可以查到 我的好友单向好友好友恢复、 以及好友申请 啊可以是你加别人的 或 别人加你的都可以查得到",
|
663 |
+
]
|
664 |
+
|
665 |
+
query_embeddings = model.encode(queries, prompt_name="query")
|
666 |
+
document_embeddings = model.encode(documents)
|
667 |
+
|
668 |
+
scores = (query_embeddings @ document_embeddings.T)
|
669 |
+
print(scores.tolist())
|
670 |
+
```
|
671 |
+
### Reproduce our results(C-MTEB):
|
672 |
+
Check out scripts/eval_mteb.py to reproduce evaluation results on C-MTEB benchmark.
|
673 |
|
674 |
+
| Model | T2Retrieval | MMarcoRetrieval | DuRetrieval | CovidRetrieval | CmedqaRetrieval | EcomRetrieval | MedicalRetrieval | VideoRetrieval | Avg |
|
675 |
+
|:-------------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|
|
676 |
+
|**Zhihui_LLM_Embedding** | 88.30 | 84.77 | 91.34 | 84.39 | 48.69 | 71.96 | 65.19 | 79.31 | **76.74** |
|
677 |
+
|zpoint_large_embedding_zh | 83.81 | 82.38 | 89.23 | 89.14 | 47.16 | 70.74 | 68.14 | 80.26 | 76.36 |
|
678 |
+
|gte-Qwen2-7B-instruct | 87.73 | 85.16 | 87.44 | 83.65 | 48.69 | 71.15 | 65.59 | 78.84 | 76.03 |
|
679 |
+
|360Zhinao-search | 87.12 | 83.32 | 87.57 | 85.02 | 46.73 | 68.9 | 63.69 | 78.09 | 75.06 |
|
680 |
+
|AGE_Hybrid | 86.88 | 80.65 | 89.28 | 83.66 | 47.26 | 69.28 | 65.94 | 76.79 | 74.97 |
|