import os from typing import Union import torch from transformers import LogitsProcessor from seed_scheme_factory import SeedSchemeFactory from utils import bytes_to_base, base_to_bytes, get_values_per_byte class BaseProcessor(object): def __init__( self, msg_base: int, vocab: list[int], device: torch.device, seed_scheme: str, window_length: int = 1, salt_key: Union[int, None] = None, private_key: Union[int, None] = None, ): """ Args: msg_base: base of the message. vocab: vocabulary list. device: device to load processor. seed_scheme: scheme used to compute the seed. window_length: length of window to compute the seed. salt_key: salt to add to the seed. private_key: private key used to compute the seed. """ # Universal parameters self.msg_base = msg_base self.vocab = vocab self.vocab_size = len(vocab) self.device = device # Seed parameters seed_fn = SeedSchemeFactory.get_instance( seed_scheme, salt_key=salt_key, private_key=private_key, ) if seed_fn is None: raise ValueError(f'Seed scheme "{seed_scheme}" is invalid') else: self.seed_fn = seed_fn self.window_length = window_length # Initialize RNG, always use cpu generator self.rng = torch.Generator(device="cpu") # Compute the ranges of each value in base self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64).to( self.device ) chunk_size = self.vocab_size / self.msg_base r = self.vocab_size % self.msg_base self.ranges[1:] = chunk_size self.ranges[1 : r + 1] += 1 self.ranges = torch.cumsum(self.ranges, dim=0) def _seed_rng(self, input_ids: torch.Tensor): """ Set the seed for the rng based on the current sequences. Args: input_ids: id in the input sequence. """ seed = self.seed_fn(input_ids[-self.window_length :]) self.rng.manual_seed(seed) def _get_valid_list_ids(self, input_ids: torch.Tensor, value: int): """ Get ids of tokens in the valid list for the current sequences. """ self._seed_rng(input_ids) vocab_perm = torch.randperm( self.vocab_size, generator=self.rng, device="cpu" ).to(self.device) vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]] return vocab_list def _get_value(self, input_ids: torch.Tensor): """ Check whether the token is in the valid list. """ self._seed_rng(input_ids[:-1]) vocab_perm = torch.randperm( self.vocab_size, generator=self.rng, device="cpu" ).to(self.device) cur_token = input_ids[-1] cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0] value = (cur_id < self.ranges).type(torch.int).argmax().item() - 1 return value class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor): def __init__( self, prompt_ids: torch.Tensor, msg: bytes, gamma: float, tokenizer, start_pos: int = 0, *args, **kwargs, ): """ Args: msg: message to hide in the text. gamma: bias add to scores of token in valid list. """ super().__init__(*args, **kwargs) if prompt_ids.size(0) != 1: raise RuntimeError( "EncryptorLogitsProcessor does not support multiple prompts input." ) self.prompt_size = prompt_ids.size(1) self.start_pos = start_pos self.raw_msg = msg self.msg = bytes_to_base(msg, self.msg_base) self.gamma = gamma self.tokenizer = tokenizer special_tokens = [ tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id, tokenizer.cls_token_id, ] special_tokens = [x for x in special_tokens if x is not None] self.special_tokens = torch.tensor(special_tokens, device=self.device) def __call__( self, input_ids_batch: torch.LongTensor, scores_batch: torch.FloatTensor ): # If the whole message is hidden already, then just return the raw scores. for i, input_ids in enumerate(input_ids_batch): cur_pos = input_ids.size(0) msg_ptr = cur_pos - (self.prompt_size + self.start_pos) if msg_ptr < 0 or msg_ptr >= len(self.msg): continue scores_batch[i] = self._add_bias_to_valid_list( input_ids, scores_batch[i], self.msg[msg_ptr] ) return scores_batch def _add_bias_to_valid_list( self, input_ids: torch.Tensor, scores: torch.Tensor, value: int ): """ Add the bias (gamma) to the valid list tokens """ ids = torch.cat( [self._get_valid_list_ids(input_ids, value), self.special_tokens] ) scores[ids] = scores[ids] + self.gamma return scores def get_message_len(self): return len(self.msg) def __map_input_ids(self, input_ids: torch.Tensor, base_arr, byte_arr): byte_enc_msg = [-1 for _ in range(input_ids.size(0))] base_enc_msg = [-1 for _ in range(input_ids.size(0))] base_msg = [-1 for _ in range(input_ids.size(0))] byte_msg = [-1 for _ in range(input_ids.size(0))] values_per_byte = get_values_per_byte(self.msg_base) start = self.start_pos % values_per_byte for i, b in enumerate(base_arr): base_enc_msg[i] = base_arr[i] byte_enc_msg[i] = byte_arr[(i - start) // values_per_byte] for i, b in enumerate(self.msg): base_msg[i + self.start_pos] = b byte_msg[i + self.start_pos] = self.raw_msg[i // values_per_byte] return base_msg, byte_msg, base_enc_msg, byte_enc_msg def validate(self, input_ids_batch: torch.Tensor): res = [] tokens_infos = [] for input_ids in input_ids_batch: # Initialization base_arr = [] # Loop and obtain values of all tokens for i in range(0, input_ids.size(0)): base_arr.append(self._get_value(input_ids[: i + 1])) values_per_byte = get_values_per_byte(self.msg_base) # Transform the values to bytes start = self.start_pos % values_per_byte byte_arr = base_to_bytes(base_arr[start:], self.msg_base) # Construct the cnt = 0 enc_msg = byte_arr[self.start_pos // values_per_byte :] for i in range(min(len(enc_msg), len(self.raw_msg))): if self.raw_msg[i] == enc_msg[i]: cnt += 1 res.append(cnt / len(self.raw_msg)) base_msg, byte_msg, base_enc_msg, byte_enc_msg = ( self.__map_input_ids(input_ids, base_arr, byte_arr) ) tokens = [] input_strs = [self.tokenizer.decode([input]) for input in input_ids] for i in range(len(base_enc_msg)): tokens.append( { "token": input_strs[i], "base_enc": base_enc_msg[i], "byte_enc": byte_enc_msg[i], "base_msg": base_msg[i], "byte_msg": byte_msg[i], "byte_id": (i - start) // values_per_byte, } ) tokens_infos.append(tokens) return res, tokens_infos class DecryptorProcessor(BaseProcessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def decrypt(self, input_ids_batch: torch.Tensor): """ Decrypt the text sequences. """ shift_msg = [] for shift in range(get_values_per_byte(self.msg_base)): msg = [] bytes_msg = [] for i, input_ids in enumerate(input_ids_batch): msg.append(list()) for j in range(shift, len(input_ids)): # TODO: this could be slow. Considering reimplement this. value = self._get_value(input_ids[: j + 1]) msg[i].append(value) bytes_msg.append(base_to_bytes(msg[i], self.msg_base)) shift_msg.append(bytes_msg) return shift_msg