|
from transformers import AutoTokenizer |
|
from MetaQA_Model import MetaQA_Model |
|
import numpy as np |
|
|
|
import torch |
|
|
|
class PredictionRequest(): |
|
input_question: str |
|
input_predictions: list[(str, float)] |
|
|
|
|
|
class MetaQA(): |
|
def __init__(self, path_to_model): |
|
self.metaqa_model = MetaQA_Model.from_pretrained(path_to_model) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path_to_model) |
|
|
|
def run_metaqa(self, request: PredictionRequest): |
|
''' |
|
Runs MetaQA on a single instance. |
|
''' |
|
|
|
input_ids, token_ids, attention_masks, ans_sc = self._encode_metaQA_instance(request) |
|
|
|
logits = self.metaqa_model(input_ids, token_ids, attention_masks, ans_sc).logits |
|
|
|
(pred, agent_name, metaqa_score, agent_score) = self._get_predictions(logits.detach().numpy(), request.input_predictions) |
|
return (pred, agent_name, metaqa_score, agent_score) |
|
|
|
def _encode_metaQA_instance(self, request: PredictionRequest, max_len=512): |
|
''' |
|
Creates input ids, token ids, token masks for an instance of MetaQA. |
|
''' |
|
|
|
list_input_ids = [] |
|
list_token_ids = [] |
|
list_attention_masks = [] |
|
list_ans_sc = [] |
|
|
|
|
|
|
|
list_input_ids.extend(self.tokenizer.encode("[CLS]", add_special_tokens=False)) |
|
list_input_ids.extend(self.tokenizer.encode(request.input_question, add_special_tokens=False)) |
|
list_input_ids.extend(self.tokenizer.encode("[SEP]", add_special_tokens=False)) |
|
|
|
list_token_ids.extend(len(list_input_ids) * [0]) |
|
|
|
list_ans_sc.extend(len(list_input_ids) * [0]) |
|
|
|
|
|
for qa_agent_pred in request.input_predictions: |
|
|
|
list_input_ids.append(1) |
|
ans_input_ids = self.tokenizer.encode(qa_agent_pred[0], add_special_tokens=False) |
|
list_input_ids.extend(ans_input_ids) |
|
|
|
list_token_ids.extend((len(ans_input_ids)+1) * [1]) |
|
|
|
ans_score = qa_agent_pred[1] |
|
list_ans_sc.extend((len(ans_input_ids)+1) * [ans_score]) |
|
|
|
|
|
list_input_ids.extend(self.tokenizer.encode("[SEP]", add_special_tokens=False)) |
|
|
|
list_token_ids.append(1) |
|
|
|
list_ans_sc.append(0) |
|
|
|
list_attention_masks.extend(len(list_input_ids) * [1]) |
|
|
|
|
|
len_padding = max_len - len(list_input_ids) |
|
|
|
list_input_ids.extend([0]*len_padding) |
|
|
|
list_token_ids.extend((len(list_input_ids) - len(list_token_ids)) * [1]) |
|
|
|
list_ans_sc.extend((len(list_input_ids) - len(list_ans_sc)) * [0]) |
|
|
|
list_attention_masks.extend((len(list_input_ids) - len(list_attention_masks)) * [0]) |
|
|
|
|
|
list_input_ids = torch.Tensor(list_input_ids).unsqueeze(0).long() |
|
list_token_ids = torch.Tensor(list_token_ids).unsqueeze(0).long() |
|
list_attention_masks = torch.Tensor(list_attention_masks).unsqueeze(0).long() |
|
list_ans_sc = torch.Tensor(list_ans_sc).unsqueeze(0).long() |
|
|
|
if len(list_input_ids) > max_len: |
|
return None |
|
else: |
|
return (list_input_ids, list_token_ids, list_attention_masks, list_ans_sc) |
|
|
|
def _get_predictions(self, logits, input_predictions): |
|
top_k = lambda a, k: np.argsort(-a)[:k] |
|
for idx in top_k(logits[0][:,1], self.metaqa_model.num_agents): |
|
pred = input_predictions[idx][0] |
|
if pred != '': |
|
agent_name = self.metaqa_model.config.agents[idx] |
|
metaqa_score = logits[0][idx][1] |
|
agent_score = input_predictions[idx][1] |
|
return (pred, agent_name, metaqa_score, agent_score) |
|
|
|
idx = top_k(logits[0][:,1], 1)[0] |
|
pred = input_predictions[idx][0] |
|
metaqa_score = logits[0][idx][1] |
|
agent_name = self.metaqa_model.config.agents[idx] |
|
agent_score = input_predictions[idx][1] |
|
return (pred, agent_name, metaqa_score, agent_score) |
|
|