Spaces:
Runtime error
Runtime error
import torch | |
from transformers import ElectraTokenizerFast | |
class AnswerAwareTokenizer(): | |
def __init__(self, total_maxlen, bert_model='google/electra-base-discriminator'): | |
self.total_maxlen = total_maxlen | |
self.tok = ElectraTokenizerFast.from_pretrained(bert_model) | |
def process(self, questions, passages, all_answers=None, mask=None): | |
return TokenizationObject(self, questions, passages, all_answers, mask) | |
def tensorize(self, questions, passages): | |
query_lengths = self.tok(questions, padding='longest', return_tensors='pt').attention_mask.sum(-1) | |
encoding = self.tok(questions, passages, padding='longest', truncation='longest_first', | |
return_tensors='pt', max_length=self.total_maxlen, add_special_tokens=True) | |
return encoding, query_lengths | |
def get_all_candidates(self, encoding, index): | |
offsets, endpositions = self.all_word_positions(encoding, index) | |
candidates = [(offset, endpos) | |
for idx, offset in enumerate(offsets) | |
for endpos in endpositions[idx:idx+10]] | |
return candidates | |
def all_word_positions(self, encoding, index): | |
words = encoding.word_ids(index) | |
offsets = [position | |
for position, (last_word_number, current_word_number) in enumerate(zip([-1] + words, words)) | |
if last_word_number != current_word_number] | |
endpositions = offsets[1:] + [len(words)] | |
return offsets, endpositions | |
def characters_to_tokens(self, text, answers, encoding, index, offset, endpos): | |
# print(text, answers, encoding, index, offset, endpos) | |
# endpos = endpos - 1 | |
for offset_ in range(offset, len(text)+1): | |
tokens_offset = encoding.char_to_token(index, offset_) | |
# print(f'tokens_offset = {tokens_offset}') | |
if tokens_offset is not None: | |
break | |
for endpos_ in range(endpos, len(text)+1): | |
tokens_endpos = encoding.char_to_token(index, endpos_) | |
# print(f'tokens_endpos = {tokens_endpos}') | |
if tokens_endpos is not None: | |
break | |
# None on whitespace! | |
assert tokens_offset is not None, (text, answers, offset) | |
# assert tokens_endpos is not None, (text, answers, endpos) | |
tokens_endpos = tokens_endpos if tokens_endpos is not None else len(encoding.tokens(index)) | |
return tokens_offset, tokens_endpos | |
def tokens_to_answer(self, encoding, index, text, tokens_offset, tokens_endpos): | |
# print(encoding, index, text, tokens_offset, tokens_endpos, len(encoding.tokens(index))) | |
char_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_offset)).start | |
try: | |
char_next_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos)).start | |
char_endpos = char_next_offset | |
except: | |
char_endpos = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos-1)).end | |
assert char_offset is not None | |
assert char_endpos is not None | |
return text[char_offset:char_endpos].strip() | |
class TokenizationObject(): | |
def __init__(self, tokenizer: AnswerAwareTokenizer, questions, passages, answers=None, mask=None): | |
assert type(questions) is list and type(passages) is list | |
assert len(questions) in [1, len(passages)] | |
if mask is None: | |
mask = [True for _ in passages] | |
self.mask = mask | |
self.tok = tokenizer | |
self.questions = questions if len(questions) == len(passages) else questions * len(passages) | |
self.passages = passages | |
self.answers = answers | |
self.encoding, self.query_lengths = self._encode() | |
self.passages_only_encoding, self.candidates, self.candidates_list = self._candidize() | |
if answers is not None: | |
self.gold_candidates = self.answers # self._answerize() | |
def _encode(self): | |
return self.tok.tensorize(self.questions, self.passages) | |
def _candidize(self): | |
encoding = self.tok.tok(self.passages, add_special_tokens=False) | |
all_candidates = [self.tok.get_all_candidates(encoding, index) for index in range(len(self.passages))] | |
bsize, maxcands = len(self.passages), max(map(len, all_candidates)) | |
all_candidates = [cands + [(-1, -1)] * (maxcands - len(cands)) for cands in all_candidates] | |
candidates = torch.tensor(all_candidates) | |
assert candidates.size() == (bsize, maxcands, 2), (candidates.size(), (bsize, maxcands, 2), (self.questions, self.passages)) | |
candidates = candidates + self.query_lengths.unsqueeze(-1).unsqueeze(-1) | |
return encoding, candidates, all_candidates | |