File size: 1,654 Bytes
a039d5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from typing import List, Union
from transformers import BertForMaskedLM, BertTokenizerFast

class BertForLexPrediction(BertForMaskedLM):

    def __init__(self, config):
        super().__init__(config)
    
    def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast):
        if isinstance(sentences, str):
            sentences = [sentences]
        
        # predict the logits for the sentence
        inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
        inputs = {k:v.to(self.device) for k,v in inputs.items()}
        logits = self.forward(**inputs, return_dict=True).logits

        # for each token, we will take the top 10, and search for one that is appropriate. If none, then
        # return a [BLANK] for that word.
        input_ids = inputs['input_ids']
        batch_ret = []
        for batch_idx in range(len(sentences)):
            ret = []
            batch_ret.append(ret)
            for tok_idx in range(input_ids.shape[1]):
                token_id = input_ids[batch_idx, tok_idx]
                # ignore cls, sep, pad
                if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue

                token = tokenizer._convert_id_to_token(token_id)
                # wordpieces should just be appended to the previous word
                if token.startswith('##'):
                    ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
                    continue
                ret.append((token, tokenizer._convert_id_to_token(torch.argmax(logits[batch_idx, tok_idx]))))
        return batch_ret