|
import re |
|
import os |
|
import fire |
|
import torch |
|
from functools import partial |
|
from transformers import AutoTokenizer |
|
from transformers import AutoModelForPreTraining |
|
from pya0.preprocess import preprocess_for_transformer |
|
|
|
|
|
def highlight_masked(txt): |
|
return re.sub(r"(\[MASK\])", '\033[92m' + r"\1" + '\033[0m', txt) |
|
|
|
|
|
def classifier_hook(tokenizer, tokens, topk, module, inputs, outputs): |
|
unmask_scores, seq_rel_scores = outputs |
|
MSK_CODE = 103 |
|
token_ids = tokens['input_ids'][0] |
|
masked_idx = (token_ids == torch.tensor([MSK_CODE])) |
|
scores = unmask_scores[0][masked_idx] |
|
cands = torch.argsort(scores, dim=1, descending=True) |
|
for i, mask_cands in enumerate(cands): |
|
top_cands = mask_cands[:topk].detach().cpu() |
|
print(f'MASK[{i}] top candidates: ' + |
|
str(tokenizer.convert_ids_to_tokens(top_cands))) |
|
|
|
|
|
def test(tokenizer_name_or_path, model_name_or_path, test_file='test.txt'): |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) |
|
model = AutoModelForPreTraining.from_pretrained(model_name_or_path, |
|
tie_word_embeddings=True |
|
) |
|
with open(test_file, 'r') as fh: |
|
for line in fh: |
|
|
|
line = line.rstrip() |
|
fields = line.split('\t') |
|
maskpos = list(map(int, fields[0].split(','))) |
|
|
|
sentence = preprocess_for_transformer(fields[1]) |
|
tokens = sentence.split() |
|
for pos in filter(lambda x: x!=0, maskpos): |
|
tokens[pos-1] = '[MASK]' |
|
sentence = ' '.join(tokens) |
|
sentence = sentence.replace('[mask]', '[MASK]') |
|
tokens = tokenizer(sentence, |
|
padding=True, truncation=True, return_tensors="pt") |
|
|
|
print('*', highlight_masked(sentence)) |
|
|
|
with torch.no_grad(): |
|
display = ['\n', ''] |
|
classifier = model.cls |
|
partial_hook = partial(classifier_hook, tokenizer, tokens, 3) |
|
hook = classifier.register_forward_hook(partial_hook) |
|
model(**tokens) |
|
hook.remove() |
|
|
|
|
|
if __name__ == '__main__': |
|
os.environ["PAGER"] = 'cat' |
|
fire.Fire(test) |
|
|