|
import torch |
|
from typing import List, Union |
|
from transformers import BertForMaskedLM, BertTokenizerFast |
|
|
|
class BertForLexPrediction(BertForMaskedLM): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast): |
|
if isinstance(sentences, str): |
|
sentences = [sentences] |
|
|
|
|
|
inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt') |
|
inputs = {k:v.to(self.device) for k,v in inputs.items()} |
|
logits = self.forward(**inputs, return_dict=True).logits |
|
|
|
|
|
|
|
input_ids = inputs['input_ids'] |
|
batch_ret = [] |
|
for batch_idx in range(len(sentences)): |
|
ret = [] |
|
batch_ret.append(ret) |
|
for tok_idx in range(input_ids.shape[1]): |
|
token_id = input_ids[batch_idx, tok_idx] |
|
|
|
if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue |
|
|
|
token = tokenizer._convert_id_to_token(token_id) |
|
|
|
if token.startswith('##'): |
|
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1]) |
|
continue |
|
ret.append((token, tokenizer._convert_id_to_token(torch.argmax(logits[batch_idx, tok_idx])))) |
|
return batch_ret |