from typing import List | |
from tokenizers import NormalizedString, PreTokenizedString | |
from tokenizers.pre_tokenizers import PreTokenizer | |
from transformers import PreTrainedTokenizerFast | |
try: | |
from clang import cindex | |
except ModuleNotFoundError as e: | |
raise ModuleNotFoundError( | |
"VulBERTa Clang tokenizer requires `libclang`. Please install it via `pip install libclang`.", | |
) from e | |
class ClangPreTokenizer: | |
cidx = cindex.Index.create() | |
def clang_split( | |
self, | |
i: int, | |
normalized_string: NormalizedString, | |
) -> List[NormalizedString]: | |
tok = [] | |
tu = self.cidx.parse( | |
"tmp.c", | |
args=[""], | |
unsaved_files=[("tmp.c", str(normalized_string.original))], | |
options=0, | |
) | |
for t in tu.get_tokens(extent=tu.cursor.extent): | |
spelling = t.spelling.strip() | |
if spelling == "": | |
continue | |
tok.append(NormalizedString(spelling)) | |
return tok | |
def pre_tokenize(self, pretok: PreTokenizedString): | |
pretok.split(self.clang_split) | |
class VulBERTaTokenizer(PreTrainedTokenizerFast): | |
def __init__( | |
self, | |
*args, | |
**kwargs, | |
): | |
super().__init__( | |
*args, | |
**kwargs, | |
) | |
self._tokenizer.pre_tokenizer = PreTokenizer.custom(ClangPreTokenizer()) | |