Spaces:
Runtime error
Runtime error
import torch | |
# from transformers import BertTokenizerFast | |
from colbert.modeling.hf_colbert import HF_ColBERT | |
from colbert.infra import ColBERTConfig | |
from colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length | |
class DocTokenizer(): | |
def __init__(self, config: ColBERTConfig): | |
self.tok = HF_ColBERT.raw_tokenizer_from_pretrained(config.checkpoint) | |
self.config = config | |
self.doc_maxlen = config.doc_maxlen | |
self.D_marker_token, self.D_marker_token_id = '[D]', self.tok.convert_tokens_to_ids('[unused1]') | |
self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id | |
self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id | |
# assert self.D_marker_token_id == 2 | |
def tokenize(self, batch_text, add_special_tokens=False): | |
assert type(batch_text) in [list, tuple], (type(batch_text)) | |
tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text] | |
if not add_special_tokens: | |
return tokens | |
prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token] | |
tokens = [prefix + lst + suffix for lst in tokens] | |
return tokens | |
def encode(self, batch_text, add_special_tokens=False): | |
assert type(batch_text) in [list, tuple], (type(batch_text)) | |
ids = self.tok(batch_text, add_special_tokens=False)['input_ids'] | |
if not add_special_tokens: | |
return ids | |
prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id] | |
ids = [prefix + lst + suffix for lst in ids] | |
return ids | |
def tensorize(self, batch_text, bsize=None): | |
assert type(batch_text) in [list, tuple], (type(batch_text)) | |
# add placehold for the [D] marker | |
batch_text = ['. ' + x for x in batch_text] | |
obj = self.tok(batch_text, padding='longest', truncation='longest_first', | |
return_tensors='pt', max_length=self.doc_maxlen) | |
ids, mask = obj['input_ids'], obj['attention_mask'] | |
# postprocess for the [D] marker | |
ids[:, 1] = self.D_marker_token_id | |
if bsize: | |
ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize) | |
batches = _split_into_batches(ids, mask, bsize) | |
return batches, reverse_indices | |
return ids, mask | |