|
--- |
|
license: apache-2.0 |
|
language: |
|
- zh |
|
library_name: transformers |
|
--- |
|
|
|
### ICAA-ranker |
|
Instruction-Aware Contextual Compressor(ICAA) is an open-source re-ranking/context compression model developed by the Guangdong Laboratory of Artificial Intelligence and Digital Economy (Shenzhen Guangming Laboratory). |
|
This repository, IACC-ranker, is designated for housing the ranker. The compressor will be placed on a separate page. |
|
It is trained on a dataset of 15 million Chinese sentence pairs. |
|
It has consistently delivered the good results across various Chinese test datasets. |
|
For those who wish to utilize the more extensive features of RankingPrompter, such as the complete document encoding-retrieval-fine-tuning pipeline, we recommend the use of the accompanying codebase[https://github.com/howard-hou/instruction-aware-contextual-compressor/tree/main]. |
|
|
|
### How to use |
|
|
|
You can use this model simply as a re-ranker, note now the model is only available for Chinese. |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("howard-hou/IACC-ranker-small") |
|
# trust_remote_code=True 很重要,否则不会读取到正确的模型 |
|
model = AutoModel.from_pretrained("howard-hou/IACC-ranker-small", |
|
trust_remote_code=True) |
|
|
|
# |
|
documents = [ |
|
'水库诱发地震的震中多在库底和水库边缘。', |
|
'双标紫斑蝶广泛分布于南亚、东南亚、澳洲、新几内亚等地。台湾地区于本岛中海拔地区可见,多以特有亚种归类。', |
|
'月经停止是怀孕最显著也是最早的一个信号,如果在无避孕措施下进行了性生活而出现月经停止的话,很可能就是怀孕了。' |
|
] |
|
|
|
question = "什么是怀孕最显著也是最早的信号?" |
|
|
|
question_input = tokenizer(question, padding=True, return_tensors="pt") |
|
docs_input = tokenizer(documents, padding=True, return_tensors="pt") |
|
# document input shape should be [batch_size, num_docs, seq_len] |
|
# so if only input one sample of documents, add one dim by unsqueeze(0) |
|
output = model( |
|
document_input_ids=docs_input.input_ids.unsqueeze(0), |
|
document_attention_mask=docs_input.attention_mask.unsqueeze(0), |
|
question_input_ids=question_input.input_ids, |
|
question_attention_mask=question_input.attention_mask |
|
) |
|
print("reranking scores: ", output.logits) |
|
``` |