irow commited on
Commit
4529fd0
1 Parent(s): d3eef20

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +79 -0
inference.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from transformers import BertTokenizerFast
4
+ import numpy as np
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ class DualBertModel: # use this class to encode your document and queries.
9
+ """
10
+ This class can encode queries and documents.
11
+ """
12
+ def __init__(self, **kwargs):
13
+ """
14
+ Loads models found in "model/passage_model_3.pt", "model/query_model_3.pt",
15
+ """
16
+ self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case = True)
17
+ # put models into eval mode
18
+ self.p_model = torch.load("model/passage_model_3.pt").eval()
19
+ self.q_model = torch.load("model/query_model_3.pt").eval()
20
+
21
+ def chunks(self, lst, n): # chunk list of strings
22
+ for i in tqdm(range(0, len(lst), n)):
23
+ yield lst[i:i+n]
24
+
25
+ def process(self, lst, query):
26
+ if query:
27
+ max_length = 30
28
+ else:
29
+ max_length = 450
30
+
31
+ encoded = self.tokenizer(
32
+ text=lst,
33
+ add_special_tokens=True,
34
+ padding='max_length',
35
+ return_attention_mask=True,
36
+ truncation=True,
37
+ max_length=max_length
38
+ )
39
+ input_ids = torch.tensor(encoded.get('input_ids')).to(device)
40
+ attention_mask = torch.tensor(encoded.get('attention_mask')).to(device)
41
+ return input_ids, attention_mask
42
+
43
+ def encode_queries(self, queries, batch_size: int, **kwargs) -> np.ndarray:
44
+ """
45
+ Encodes a list of strings (queries) and returns a list of encodings.
46
+ """
47
+ with torch.no_grad():
48
+ to_return = []
49
+ for c in self.chunks(queries, batch_size):
50
+ ids, atm = self.process(c, True)
51
+ out = self.q_model(ids, atm)
52
+ to_return.append(out.cpu())
53
+
54
+ return torch.concat(to_return)
55
+
56
+
57
+ def encode_corpus(self, corpus, batch_size: int, **kwargs) -> np.ndarray:
58
+ """
59
+ Encodes a list of strings (documents) and returns a list of encodings.
60
+ """
61
+ with torch.no_grad():
62
+ to_return = []
63
+ for c in self.chunks(corpus, batch_size):
64
+ to_encode = [d['text'] for d in c]
65
+ ids, atm = self.process(to_encode, False)
66
+ out = self.p_model(ids, atm)
67
+ to_return.append(out.cpu())
68
+
69
+ return torch.concat(to_return)
70
+
71
+ if __name__ == "__main__":
72
+ model = DualBertModel()
73
+ queries = ["What is deep learning"]
74
+ documents = ['The adjective "deep" in deep learning refers to the use of multiple layers in the network. Early work showed that a linear perceptron cannot be a universal classifier, but that a network with a nonpolynomial activation function with one hidden layer of unbounded width can. Deep learning is a modern variation which is concerned with an unbounded number of layers of bounded size, which permits practical application and optimized implementation, while retaining theoretical universality under mild conditions. In deep learning the layers are also permitted to be heterogeneous and to deviate widely from biologically informed connectionist models, for the sake of efficiency, trainability and understandability.']
75
+
76
+ encoded_queries = model.encode_queries(queries, 1)
77
+ encoded_docs = model.encode_corpus(documents, 1)
78
+
79
+ print("Distance between query-documents", torch.cdist(encoded_queries, encoded_docs))