|
|
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
import os |
|
import unicodedata |
|
from typing import List, Optional, Tuple, Union |
|
from copy import deepcopy |
|
from transformers import PreTrainedTokenizer |
|
from transformers.tokenization_utils import _is_control, _is_punctuation, _is_whitespace |
|
from transformers.utils import logging |
|
|
|
|
|
|
|
from config_utils import * |
|
from sequtils import * |
|
|
|
import logging as logger |
|
|
|
|
|
|
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} |
|
|
|
|
|
|
|
|
|
PRETRAINED_VOCAB_FILES_MAP = { |
|
"vocab_file": { |
|
"prokbert-mini-k6s1": "prokbert-base-dna6/vocab.txt", |
|
"prokbert-large-k6s1": "prokbert-base-dna6/vocab.txt", |
|
"prokbert-large-k6s2": "prokbert-base-dna6/vocab.txt" |
|
} |
|
} |
|
|
|
|
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { |
|
"prokbert-mini-k6s1": 1024, |
|
"prokbert-large-k6s1": 1024, |
|
"prokbert-large-k6s2": 1024 |
|
} |
|
|
|
PRETRAINED_INIT_CONFIGURATION = { |
|
"prokbert-mini-k6s1": {"do_upper_case": True}, |
|
"prokbert-large-k6s1": {"do_upper_case": True}, |
|
"prokbert-large-k6s2": {"do_upper_case": True} |
|
|
|
} |
|
|
|
|
|
def load_vocab(vocab_file): |
|
"""Loads a vocabulary file into a dictionary.""" |
|
vocab = collections.OrderedDict() |
|
with open(vocab_file, "r", encoding="utf-8") as reader: |
|
tokens = reader.readlines() |
|
for index, token in enumerate(tokens): |
|
token = token.rstrip("\n") |
|
vocab[token] = index |
|
return vocab |
|
|
|
|
|
class ProkBERTTokenizer(PreTrainedTokenizer): |
|
"""Custom tokenizer for ProkBERT.""" |
|
|
|
vocab_files_names = VOCAB_FILES_NAMES |
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP |
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION |
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES |
|
nucleotide_abc = {'A', 'T', 'C', 'G'} |
|
extended_nucleotide_abc = {'A', 'T', 'C', 'G', '*'} |
|
sequence_unk_token = 'N' |
|
default_unk_token="[UNK]" |
|
default_sep_token="[SEP]" |
|
default_pad_token="[PAD]" |
|
default_cls_token="[CLS]" |
|
default_mask_token="[MASK]" |
|
|
|
|
|
def __init__(self, |
|
tokenization_params: Dict = {}, |
|
segmentation_params: Dict = {}, |
|
comp_params: Dict = {}, |
|
operation_space: str = 'sequence', |
|
**kwargs): |
|
"""Initialize the ProkBERT tokenizer. |
|
|
|
Args: |
|
tokenization_params (Dict, optional): Tokenization parameters. Defaults to {}. |
|
segmentation_params (Dict, optional): Segmentation parameters. Defaults to {}. |
|
comp_params (Dict, optional): Computational parameters. Defaults to {}. |
|
operation_space (str, optional): Specifies the operation mode. Can be 'kmer' or 'sequence'. Defaults to 'kmer'. |
|
""" |
|
super().__init__(cls_token=ProkBERTTokenizer.default_cls_token, |
|
**kwargs) |
|
|
|
self.defconfig = SeqConfig() |
|
self.tokenization_params = self.defconfig.get_and_set_tokenization_parameters(tokenization_params) |
|
self.segmentation_params = self.defconfig.get_and_set_segmentation_parameters(segmentation_params) |
|
self.comp_params = self.defconfig.get_and_set_computational_parameters(comp_params) |
|
self.operation_space = operation_space |
|
|
|
vocab_file = self.tokenization_params['vocabfile'] |
|
self.vocab = self.tokenization_params['vocabmap'] |
|
self.id2token = {v: k for k, v in self.vocab.items()} |
|
self.max_len = self.tokenization_params['max_segment_length'] |
|
|
|
if self.operation_space == 'sequence': |
|
token_extension = sorted(list(set(generate_kmers(ProkBERTTokenizer.extended_nucleotide_abc, self.tokenization_params['kmer'])) - \ |
|
set(generate_kmers(ProkBERTTokenizer.nucleotide_abc, self.tokenization_params['kmer'])) )) |
|
self.extended_vocab = deepcopy(self.vocab) |
|
for token in token_extension: |
|
self.extended_vocab[token] = 4 |
|
|
|
self.unk_token = ProkBERTTokenizer.sequence_unk_token * self.tokenization_params['shift'] |
|
self.mask_token = '*' |
|
self.extended_vocab[self.mask_token] = self.vocab['[MASK]'] |
|
|
|
full_unk = 'N' * self.tokenization_params['kmer'] |
|
self.vocab[full_unk] = 1 |
|
self.id2token[1] = full_unk |
|
self.full_unk_token = full_unk |
|
|
|
else: |
|
self.extended_vocab = self.vocab |
|
self.unk_token = '[UNK]' |
|
self.sep_token = '[SEP]' |
|
self.cls_token = '[CLS]' |
|
self.pad_token = '[PAD]' |
|
self.mask_token = '[MASK]' |
|
self.special_tokens = list(self.special_tokens_map.values()) |
|
|
|
def __len__(self) -> int: |
|
return len(self.vocab) |
|
|
|
|
|
def tokenize(self, text: str, lca_shift: int = 0, all: bool = False) -> Union[List[str], Tuple[List[List[str]], List[List[str]]]]: |
|
""" |
|
Tokenizes a given segment. |
|
|
|
Args: |
|
text (str): The DNA segment to tokenize. |
|
lca_shift (int, optional): Which tokenized vector belonging to the specified LCA offset should be returned. Defaults to 0. |
|
all (bool, optional): If True, returns all possible tokenizations. Defaults to False. |
|
|
|
Returns: |
|
Union[List[str], Tuple[List[List[str]], List[List[str]]]]: Tokenized segment or tuple of all possible tokenizations. |
|
|
|
Usage Example: |
|
>>> tokenizer = ProkBERTTokenizer(...) |
|
>>> segment = 'AATCAAGGAATTATTATCGTT' |
|
>>> tokens, kmers = tokenizer.tokenize(segment, all=True) |
|
>>> print(tokens) |
|
... |
|
""" |
|
tokenized_segments, kmerized_segments = lca_tokenize_segment(text, self.tokenization_params) |
|
if all: |
|
return tokenized_segments, kmerized_segments |
|
else: |
|
return kmerized_segments[lca_shift] |
|
|
|
def _convert_token_to_id(self, token): |
|
"""Converts a token (str) in an id using the vocab.""" |
|
return self.vocab.get(token, self.vocab.get(self.unk_token)) |
|
|
|
def _convert_id_to_token(self, index): |
|
"""Converts an index (integer) in a token (str) using the vocab.""" |
|
return self.ids_to_tokens.get(index, self.unk_token) |
|
|
|
|
|
def depr_convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> List[str]: |
|
""" |
|
Converts tokens to their corresponding IDs. |
|
|
|
Args: |
|
tokens (List[str]): List of tokens to convert. |
|
|
|
Returns: |
|
List[int]: List of corresponding token IDs. |
|
|
|
Usage Example: |
|
>>> tokenizer = ProkBERTTokenizer(...) |
|
>>> tokens = ['AATCAA', 'TCAAGG'] |
|
>>> ids = tokenizer.convert_tokens_to_ids(tokens) |
|
>>> print(ids) |
|
... |
|
""" |
|
|
|
if isinstance(ids, int): |
|
token_ids = self.vocab.get(ids, self.vocab[self.unk_token]) |
|
|
|
|
|
if self.operation_space == 'sequence': |
|
token_ids = [self.vocab.get(token, self.vocab[self.full_unk_token]) for token in tokens] |
|
|
|
else: |
|
token_ids = [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens] |
|
|
|
return token_ids |
|
|
|
def convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> Union[str, List[str]]: |
|
""" |
|
Converts token IDs back to their original tokens. |
|
|
|
Args: |
|
ids (List[int]): List of token IDs to convert. |
|
|
|
Returns: |
|
List[str]: List of corresponding tokens. |
|
|
|
Usage Example: |
|
>>> tokenizer = ProkBERTTokenizer(...) |
|
>>> ids = [213, 3343] |
|
>>> tokens = tokenizer.convert_ids_to_tokens(ids) |
|
>>> print(tokens) |
|
... |
|
""" |
|
if isinstance(ids, int): |
|
ids = [ids] |
|
if len(ids) == 1: |
|
|
|
return self.id2token.get(ids[0], self.unk_token) |
|
|
|
if self.operation_space == 'kmer': |
|
token_list = [self.id2token.get(id, self.unk_token) for id in ids] |
|
|
|
elif self.operation_space == 'sequence': |
|
token_list = [] |
|
|
|
if ids[0] == 2: |
|
pass |
|
else: |
|
token_list.append(self.id2token.get(ids[0], self.unk_token)) |
|
if len(ids) > 1: |
|
|
|
true_start_token = self.id2token.get(ids[1], self.unk_token) |
|
|
|
|
|
token_list.append(true_start_token) |
|
print(token_list) |
|
if len(ids) >2: |
|
|
|
for token_id in ids[2:]: |
|
mapped_token_id = self.id2token.get(token_id, self.unk_token) |
|
if (mapped_token_id in self.special_tokens): |
|
act_token_value = '' |
|
else: |
|
act_token_value = mapped_token_id[-1*self.tokenization_params['shift']:] |
|
token_list.append(act_token_value) |
|
|
|
return token_list |
|
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
|
"""Saves the vocabulary to a file.""" |
|
if filename_prefix is None: |
|
filename_prefix = "" |
|
vocab_file_path = os.path.join(save_directory, filename_prefix + "vocab.txt") |
|
with open(vocab_file_path, "w") as f: |
|
for token in self.vocab: |
|
f.write(token + "\\n") |
|
return (vocab_file_path,) |
|
|
|
@classmethod |
|
def from_pretrained(cls, vocab_file: str) -> 'ProkBERTTokenizer': |
|
"""Loads a pre-trained tokenizer. |
|
|
|
Args: |
|
vocab_file (str): Path to the pre-trained tokenizer vocabulary file. |
|
|
|
Returns: |
|
ProkBERTTokenizer: Loaded tokenizer instance. |
|
""" |
|
return cls(vocab_file) |
|
|
|
def encode_plus(self, text: str, lca_shift: int = 0, **kwargs) -> Dict[str, np.ndarray]: |
|
""" |
|
Tokenizes a sequence and returns it in a format suitable for model input. |
|
|
|
Args: |
|
text (str): The sequence to tokenize. |
|
lca_shift (int, optional): LCA offset for tokenization. Defaults to 0. |
|
|
|
Returns: |
|
Dict[str, np.ndarray]: Dictionary containing token IDs and attention masks. |
|
|
|
Usage Example: |
|
>>> tokenizer = ProkBERTTokenizer(...) |
|
>>> segment = 'AATCAAGGAATTATTATCGTT' |
|
>>> encoded = tokenizer.encode_plus(segment) |
|
>>> print(encoded) |
|
... |
|
""" |
|
tokenized_segments, kmerized_segments = lca_tokenize_segment(text, self.tokenization_params) |
|
input_ids = tokenized_segments[lca_shift] |
|
attention_mask = [1] * len(input_ids) |
|
|
|
|
|
while len(input_ids) < self.max_len: |
|
input_ids.append(0) |
|
attention_mask.append(0) |
|
|
|
return { |
|
"input_ids": np.array(input_ids, dtype=self.comp_params['np_tokentype']), |
|
"attention_mask": np.array(attention_mask, dtype=self.comp_params['np_tokentype']) |
|
} |
|
|
|
def batch_encode_plus(self, sequences: List[str], lca_shift: int = 0, all: bool = False, **kwargs) -> Dict[str, List[List[int]]]: |
|
""" |
|
Tokenizes multiple sequences and returns them in a format suitable for model input. It is assumed that sequences |
|
have already been preprocessed (i.e., segmented) and quality controlled. |
|
|
|
Args: |
|
- sequences (List[str]): A list of DNA sequences to be tokenized. |
|
- lca_shift (int, default=0): The LCA offset or windows to get the tokenized vector. If the required offset is >= shift, |
|
an error is raised. |
|
- all (bool, default=False): Whether all possible tokenization vectors should be returned. If False, only the specified |
|
offset is used. |
|
- **kwargs: Additional arguments (like max_length, padding, etc.) |
|
|
|
Returns: |
|
- Dict[str, List[List[int]]]: A dictionary containing token IDs, attention masks, and token type IDs. |
|
""" |
|
shift = self.tokenization_params['shift'] |
|
if lca_shift >= shift: |
|
raise ValueError(f'The required offset {lca_shift} is invalid. The maximum offset should be < {shift}') |
|
|
|
|
|
sequence_ids = list(range(len(sequences))) |
|
to_tokenize_data = (sequences, sequence_ids) |
|
|
|
|
|
tokenization_results = batch_tokenize_segments_with_ids( |
|
to_tokenize_data, |
|
self.tokenization_params, |
|
self.comp_params['cpu_cores_for_tokenization'], |
|
self.comp_params['batch_size_tokenization'], |
|
self.comp_params['np_tokentype'] |
|
) |
|
|
|
|
|
input_ids = [] |
|
token_type_ids = [] |
|
attention_masks = [] |
|
|
|
if all: |
|
for tokenized_vectors in tokenization_results.values(): |
|
for tokenized_vector in tokenized_vectors: |
|
input_ids.append(tokenized_vector) |
|
token_type_ids.append([0] * len(tokenized_vector)) |
|
attention_masks.append([1] * len(tokenized_vector)) |
|
else: |
|
for tokenized_vectors in tokenization_results.values(): |
|
selected_vector = tokenized_vectors[lca_shift] |
|
input_ids.append(selected_vector) |
|
token_type_ids.append([0] * len(selected_vector)) |
|
attention_masks.append([1] * len(selected_vector)) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"token_type_ids": token_type_ids, |
|
"attention_mask": attention_masks |
|
} |
|
|
|
def encode(self, segment: str, lca_shift: int = 0, all: bool = False, add_special_tokens: bool = True, **kwargs) -> List[int]: |
|
""" |
|
Encode a DNA sequence into its corresponding token IDs. |
|
|
|
Args: |
|
text (str): The DNA segment to encode. |
|
add_special_tokens (bool, optional): Whether to add special tokens like [CLS] and [SEP]. Defaults to True. |
|
|
|
Returns: |
|
List[int]: Encoded token IDs. |
|
|
|
Usage Example: |
|
>>> tokenizer = ProkBERTTokenizer(...) |
|
>>> segment = 'AATCAAGGAATTATTATCGTT' |
|
>>> ids = tokenizer.encode(segment) |
|
>>> print(ids) |
|
... |
|
""" |
|
shift = self.tokenization_params['shift'] |
|
if lca_shift >= shift: |
|
raise ValueError(f'The required offset {lca_shift} is invalid. The maximum offset should be < {shift}') |
|
|
|
tokenized_segments, _ = lca_tokenize_segment(segment, self.tokenization_params) |
|
|
|
|
|
if all: |
|
token_ids = tokenized_segments |
|
if not add_special_tokens: |
|
new_token_ids = [] |
|
for token_id_set in tokenized_segments: |
|
new_token_ids.append(token_id_set[1:len(token_id_set)-1]) |
|
token_ids = new_token_ids |
|
|
|
else: |
|
token_ids = tokenized_segments[lca_shift] |
|
|
|
|
|
if not add_special_tokens: |
|
token_ids = token_ids[1:len(token_ids)-1] |
|
|
|
return token_ids |
|
|
|
def decode(self, ids): |
|
tokens = self.convert_ids_to_tokens(ids) |
|
return ''.join(tokens) |
|
|
|
def batch_decode(self, token_ids_list: List[List[int]], **kwargs) -> List[str]: |
|
""" |
|
Decodes multiple token ID sequences back into their original sequences. |
|
|
|
Args: |
|
token_ids_list (List[List[int]]): List of token ID sequences. |
|
|
|
Returns: |
|
List[str]: List of decoded sequences. |
|
|
|
Usage Example: |
|
>>> tokenizer = ProkBERTTokenizer(...) |
|
>>> ids = [[2, 213, 3343, 165, 2580, 248, 3905, 978, 3296, 3]] |
|
>>> sequences = tokenizer.batch_decode(ids) |
|
>>> print(sequences) |
|
... |
|
""" |
|
return [self.decode(token_ids) for token_ids in token_ids_list] |