|
import torch |
|
from tqdm import tqdm |
|
from transformers import BertTokenizerFast |
|
import numpy as np |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class DualBertModel: |
|
""" |
|
This class can encode queries and documents. |
|
""" |
|
def __init__(self, **kwargs): |
|
""" |
|
Loads models found in "model/passage_model_3.pt", "model/query_model_3.pt", |
|
""" |
|
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case = True) |
|
|
|
self.p_model = torch.load("model/passage_model_3.pt").eval() |
|
self.q_model = torch.load("model/query_model_3.pt").eval() |
|
|
|
def chunks(self, lst, n): |
|
for i in tqdm(range(0, len(lst), n)): |
|
yield lst[i:i+n] |
|
|
|
def process(self, lst, query): |
|
if query: |
|
max_length = 30 |
|
else: |
|
max_length = 450 |
|
|
|
encoded = self.tokenizer( |
|
text=lst, |
|
add_special_tokens=True, |
|
padding='max_length', |
|
return_attention_mask=True, |
|
truncation=True, |
|
max_length=max_length |
|
) |
|
input_ids = torch.tensor(encoded.get('input_ids')).to(device) |
|
attention_mask = torch.tensor(encoded.get('attention_mask')).to(device) |
|
return input_ids, attention_mask |
|
|
|
def encode_queries(self, queries, batch_size: int, **kwargs): |
|
""" |
|
Encodes a list of strings (queries) and returns a list of encodings. |
|
""" |
|
with torch.no_grad(): |
|
to_return = [] |
|
for c in self.chunks(queries, batch_size): |
|
ids, atm = self.process(c, True) |
|
out = self.q_model(ids, atm) |
|
to_return.append(out.cpu()) |
|
|
|
return torch.concat(to_return) |
|
|
|
|
|
def encode_corpus(self, corpus, batch_size: int, **kwargs): |
|
""" |
|
Encodes a list of strings (documents) and returns a list of encodings. |
|
""" |
|
with torch.no_grad(): |
|
to_return = [] |
|
for c in self.chunks(corpus, batch_size): |
|
|
|
ids, atm = self.process(c, False) |
|
out = self.p_model(ids, atm) |
|
to_return.append(out.cpu()) |
|
|
|
return torch.concat(to_return) |
|
|
|
if __name__ == "__main__": |
|
model = DualBertModel() |
|
queries = ["What is deep learning"] |
|
documents = ['The adjective "deep" in deep learning refers to the use of multiple layers in the network. Early work showed that a linear perceptron cannot be a universal classifier, but that a network with a nonpolynomial activation function with one hidden layer of unbounded width can. Deep learning is a modern variation which is concerned with an unbounded number of layers of bounded size, which permits practical application and optimized implementation, while retaining theoretical universality under mild conditions. In deep learning the layers are also permitted to be heterogeneous and to deviate widely from biologically informed connectionist models, for the sake of efficiency, trainability and understandability.'] |
|
|
|
encoded_queries = model.encode_queries(queries, 1) |
|
encoded_docs = model.encode_corpus(documents, 1) |
|
|
|
print("Distance between query-documents", torch.cdist(encoded_queries, encoded_docs)) |