cfli commited on
Commit
e2ee5b6
1 Parent(s): 55b132c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +27 -20
README.md CHANGED
@@ -14,40 +14,46 @@ Below is an example to encode a query and a passage, and then compute their simi
14
  import torch
15
  from transformers import AutoModel, AutoTokenizer, LlamaModel
16
 
17
- def get_query_inputs(query, tokenizer, max_length=512):
18
  prefix = '"'
19
  suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>'
20
  prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
21
  suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
22
- inputs = tokenizer(query,
23
- return_tensors=None,
24
- max_length=max_length,
25
- truncation=True,
26
- add_special_tokens=False)
27
- inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
28
- inputs['attention_mask'] = [1] * len(inputs['input_ids'])
 
 
 
29
  return tokenizer.pad(
30
- [inputs],
31
  padding=True,
32
  max_length=max_length,
33
  pad_to_multiple_of=8,
34
  return_tensors='pt',
35
  )
36
 
37
- def get_passage_inputs(passage, tokenizer, max_length=512):
38
  prefix = '"'
39
  suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>'
40
  prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
41
  suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
42
- inputs = tokenizer(passage,
43
- return_tensors=None,
44
- max_length=max_length,
45
- truncation=True,
46
- add_special_tokens=False)
47
- inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
48
- inputs['attention_mask'] = [1] * len(inputs['input_ids'])
 
 
 
49
  return tokenizer.pad(
50
- [inputs],
51
  padding=True,
52
  max_length=max_length,
53
  pad_to_multiple_of=8,
@@ -62,8 +68,8 @@ model = AutoModel.from_pretrained('cfli/LLARA-beir')
62
  query = "What is llama?"
63
  title = "Llama"
64
  passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
65
- query_input = get_query_inputs(query, tokenizer)
66
- passage_input = get_passage_inputs(passage, tokenizer)
67
 
68
 
69
  with torch.no_grad():
@@ -84,4 +90,5 @@ with torch.no_grad():
84
  print(score)
85
 
86
 
 
87
  ```
 
14
  import torch
15
  from transformers import AutoModel, AutoTokenizer, LlamaModel
16
 
17
+ def get_query_inputs(queries, tokenizer, max_length=512):
18
  prefix = '"'
19
  suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>'
20
  prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
21
  suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
22
+ queries_inputs = []
23
+ for query in queries:
24
+ inputs = tokenizer(query,
25
+ return_tensors=None,
26
+ max_length=max_length,
27
+ truncation=True,
28
+ add_special_tokens=False)
29
+ inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
30
+ inputs['attention_mask'] = [1] * len(inputs['input_ids'])
31
+ queries_inputs.append(inputs)
32
  return tokenizer.pad(
33
+ queries_inputs,
34
  padding=True,
35
  max_length=max_length,
36
  pad_to_multiple_of=8,
37
  return_tensors='pt',
38
  )
39
 
40
+ def get_passage_inputs(passages, tokenizer, max_length=512):
41
  prefix = '"'
42
  suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>'
43
  prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
44
  suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
45
+ passages_inputs = []
46
+ for passage in passages:
47
+ inputs = tokenizer(passage,
48
+ return_tensors=None,
49
+ max_length=max_length,
50
+ truncation=True,
51
+ add_special_tokens=False)
52
+ inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
53
+ inputs['attention_mask'] = [1] * len(inputs['input_ids'])
54
+ passages_inputs.append(inputs)
55
  return tokenizer.pad(
56
+ passages_inputs,
57
  padding=True,
58
  max_length=max_length,
59
  pad_to_multiple_of=8,
 
68
  query = "What is llama?"
69
  title = "Llama"
70
  passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
71
+ query_input = get_query_inputs([query], tokenizer)
72
+ passage_input = get_passage_inputs([passage], tokenizer)
73
 
74
 
75
  with torch.no_grad():
 
90
  print(score)
91
 
92
 
93
+
94
  ```