Spaces:
Running
Running
File size: 4,356 Bytes
341de97 0186ed1 341de97 7e8d9b9 cc8b2eb 341de97 52c67ef 11c7796 0186ed1 ee83d59 0186ed1 341de97 cc8b2eb 341de97 7e8d9b9 0186ed1 7e8d9b9 341de97 0186ed1 341de97 7e8d9b9 cc8b2eb 341de97 1f125f1 341de97 52c67ef 7e8d9b9 ee83d59 341de97 0186ed1 341de97 ee83d59 0186ed1 ee83d59 9da31aa 0186ed1 341de97 7e8d9b9 0186ed1 52c67ef 0186ed1 52c67ef 0186ed1 52c67ef 50b0c43 341de97 0186ed1 341de97 cc8b2eb 341de97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from typing import Union
import torch
import transformers
from processors import EncryptorLogitsProcessor, DecryptorProcessor
def generate(
tokenizer,
model,
prompt: str | list[str],
msg: bytes,
start_pos_p: list[int],
delta: float,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
min_new_tokens_ratio: float = 1,
max_new_tokens_ratio: float = 2,
do_sample: bool = True,
num_beams: int = 1,
repetition_penalty: float = 1.0,
generator: torch.Generator | None = None,
):
"""
Generate the sequence containing the hidden data.
Args:
tokenizer: tokenizer to use.
model: generative model to use.
prompt: input prompt.
msg: message to hide in the text.
delta: bias add to scores of token in valid list.
msg_base: base of the message.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
"""
if len(start_pos_p) == 1:
start_pos = start_pos_p[0]
else:
start_pos = torch.randint(
start_pos_p[0], start_pos_p[1] + 1, (1,)
).item()
start_pos = int(start_pos) + window_length
tokenizer.pad_token = tokenizer.eos_token
tokenized_input = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
prompt_size = tokenized_input.input_ids.size(1)
logits_processor = EncryptorLogitsProcessor(
prompt_ids=tokenized_input.input_ids,
msg=msg,
start_pos=start_pos,
delta=delta,
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
tokenizer=tokenizer,
device=model.device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
min_length = (
prompt_size
+ start_pos
+ logits_processor.get_message_len() * min_new_tokens_ratio
)
max_length = (
prompt_size
+ start_pos
+ logits_processor.get_message_len() * max_new_tokens_ratio
)
max_length = min(max_length, tokenizer.model_max_length)
min_length = min(min_length, max_length)
output_tokens = model.generate(
**tokenized_input,
logits_processor=transformers.LogitsProcessorList([logits_processor]),
min_length=min_length,
max_length=max_length,
do_sample=do_sample,
num_beams=num_beams,
repetition_penalty=float(repetition_penalty),
pad_token_id=tokenizer.eos_token_id,
generator=generator,
)
output_tokens = output_tokens[:, prompt_size:]
output_texts = tokenizer.batch_decode(
output_tokens, skip_special_tokens=True
)
output_tokens_post = tokenizer(
output_texts,
return_tensors="pt",
add_special_tokens=False,
padding=True,
).to(model.device)
msg_rates, tokens_infos = logits_processor.validate(
output_tokens_post.input_ids
)
return output_texts, msg_rates, tokens_infos
def decrypt(
tokenizer,
device: torch.device,
text: str,
msg_base: int,
seed_scheme: str,
window_length: int = 1,
salt_key: Union[int, None] = None,
private_key: Union[int, None] = None,
):
"""
Extract the hidden data from the generated sequence.
Args:
tokenizer: tokenizer to use.
text: text to decode.
msg_base: base of the message.
delta: bias added to scores of valid list.
seed_scheme: scheme used to compute the seed.
window_length: length of window to compute the seed.
salt_key: salt to add to the seed.
private_key: private key used to compute the seed.
"""
tokenized_input = tokenizer(text, return_tensors="pt").to(device)
decryptor = DecryptorProcessor(
msg_base=msg_base,
vocab=list(tokenizer.get_vocab().values()),
device=device,
seed_scheme=seed_scheme,
window_length=window_length,
salt_key=salt_key,
private_key=private_key,
)
msg = decryptor.decrypt(tokenized_input.input_ids)
return msg
|