dictabert-lex / BertForLexPrediction.py
Shaltiel's picture
Upload BertForLexPrediction.py
a039d5e
import torch
from typing import List, Union
from transformers import BertForMaskedLM, BertTokenizerFast
class BertForLexPrediction(BertForMaskedLM):
def __init__(self, config):
super().__init__(config)
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast):
if isinstance(sentences, str):
sentences = [sentences]
# predict the logits for the sentence
inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
inputs = {k:v.to(self.device) for k,v in inputs.items()}
logits = self.forward(**inputs, return_dict=True).logits
# for each token, we will take the top 10, and search for one that is appropriate. If none, then
# return a [BLANK] for that word.
input_ids = inputs['input_ids']
batch_ret = []
for batch_idx in range(len(sentences)):
ret = []
batch_ret.append(ret)
for tok_idx in range(input_ids.shape[1]):
token_id = input_ids[batch_idx, tok_idx]
# ignore cls, sep, pad
if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue
token = tokenizer._convert_id_to_token(token_id)
# wordpieces should just be appended to the previous word
if token.startswith('##'):
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
continue
ret.append((token, tokenizer._convert_id_to_token(torch.argmax(logits[batch_idx, tok_idx]))))
return batch_ret