tnk2908's picture
Update statistical analysis
9da31aa
raw
history blame
4.32 kB
from typing import Union
import torch
import transformers
from processors import EncryptorLogitsProcessor, DecryptorProcessor
def generate(
tokenizer,
model,
prompt: str,
msg: bytes,
start_pos_p: list[int],
delta: float,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
min_new_tokens_ratio: float = 1,
max_new_tokens_ratio: float = 2,
num_beams: int = 4,
repetition_penalty: float = 1.0,
prompt_size: int = -1,
):
"""
Generate the sequence containing the hidden data.
Args:
tokenizer: tokenizer to use.
model: generative model to use.
prompt: input prompt.
msg: message to hide in the text.
delta: bias add to scores of token in valid list.
msg_base: base of the message.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
"""
if len(start_pos_p) == 1:
start_pos = start_pos_p[0]
else:
start_pos = torch.randint(
start_pos_p[0], start_pos_p[1] + 1, (1,)
).item()
start_pos = int(start_pos) + window_length
tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
if prompt_size == -1:
prompt_size = tokenized_input.input_ids.size(1)
logits_processor = EncryptorLogitsProcessor(
prompt_ids=tokenized_input.input_ids[:prompt_size],
msg=msg,
start_pos=start_pos,
delta=delta,
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
tokenizer=tokenizer,
device=model.device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
min_length = (
prompt_size
+ start_pos
+ logits_processor.get_message_len() * min_new_tokens_ratio
)
max_length = (
prompt_size
+ start_pos
+ logits_processor.get_message_len() * max_new_tokens_ratio
)
max_length = min(max_length, tokenizer.model_max_length)
min_length = min(min_length, max_length)
output_tokens = model.generate(
input_ids=tokenized_input.input_ids[:, :prompt_size],
attention_mask=tokenized_input.attention_mask[:, :prompt_size],
logits_processor=transformers.LogitsProcessorList([logits_processor]),
min_length=min_length,
max_length=max_length,
do_sample=True,
num_beams=num_beams,
repetition_penalty=float(repetition_penalty),
pad_token_id=tokenizer.eos_token_id,
)
output_tokens = output_tokens[:, prompt_size:]
output_text = tokenizer.batch_decode(
output_tokens, skip_special_tokens=True
)[0]
output_tokens_post = tokenizer(
output_text, return_tensors="pt", add_special_tokens=False
).to(model.device)
msg_rates, tokens_infos = logits_processor.validate(
output_tokens_post.input_ids
)
return output_text, msg_rates[0], tokens_infos[0]
def decrypt(
tokenizer,
device: torch.device,
text: str,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
):
"""
Extract the hidden data from the generated sequence.
Args:
tokenizer: tokenizer to use.
text: text to decode.
msg_base: base of the message.
delta: bias added to scores of valid list.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
"""
tokenized_input = tokenizer(text, return_tensors="pt").to(device)
decryptor = DecryptorProcessor(
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
device=device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
msg = decryptor.decrypt(tokenized_input.input_ids)
return msg