ai-text-steganography / processors.py
tnk2908's picture
Improve UI and reduce repetitiveness of generation
ee83d59
raw
history blame
8.74 kB
import os
from typing import Union
import torch
from transformers import LogitsProcessor
from seed_scheme_factory import SeedSchemeFactory
from utils import bytes_to_base, base_to_bytes, get_values_per_byte
class BaseProcessor(object):
def __init__(
self,
msg_base: int,
vocab: list[int],
device: torch.device,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
):
"""
Args:
msg_base: base of the message.
vocab: vocabulary list.
device: device to load processor.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
"""
# Universal parameters
self.msg_base = msg_base
self.vocab = vocab
self.vocab_size = len(vocab)
self.device = device
# Seed parameters
seed_fn = SeedSchemeFactory.get_instance(
seed_scheme,
salt_key=salt_key,
private_key=private_key,
)
if seed_fn is None:
raise ValueError(f'Seed scheme "{seed_scheme}" is invalid')
else:
self.seed_fn = seed_fn
self.window_length = window_length
# Initialize RNG, always use cpu generator
self.rng = torch.Generator(device="cpu")
# Compute the ranges of each value in base
self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64).to(
self.device
)
chunk_size = self.vocab_size / self.msg_base
r = self.vocab_size % self.msg_base
self.ranges[1:] = chunk_size
self.ranges[1 : r + 1] += 1
self.ranges = torch.cumsum(self.ranges, dim=0)
def _seed_rng(self, input_ids: torch.Tensor):
"""
Set the seed for the rng based on the current sequences.
Args:
input_ids: id in the input sequence.
"""
seed = self.seed_fn(input_ids[-self.window_length :])
self.rng.manual_seed(seed)
def _get_valid_list_ids(self, input_ids: torch.Tensor, value: int):
"""
Get ids of tokens in the valid list for the current sequences.
"""
self._seed_rng(input_ids)
vocab_perm = torch.randperm(
self.vocab_size, generator=self.rng, device="cpu"
).to(self.device)
vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]]
return vocab_list
def _get_value(self, input_ids: torch.Tensor):
"""
Check whether the token is in the valid list.
"""
self._seed_rng(input_ids[:-1])
vocab_perm = torch.randperm(
self.vocab_size, generator=self.rng, device="cpu"
).to(self.device)
cur_token = input_ids[-1]
cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0]
value = (cur_id < self.ranges).type(torch.int).argmax().item() - 1
return value
class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
def __init__(
self,
prompt_ids: torch.Tensor,
msg: bytes,
gamma: float,
tokenizer,
start_pos: int = 0,
*args,
**kwargs,
):
"""
Args:
msg: message to hide in the text.
gamma: bias add to scores of token in valid list.
"""
super().__init__(*args, **kwargs)
if prompt_ids.size(0) != 1:
raise RuntimeError(
"EncryptorLogitsProcessor does not support multiple prompts input."
)
self.prompt_size = prompt_ids.size(1)
self.start_pos = start_pos
self.raw_msg = msg
self.msg = bytes_to_base(msg, self.msg_base)
self.gamma = gamma
self.tokenizer = tokenizer
special_tokens = [
tokenizer.bos_token_id,
tokenizer.eos_token_id,
tokenizer.sep_token_id,
tokenizer.pad_token_id,
tokenizer.cls_token_id,
]
special_tokens = [x for x in special_tokens if x is not None]
self.special_tokens = torch.tensor(special_tokens, device=self.device)
def __call__(
self, input_ids_batch: torch.LongTensor, scores_batch: torch.FloatTensor
):
# If the whole message is hidden already, then just return the raw scores.
for i, input_ids in enumerate(input_ids_batch):
cur_pos = input_ids.size(0)
msg_ptr = cur_pos - (self.prompt_size + self.start_pos)
if msg_ptr < 0 or msg_ptr >= len(self.msg):
continue
scores_batch[i] = self._add_bias_to_valid_list(
input_ids, scores_batch[i], self.msg[msg_ptr]
)
return scores_batch
def _add_bias_to_valid_list(
self, input_ids: torch.Tensor, scores: torch.Tensor, value: int
):
"""
Add the bias (gamma) to the valid list tokens
"""
ids = torch.cat(
[self._get_valid_list_ids(input_ids, value), self.special_tokens]
)
scores[ids] = scores[ids] + self.gamma
return scores
def get_message_len(self):
return len(self.msg)
def __map_input_ids(self, input_ids: torch.Tensor, base_arr, byte_arr):
byte_enc_msg = [-1 for _ in range(input_ids.size(0))]
base_enc_msg = [-1 for _ in range(input_ids.size(0))]
base_msg = [-1 for _ in range(input_ids.size(0))]
byte_msg = [-1 for _ in range(input_ids.size(0))]
values_per_byte = get_values_per_byte(self.msg_base)
start = self.start_pos % values_per_byte
for i, b in enumerate(base_arr):
base_enc_msg[i] = base_arr[i]
byte_enc_msg[i] = byte_arr[(i - start) // values_per_byte]
for i, b in enumerate(self.msg):
base_msg[i + self.start_pos] = b
byte_msg[i + self.start_pos] = self.raw_msg[i // values_per_byte]
return base_msg, byte_msg, base_enc_msg, byte_enc_msg
def validate(self, input_ids_batch: torch.Tensor):
res = []
tokens_infos = []
for input_ids in input_ids_batch:
# Initialization
base_arr = []
# Loop and obtain values of all tokens
for i in range(0, input_ids.size(0)):
base_arr.append(self._get_value(input_ids[: i + 1]))
values_per_byte = get_values_per_byte(self.msg_base)
# Transform the values to bytes
start = self.start_pos % values_per_byte
byte_arr = base_to_bytes(base_arr[start:], self.msg_base)
# Construct the
cnt = 0
enc_msg = byte_arr[self.start_pos // values_per_byte :]
for i in range(min(len(enc_msg), len(self.raw_msg))):
if self.raw_msg[i] == enc_msg[i]:
cnt += 1
res.append(cnt / len(self.raw_msg))
base_msg, byte_msg, base_enc_msg, byte_enc_msg = (
self.__map_input_ids(input_ids, base_arr, byte_arr)
)
tokens = []
input_strs = [self.tokenizer.decode([input]) for input in input_ids]
for i in range(len(base_enc_msg)):
tokens.append(
{
"token": input_strs[i],
"base_enc": base_enc_msg[i],
"byte_enc": byte_enc_msg[i],
"base_msg": base_msg[i],
"byte_msg": byte_msg[i],
"byte_id": (i - start) // values_per_byte,
}
)
tokens_infos.append(tokens)
return res, tokens_infos
class DecryptorProcessor(BaseProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def decrypt(self, input_ids_batch: torch.Tensor):
"""
Decrypt the text sequences.
"""
shift_msg = []
for shift in range(get_values_per_byte(self.msg_base)):
msg = []
bytes_msg = []
for i, input_ids in enumerate(input_ids_batch):
msg.append(list())
for j in range(shift, len(input_ids)):
# TODO: this could be slow. Considering reimplement this.
value = self._get_value(input_ids[: j + 1])
msg[i].append(value)
bytes_msg.append(base_to_bytes(msg[i], self.msg_base))
shift_msg.append(bytes_msg)
return shift_msg