Spaces:
Sleeping
Sleeping
# PyTorch is a library for deep learning and machine learning | |
import torch | |
# Huggingface transformer library is State-of-the-art Machine Learning for PyTorch, TensorFlow, and JAX. | |
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering | |
# the model | |
class Model: | |
# The DistilBERT model and tokenizer are loaded with the pre-trained weights | |
# of the "distilbert-base-uncased-distilled-squad" model. The variables | |
# answer, context and question are initialized empty. | |
def __init__(self): | |
self.model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad") | |
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad") | |
self.answer = "" | |
self.context = "" | |
self.question = "" | |
# This method receives a question (self.question) and a context (self.context). | |
# It tokenizes the question and context using DistilBERT's tokenizer and | |
# returns the inputs to the model, which are wrapped in the self.inputs | |
# object. | |
# Then, the method performs an inference with the DistilBERT model, passing | |
# the inputs to the self.model(**self.inputs) method. The call to | |
# torch.no_grad() indicates that it is not necessary to compute | |
# gradients during this inference, which saves memory resources. | |
# After the inference, the start and end indices of the response | |
# within the context are obtained, using torch.argmax to find | |
# the indices with the highest probability. These indexes are | |
# used to extract the corresponding tokens from the predicted | |
# response. | |
# Finally, the response tokens are decoded using the tokenizer, | |
# resulting in the final response. The answer is stored in | |
# the self.answer variable and returned by the method. | |
def question_answerer(self): | |
self.inputs = self.tokenizer(self.question, self.context, return_tensors="pt") | |
# disable gradient calculations | |
with torch.no_grad(): | |
self.outputs = self.model(**self.inputs) | |
self.answer_start_index = torch.argmax(self.outputs.start_logits) | |
self.answer_end_index = torch.argmax(self.outputs.end_logits) | |
self.predict_answer_tokens = self.inputs.input_ids[0, self.answer_start_index : self.answer_end_index + 1] | |
self.answer = self.tokenizer.decode(self.predict_answer_tokens) | |
return self.answer | |