prokbert-mini / prokbert_tokenizer.py
Ligeti Balázs
Tokenizer base
12bee07
# coding=utf-8
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/tokenization_bert.py
# ProkBERT tokenizer stuff
import collections
import os
import unicodedata
from typing import List, Optional, Tuple, Union
from copy import deepcopy
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils import _is_control, _is_punctuation, _is_whitespace
from transformers.utils import logging
# These utils contains the tools needed by the ProkBERT tokenizer
from config_utils import *
from sequtils import *
import logging as logger
#logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
# models prokbert-mini-k6s1, prokbert-large-k6s2, prokbert-large-k6s1
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"prokbert-mini-k6s1": "prokbert-base-dna6/vocab.txt",
"prokbert-large-k6s1": "prokbert-base-dna6/vocab.txt",
"prokbert-large-k6s2": "prokbert-base-dna6/vocab.txt"
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"prokbert-mini-k6s1": 1024,
"prokbert-large-k6s1": 1024,
"prokbert-large-k6s2": 1024
}
PRETRAINED_INIT_CONFIGURATION = {
"prokbert-mini-k6s1": {"do_upper_case": True},
"prokbert-large-k6s1": {"do_upper_case": True},
"prokbert-large-k6s2": {"do_upper_case": True}
}
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab
class ProkBERTTokenizer(PreTrainedTokenizer):
"""Custom tokenizer for ProkBERT."""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
nucleotide_abc = {'A', 'T', 'C', 'G'}
extended_nucleotide_abc = {'A', 'T', 'C', 'G', '*'}
sequence_unk_token = 'N'
default_unk_token="[UNK]"
default_sep_token="[SEP]"
default_pad_token="[PAD]"
default_cls_token="[CLS]"
default_mask_token="[MASK]"
def __init__(self,
tokenization_params: Dict = {},
segmentation_params: Dict = {},
comp_params: Dict = {},
operation_space: str = 'sequence',
**kwargs):
"""Initialize the ProkBERT tokenizer.
Args:
tokenization_params (Dict, optional): Tokenization parameters. Defaults to {}.
segmentation_params (Dict, optional): Segmentation parameters. Defaults to {}.
comp_params (Dict, optional): Computational parameters. Defaults to {}.
operation_space (str, optional): Specifies the operation mode. Can be 'kmer' or 'sequence'. Defaults to 'kmer'.
"""
super().__init__(cls_token=ProkBERTTokenizer.default_cls_token,
**kwargs)
self.defconfig = SeqConfig()
self.tokenization_params = self.defconfig.get_and_set_tokenization_parameters(tokenization_params)
self.segmentation_params = self.defconfig.get_and_set_segmentation_parameters(segmentation_params)
self.comp_params = self.defconfig.get_and_set_computational_parameters(comp_params)
self.operation_space = operation_space
vocab_file = self.tokenization_params['vocabfile']
self.vocab = self.tokenization_params['vocabmap']
self.id2token = {v: k for k, v in self.vocab.items()}
self.max_len = self.tokenization_params['max_segment_length']
if self.operation_space == 'sequence':
token_extension = sorted(list(set(generate_kmers(ProkBERTTokenizer.extended_nucleotide_abc, self.tokenization_params['kmer'])) - \
set(generate_kmers(ProkBERTTokenizer.nucleotide_abc, self.tokenization_params['kmer'])) ))
self.extended_vocab = deepcopy(self.vocab)
for token in token_extension:
self.extended_vocab[token] = 4
self.unk_token = ProkBERTTokenizer.sequence_unk_token * self.tokenization_params['shift']
self.mask_token = '*'
self.extended_vocab[self.mask_token] = self.vocab['[MASK]']
full_unk = 'N' * self.tokenization_params['kmer']
self.vocab[full_unk] = 1
self.id2token[1] = full_unk
self.full_unk_token = full_unk
else:
self.extended_vocab = self.vocab
self.unk_token = '[UNK]'
self.sep_token = '[SEP]'
self.cls_token = '[CLS]'
self.pad_token = '[PAD]'
self.mask_token = '[MASK]'
self.special_tokens = list(self.special_tokens_map.values())
def __len__(self) -> int:
return len(self.vocab)
def tokenize(self, text: str, lca_shift: int = 0, all: bool = False) -> Union[List[str], Tuple[List[List[str]], List[List[str]]]]:
"""
Tokenizes a given segment.
Args:
text (str): The DNA segment to tokenize.
lca_shift (int, optional): Which tokenized vector belonging to the specified LCA offset should be returned. Defaults to 0.
all (bool, optional): If True, returns all possible tokenizations. Defaults to False.
Returns:
Union[List[str], Tuple[List[List[str]], List[List[str]]]]: Tokenized segment or tuple of all possible tokenizations.
Usage Example:
>>> tokenizer = ProkBERTTokenizer(...)
>>> segment = 'AATCAAGGAATTATTATCGTT'
>>> tokens, kmers = tokenizer.tokenize(segment, all=True)
>>> print(tokens)
...
"""
tokenized_segments, kmerized_segments = lca_tokenize_segment(text, self.tokenization_params)
if all:
return tokenized_segments, kmerized_segments
else:
return kmerized_segments[lca_shift]
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
def depr_convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> List[str]:
"""
Converts tokens to their corresponding IDs.
Args:
tokens (List[str]): List of tokens to convert.
Returns:
List[int]: List of corresponding token IDs.
Usage Example:
>>> tokenizer = ProkBERTTokenizer(...)
>>> tokens = ['AATCAA', 'TCAAGG']
>>> ids = tokenizer.convert_tokens_to_ids(tokens)
>>> print(ids)
...
"""
if isinstance(ids, int):
token_ids = self.vocab.get(ids, self.vocab[self.unk_token])
if self.operation_space == 'sequence':
token_ids = [self.vocab.get(token, self.vocab[self.full_unk_token]) for token in tokens]
else:
token_ids = [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
return token_ids
def convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> Union[str, List[str]]:
"""
Converts token IDs back to their original tokens.
Args:
ids (List[int]): List of token IDs to convert.
Returns:
List[str]: List of corresponding tokens.
Usage Example:
>>> tokenizer = ProkBERTTokenizer(...)
>>> ids = [213, 3343]
>>> tokens = tokenizer.convert_ids_to_tokens(ids)
>>> print(tokens)
...
"""
if isinstance(ids, int):
ids = [ids]
if len(ids) == 1:
#default_token_list = [self.id2token.get(ids[0], self.unk_token)]
return self.id2token.get(ids[0], self.unk_token)
if self.operation_space == 'kmer':
token_list = [self.id2token.get(id, self.unk_token) for id in ids]
elif self.operation_space == 'sequence':
token_list = []
# Handling the sentence start
if ids[0] == 2:
pass
else:
token_list.append(self.id2token.get(ids[0], self.unk_token))
if len(ids) > 1:
# if this is a kmer then we add accordingly.
true_start_token = self.id2token.get(ids[1], self.unk_token)
token_list.append(true_start_token)
print(token_list)
if len(ids) >2:
# Adding the other tokens until the end
for token_id in ids[2:]:
mapped_token_id = self.id2token.get(token_id, self.unk_token)
if (mapped_token_id in self.special_tokens):
act_token_value = ''
else:
act_token_value = mapped_token_id[-1*self.tokenization_params['shift']:]
token_list.append(act_token_value)
return token_list
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""Saves the vocabulary to a file."""
if filename_prefix is None:
filename_prefix = ""
vocab_file_path = os.path.join(save_directory, filename_prefix + "vocab.txt")
with open(vocab_file_path, "w") as f:
for token in self.vocab:
f.write(token + "\\n")
return (vocab_file_path,)
@classmethod
def from_pretrained(cls, vocab_file: str) -> 'ProkBERTTokenizer':
"""Loads a pre-trained tokenizer.
Args:
vocab_file (str): Path to the pre-trained tokenizer vocabulary file.
Returns:
ProkBERTTokenizer: Loaded tokenizer instance.
"""
return cls(vocab_file)
def encode_plus(self, text: str, lca_shift: int = 0, **kwargs) -> Dict[str, np.ndarray]:
"""
Tokenizes a sequence and returns it in a format suitable for model input.
Args:
text (str): The sequence to tokenize.
lca_shift (int, optional): LCA offset for tokenization. Defaults to 0.
Returns:
Dict[str, np.ndarray]: Dictionary containing token IDs and attention masks.
Usage Example:
>>> tokenizer = ProkBERTTokenizer(...)
>>> segment = 'AATCAAGGAATTATTATCGTT'
>>> encoded = tokenizer.encode_plus(segment)
>>> print(encoded)
...
"""
tokenized_segments, kmerized_segments = lca_tokenize_segment(text, self.tokenization_params)
input_ids = tokenized_segments[lca_shift]
attention_mask = [1] * len(input_ids)
# Padding
while len(input_ids) < self.max_len:
input_ids.append(0)
attention_mask.append(0)
return {
"input_ids": np.array(input_ids, dtype=self.comp_params['np_tokentype']),
"attention_mask": np.array(attention_mask, dtype=self.comp_params['np_tokentype'])
}
def batch_encode_plus(self, sequences: List[str], lca_shift: int = 0, all: bool = False, **kwargs) -> Dict[str, List[List[int]]]:
"""
Tokenizes multiple sequences and returns them in a format suitable for model input. It is assumed that sequences
have already been preprocessed (i.e., segmented) and quality controlled.
Args:
- sequences (List[str]): A list of DNA sequences to be tokenized.
- lca_shift (int, default=0): The LCA offset or windows to get the tokenized vector. If the required offset is >= shift,
an error is raised.
- all (bool, default=False): Whether all possible tokenization vectors should be returned. If False, only the specified
offset is used.
- **kwargs: Additional arguments (like max_length, padding, etc.)
Returns:
- Dict[str, List[List[int]]]: A dictionary containing token IDs, attention masks, and token type IDs.
"""
shift = self.tokenization_params['shift']
if lca_shift >= shift:
raise ValueError(f'The required offset {lca_shift} is invalid. The maximum offset should be < {shift}')
# Parallel tokenization. First, create unique IDs for all sequences.
sequence_ids = list(range(len(sequences)))
to_tokenize_data = (sequences, sequence_ids)
# Tokenize each sequence
tokenization_results = batch_tokenize_segments_with_ids(
to_tokenize_data,
self.tokenization_params,
self.comp_params['cpu_cores_for_tokenization'],
self.comp_params['batch_size_tokenization'],
self.comp_params['np_tokentype']
)
# Generate input ids, token type ids, and attention masks
input_ids = []
token_type_ids = []
attention_masks = []
if all:
for tokenized_vectors in tokenization_results.values():
for tokenized_vector in tokenized_vectors:
input_ids.append(tokenized_vector)
token_type_ids.append([0] * len(tokenized_vector))
attention_masks.append([1] * len(tokenized_vector))
else:
for tokenized_vectors in tokenization_results.values():
selected_vector = tokenized_vectors[lca_shift]
input_ids.append(selected_vector)
token_type_ids.append([0] * len(selected_vector))
attention_masks.append([1] * len(selected_vector))
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_masks
}
def encode(self, segment: str, lca_shift: int = 0, all: bool = False, add_special_tokens: bool = True, **kwargs) -> List[int]:
"""
Encode a DNA sequence into its corresponding token IDs.
Args:
text (str): The DNA segment to encode.
add_special_tokens (bool, optional): Whether to add special tokens like [CLS] and [SEP]. Defaults to True.
Returns:
List[int]: Encoded token IDs.
Usage Example:
>>> tokenizer = ProkBERTTokenizer(...)
>>> segment = 'AATCAAGGAATTATTATCGTT'
>>> ids = tokenizer.encode(segment)
>>> print(ids)
...
"""
shift = self.tokenization_params['shift']
if lca_shift >= shift:
raise ValueError(f'The required offset {lca_shift} is invalid. The maximum offset should be < {shift}')
tokenized_segments, _ = lca_tokenize_segment(segment, self.tokenization_params)
# if all is set to True, then we return all the possible ids as a list
if all:
token_ids = tokenized_segments
if not add_special_tokens:
new_token_ids = []
for token_id_set in tokenized_segments:
new_token_ids.append(token_id_set[1:len(token_id_set)-1])
token_ids = new_token_ids
else:
token_ids = tokenized_segments[lca_shift]
# Convert tokens to their corresponding IDs
# Add special tokens if needed
if not add_special_tokens:
token_ids = token_ids[1:len(token_ids)-1]
return token_ids
def decode(self, ids):
tokens = self.convert_ids_to_tokens(ids)
return ''.join(tokens)
def batch_decode(self, token_ids_list: List[List[int]], **kwargs) -> List[str]:
"""
Decodes multiple token ID sequences back into their original sequences.
Args:
token_ids_list (List[List[int]]): List of token ID sequences.
Returns:
List[str]: List of decoded sequences.
Usage Example:
>>> tokenizer = ProkBERTTokenizer(...)
>>> ids = [[2, 213, 3343, 165, 2580, 248, 3905, 978, 3296, 3]]
>>> sequences = tokenizer.batch_decode(ids)
>>> print(sequences)
...
"""
return [self.decode(token_ids) for token_ids in token_ids_list]