File size: 3,747 Bytes
a07a75c
 
 
08bb424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
---
{}
---
# LLARA-7B-Passage

This model is fine-tuned from LLaMA-2-7B using LoRA and the embedding size is 4096.

## Training Data

The model is fine-tuned on the training split of [MS MARCO Passage Ranking](https://microsoft.github.io/msmarco/Datasets) datasets for 1 epoch. Please check our paper for details.

## Usage

Below is an example to encode a query and a passage, and then compute their similarity using their embedding.

```python
import torch
from transformers import AutoModel, AutoTokenizer, LlamaModel

def get_query_inputs(queries, tokenizer, max_length=512):
    prefix = '"'
    suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>'
    prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
    suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
    queries_inputs = []
    for query in queries:
        inputs = tokenizer(query,
                           return_tensors=None,
                           max_length=max_length,
                           truncation=True,
                           add_special_tokens=False)
        inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
        inputs['attention_mask'] = [1] * len(inputs['input_ids'])
        queries_inputs.append(inputs)
    return tokenizer.pad(
            queries_inputs,
            padding=True,
            max_length=max_length,
            pad_to_multiple_of=8,
            return_tensors='pt',
        )

def get_passage_inputs(passages, tokenizer, max_length=512):
    prefix = '"'
    suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>'
    prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
    suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
    passages_inputs = []
    for passage in passages:
        inputs = tokenizer(passage,
                           return_tensors=None,
                           max_length=max_length,
                           truncation=True,
                           add_special_tokens=False)
        inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
        inputs['attention_mask'] = [1] * len(inputs['input_ids'])
        passages_inputs.append(inputs)
    return tokenizer.pad(
            passages_inputs,
            padding=True,
            max_length=max_length,
            pad_to_multiple_of=8,
            return_tensors='pt',
        )

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('cfli/LLARA-passage')
model = AutoModel.from_pretrained('cfli/LLARA-passage')

# Define query and passage inputs
query = "What is llama?"
title = "Llama"
passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
query_input = get_query_inputs([query], tokenizer)
passage_input = get_passage_inputs([passage], tokenizer)


with torch.no_grad():
    # compute query embedding
    query_outputs = model(**query_input, return_dict=True, output_hidden_states=True)
    query_embedding = query_outputs.hidden_states[-1][:, -8:, :]
    query_embedding = torch.mean(query_embedding, dim=1)
    query_embedding = torch.nn.functional.normalize(query_embedding, dim=-1)

    # compute passage embedding
    passage_outputs = model(**passage_input, return_dict=True, output_hidden_states=True)
    passage_embeddings = passage_outputs.hidden_states[-1][:, -8:, :]
    passage_embeddings = torch.mean(passage_embeddings, dim=1)
    passage_embeddings = torch.nn.functional.normalize(passage_embeddings, dim=-1)

    # compute similarity score
    score = query_embedding @ passage_embeddings.T
    print(score)



```