0x41337's picture
Upload 5 files
9805395
# PyTorch is a library for deep learning and machine learning
import torch
# Huggingface transformer library is State-of-the-art Machine Learning for PyTorch, TensorFlow, and JAX.
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
# the model
class Model:
# The DistilBERT model and tokenizer are loaded with the pre-trained weights
# of the "distilbert-base-uncased-distilled-squad" model. The variables
# answer, context and question are initialized empty.
def __init__(self):
self.model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
self.answer = ""
self.context = ""
self.question = ""
# This method receives a question (self.question) and a context (self.context).
# It tokenizes the question and context using DistilBERT's tokenizer and
# returns the inputs to the model, which are wrapped in the self.inputs
# object.
# Then, the method performs an inference with the DistilBERT model, passing
# the inputs to the self.model(**self.inputs) method. The call to
# torch.no_grad() indicates that it is not necessary to compute
# gradients during this inference, which saves memory resources.
# After the inference, the start and end indices of the response
# within the context are obtained, using torch.argmax to find
# the indices with the highest probability. These indexes are
# used to extract the corresponding tokens from the predicted
# response.
# Finally, the response tokens are decoded using the tokenizer,
# resulting in the final response. The answer is stored in
# the self.answer variable and returned by the method.
def question_answerer(self):
self.inputs = self.tokenizer(self.question, self.context, return_tensors="pt")
# disable gradient calculations
with torch.no_grad():
self.outputs = self.model(**self.inputs)
self.answer_start_index = torch.argmax(self.outputs.start_logits)
self.answer_end_index = torch.argmax(self.outputs.end_logits)
self.predict_answer_tokens = self.inputs.input_ids[0, self.answer_start_index : self.answer_end_index + 1]
self.answer = self.tokenizer.decode(self.predict_answer_tokens)
return self.answer