File size: 4,290 Bytes
9138efc 73ef392 9138efc a68cb4c 73ef392 a68cb4c 73ef392 a68cb4c 73ef392 a68cb4c 73ef392 a68cb4c 73ef392 05926ac e2ee5b6 05926ac e2ee5b6 05926ac e2ee5b6 05926ac e2ee5b6 05926ac e2ee5b6 05926ac e2ee5b6 05926ac 73ef392 05926ac e2ee5b6 05926ac a68cb4c 73ef392 a68cb4c 73ef392 e2ee5b6 73ef392 a68cb4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
---
pipeline_tag: sentence-similarity
tags:
- sentence-transformers
- feature-extraction
- sentence-similarity
license: mit
---
For more details please refer to our github repo: https://github.com/FlagOpen/FlagEmbedding
# LLARA ([paper](https://arxiv.org/pdf/2312.15503))
In this project, we introduce LLaRA:
- EBAE: Embedding-Based Auto-Encoding.
- EBAR: Embedding-Based Auto-Regression.
## Usage
```
import torch
from transformers import AutoModel, AutoTokenizer, LlamaModel
def get_query_inputs(queries, tokenizer, max_length=512):
prefix = '"'
suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>'
prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
queries_inputs = []
for query in queries:
inputs = tokenizer(query,
return_tensors=None,
max_length=max_length,
truncation=True,
add_special_tokens=False)
inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
inputs['attention_mask'] = [1] * len(inputs['input_ids'])
queries_inputs.append(inputs)
return tokenizer.pad(
queries_inputs,
padding=True,
max_length=max_length,
pad_to_multiple_of=8,
return_tensors='pt',
)
def get_passage_inputs(passages, tokenizer, max_length=512):
prefix = '"'
suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>'
prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
passages_inputs = []
for passage in passages:
inputs = tokenizer(passage,
return_tensors=None,
max_length=max_length,
truncation=True,
add_special_tokens=False)
inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
inputs['attention_mask'] = [1] * len(inputs['input_ids'])
passages_inputs.append(inputs)
return tokenizer.pad(
passages_inputs,
padding=True,
max_length=max_length,
pad_to_multiple_of=8,
return_tensors='pt',
)
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('BAAI/LLARA-beir')
model = AutoModel.from_pretrained('BAAI/LLARA-beir')
# Define query and passage inputs
query = "What is llama?"
title = "Llama"
passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
query_input = get_query_inputs([query], tokenizer)
passage_input = get_passage_inputs([passage], tokenizer)
with torch.no_grad():
# compute query embedding
query_outputs = model(**query_input, return_dict=True, output_hidden_states=True)
query_embedding = query_outputs.hidden_states[-1][:, -8:, :]
query_embedding = torch.mean(query_embedding, dim=1)
query_embedding = torch.nn.functional.normalize(query_embedding, dim=-1)
# compute passage embedding
passage_outputs = model(**passage_input, return_dict=True, output_hidden_states=True)
passage_embeddings = passage_outputs.hidden_states[-1][:, -8:, :]
passage_embeddings = torch.mean(passage_embeddings, dim=1)
passage_embeddings = torch.nn.functional.normalize(passage_embeddings, dim=-1)
# compute similarity score
score = query_embedding @ passage_embeddings.T
print(score)
```
## Acknowledgement
Thanks to the authors of open-sourced datasets, including MSMARCO, BEIR, etc.
Thanks to the open-sourced libraries like [Pyserini](https://github.com/castorini/pyserini).
## Citation
If you find this repository useful, please consider giving a star :star: and citation
```
@misc{li2023making,
title={Making Large Language Models A Better Foundation For Dense Retrieval},
author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao},
year={2023},
eprint={2312.15503},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
``` |