dual-bert-IR / inference.py
irow's picture
Update inference.py, for more general use.
b6c11fd
import torch
from tqdm import tqdm
from transformers import BertTokenizerFast
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DualBertModel: # use this class to encode your document and queries.
"""
This class can encode queries and documents.
"""
def __init__(self, **kwargs):
"""
Loads models found in "model/passage_model_3.pt", "model/query_model_3.pt",
"""
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case = True)
# put models into eval mode
self.p_model = torch.load("model/passage_model_3.pt").eval()
self.q_model = torch.load("model/query_model_3.pt").eval()
def chunks(self, lst, n): # chunk list of strings
for i in tqdm(range(0, len(lst), n)):
yield lst[i:i+n]
def process(self, lst, query):
if query:
max_length = 30
else:
max_length = 450
encoded = self.tokenizer(
text=lst,
add_special_tokens=True,
padding='max_length',
return_attention_mask=True,
truncation=True,
max_length=max_length
)
input_ids = torch.tensor(encoded.get('input_ids')).to(device)
attention_mask = torch.tensor(encoded.get('attention_mask')).to(device)
return input_ids, attention_mask
def encode_queries(self, queries, batch_size: int, **kwargs):
"""
Encodes a list of strings (queries) and returns a list of encodings.
"""
with torch.no_grad():
to_return = []
for c in self.chunks(queries, batch_size):
ids, atm = self.process(c, True)
out = self.q_model(ids, atm)
to_return.append(out.cpu())
return torch.concat(to_return)
def encode_corpus(self, corpus, batch_size: int, **kwargs):
"""
Encodes a list of strings (documents) and returns a list of encodings.
"""
with torch.no_grad():
to_return = []
for c in self.chunks(corpus, batch_size):
# to_encode = [d['text'] for d in c]
ids, atm = self.process(c, False)
out = self.p_model(ids, atm)
to_return.append(out.cpu())
return torch.concat(to_return)
if __name__ == "__main__":
model = DualBertModel()
queries = ["What is deep learning"]
documents = ['The adjective "deep" in deep learning refers to the use of multiple layers in the network. Early work showed that a linear perceptron cannot be a universal classifier, but that a network with a nonpolynomial activation function with one hidden layer of unbounded width can. Deep learning is a modern variation which is concerned with an unbounded number of layers of bounded size, which permits practical application and optimized implementation, while retaining theoretical universality under mild conditions. In deep learning the layers are also permitted to be heterogeneous and to deviate widely from biologically informed connectionist models, for the sake of efficiency, trainability and understandability.']
encoded_queries = model.encode_queries(queries, 1)
encoded_docs = model.encode_corpus(documents, 1)
print("Distance between query-documents", torch.cdist(encoded_queries, encoded_docs))