Spaces:
Sleeping
Sleeping
Allow users to change the start_pos
Browse files- main.py +10 -1
- processors.py +29 -16
- requirements.txt +1 -0
- seed_schemes.py +0 -39
- stegno.py +22 -5
- test.py +14 -0
main.py
CHANGED
@@ -84,6 +84,14 @@ def create_args():
|
|
84 |
default=None,
|
85 |
help="Text or path to text containing the hidden message",
|
86 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
# Mode
|
88 |
parser.add_argument(
|
89 |
"--encrypt",
|
@@ -166,6 +174,7 @@ def main(args):
|
|
166 |
model=model,
|
167 |
prompt=args.prompt,
|
168 |
msg=args.msg,
|
|
|
169 |
gamma=args.gamma,
|
170 |
msg_base=args.msg_base,
|
171 |
seed_scheme=args.seed_scheme,
|
@@ -179,7 +188,7 @@ def main(args):
|
|
179 |
print("-" * (os.get_terminal_size().columns))
|
180 |
print(text)
|
181 |
print("-" * (os.get_terminal_size().columns))
|
182 |
-
print(f"Successfully hide {msg_rate*100:.2f} of the message")
|
183 |
print("-" * (os.get_terminal_size().columns))
|
184 |
|
185 |
if len(args.save_file) > 0:
|
|
|
84 |
default=None,
|
85 |
help="Text or path to text containing the hidden message",
|
86 |
)
|
87 |
+
# Encryption params
|
88 |
+
parser.add_argument(
|
89 |
+
"--start-pos",
|
90 |
+
type=int,
|
91 |
+
nargs="+",
|
92 |
+
default=[0],
|
93 |
+
help="Start position to input the text (not including window length). If 2 integers are provided, choose the position randomly between the two values.",
|
94 |
+
)
|
95 |
# Mode
|
96 |
parser.add_argument(
|
97 |
"--encrypt",
|
|
|
174 |
model=model,
|
175 |
prompt=args.prompt,
|
176 |
msg=args.msg,
|
177 |
+
start_pos_p=args.start_pos,
|
178 |
gamma=args.gamma,
|
179 |
msg_base=args.msg_base,
|
180 |
seed_scheme=args.seed_scheme,
|
|
|
188 |
print("-" * (os.get_terminal_size().columns))
|
189 |
print(text)
|
190 |
print("-" * (os.get_terminal_size().columns))
|
191 |
+
print(f"Successfully hide {msg_rate*100:.2f}% of the message")
|
192 |
print("-" * (os.get_terminal_size().columns))
|
193 |
|
194 |
if len(args.save_file) > 0:
|
processors.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Union
|
|
4 |
import torch
|
5 |
from transformers import LogitsProcessor
|
6 |
|
7 |
-
from
|
8 |
from utils import bytes_to_base, base_to_bytes, get_values_per_byte
|
9 |
|
10 |
|
@@ -36,15 +36,20 @@ class BaseProcessor(object):
|
|
36 |
self.device = device
|
37 |
|
38 |
# Seed parameters
|
39 |
-
|
40 |
seed_scheme,
|
41 |
salt_key=salt_key,
|
42 |
private_key=private_key,
|
43 |
)
|
|
|
|
|
|
|
|
|
|
|
44 |
self.window_length = window_length
|
45 |
|
46 |
-
# Initialize RNG
|
47 |
-
self.rng = torch.Generator(device=
|
48 |
|
49 |
# Compute the ranges of each value in base
|
50 |
self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64)
|
@@ -69,7 +74,9 @@ class BaseProcessor(object):
|
|
69 |
Get ids of tokens in the valid list for the current sequences.
|
70 |
"""
|
71 |
self._seed_rng(input_ids)
|
72 |
-
vocab_perm = torch.randperm(
|
|
|
|
|
73 |
vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]]
|
74 |
|
75 |
return vocab_list
|
@@ -79,7 +86,9 @@ class BaseProcessor(object):
|
|
79 |
Check whether the token is in the valid list.
|
80 |
"""
|
81 |
self._seed_rng(input_ids[:-1])
|
82 |
-
vocab_perm = torch.randperm(
|
|
|
|
|
83 |
|
84 |
cur_token = input_ids[-1]
|
85 |
cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0]
|
@@ -94,8 +103,9 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
94 |
prompt_ids: torch.Tensor,
|
95 |
msg: bytes,
|
96 |
gamma: float,
|
|
|
97 |
*args,
|
98 |
-
**kwargs
|
99 |
):
|
100 |
"""
|
101 |
Args:
|
@@ -103,10 +113,14 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
103 |
gamma: bias add to scores of token in valid list.
|
104 |
"""
|
105 |
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
self.start_pos = []
|
108 |
-
for i in range(prompt_ids.size(0)):
|
109 |
-
self.start_pos.append(prompt_ids[i].size(0))
|
110 |
self.raw_msg = msg
|
111 |
self.msg = bytes_to_base(msg, self.msg_base)
|
112 |
self.gamma = gamma
|
@@ -118,8 +132,8 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
118 |
|
119 |
for i, input_ids in enumerate(input_ids_batch):
|
120 |
cur_pos = input_ids.size(0)
|
121 |
-
msg_ptr = cur_pos - self.start_pos
|
122 |
-
if msg_ptr >= len(self.msg):
|
123 |
continue
|
124 |
scores_batch[i] = self._add_bias_to_valid_list(
|
125 |
input_ids, scores_batch[i], self.msg[msg_ptr]
|
@@ -144,7 +158,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
144 |
res = []
|
145 |
for input_ids in input_ids_batch:
|
146 |
values = []
|
147 |
-
for i in range(self.start_pos
|
148 |
values.append(self._get_value(input_ids[: i + 1]))
|
149 |
enc_msg = base_to_bytes(values, self.msg_base)
|
150 |
cnt = 0
|
@@ -152,7 +166,6 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
152 |
if self.raw_msg[i] == enc_msg[i]:
|
153 |
cnt += 1
|
154 |
res.append(cnt / len(self.raw_msg))
|
155 |
-
|
156 |
|
157 |
return res
|
158 |
|
@@ -166,12 +179,12 @@ class DecryptorProcessor(BaseProcessor):
|
|
166 |
Decrypt the text sequences.
|
167 |
"""
|
168 |
shift_msg = []
|
169 |
-
for
|
170 |
msg = []
|
171 |
bytes_msg = []
|
172 |
for i, input_ids in enumerate(input_ids_batch):
|
173 |
msg.append(list())
|
174 |
-
for j in range(self.window_length +
|
175 |
# TODO: this could be slow. Considering reimplement this.
|
176 |
value = self._get_value(input_ids[: j + 1])
|
177 |
msg[i].append(value)
|
|
|
4 |
import torch
|
5 |
from transformers import LogitsProcessor
|
6 |
|
7 |
+
from seed_scheme_factory import SeedSchemeFactory
|
8 |
from utils import bytes_to_base, base_to_bytes, get_values_per_byte
|
9 |
|
10 |
|
|
|
36 |
self.device = device
|
37 |
|
38 |
# Seed parameters
|
39 |
+
seed_fn = SeedSchemeFactory.get_instance(
|
40 |
seed_scheme,
|
41 |
salt_key=salt_key,
|
42 |
private_key=private_key,
|
43 |
)
|
44 |
+
if seed_fn is None:
|
45 |
+
raise ValueError(f'Seed scheme "{seed_scheme}" is invalid')
|
46 |
+
else:
|
47 |
+
self.seed_fn = seed_fn
|
48 |
+
|
49 |
self.window_length = window_length
|
50 |
|
51 |
+
# Initialize RNG, always use cpu generator
|
52 |
+
self.rng = torch.Generator(device="cpu")
|
53 |
|
54 |
# Compute the ranges of each value in base
|
55 |
self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64)
|
|
|
74 |
Get ids of tokens in the valid list for the current sequences.
|
75 |
"""
|
76 |
self._seed_rng(input_ids)
|
77 |
+
vocab_perm = torch.randperm(
|
78 |
+
self.vocab_size, generator=self.rng, device="cpu"
|
79 |
+
).to(self.device)
|
80 |
vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]]
|
81 |
|
82 |
return vocab_list
|
|
|
86 |
Check whether the token is in the valid list.
|
87 |
"""
|
88 |
self._seed_rng(input_ids[:-1])
|
89 |
+
vocab_perm = torch.randperm(
|
90 |
+
self.vocab_size, generator=self.rng, device="cpu"
|
91 |
+
).to(self.device)
|
92 |
|
93 |
cur_token = input_ids[-1]
|
94 |
cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0]
|
|
|
103 |
prompt_ids: torch.Tensor,
|
104 |
msg: bytes,
|
105 |
gamma: float,
|
106 |
+
start_pos: int = 0,
|
107 |
*args,
|
108 |
+
**kwargs,
|
109 |
):
|
110 |
"""
|
111 |
Args:
|
|
|
113 |
gamma: bias add to scores of token in valid list.
|
114 |
"""
|
115 |
super().__init__(*args, **kwargs)
|
116 |
+
if prompt_ids.size(0) != 1:
|
117 |
+
raise RuntimeError(
|
118 |
+
"EncryptorLogitsProcessor does not support multiple prompts input."
|
119 |
+
)
|
120 |
+
|
121 |
+
self.prompt_size = prompt_ids.size(1)
|
122 |
+
self.start_pos = start_pos
|
123 |
|
|
|
|
|
|
|
124 |
self.raw_msg = msg
|
125 |
self.msg = bytes_to_base(msg, self.msg_base)
|
126 |
self.gamma = gamma
|
|
|
132 |
|
133 |
for i, input_ids in enumerate(input_ids_batch):
|
134 |
cur_pos = input_ids.size(0)
|
135 |
+
msg_ptr = cur_pos - (self.prompt_size + self.start_pos)
|
136 |
+
if msg_ptr < 0 or msg_ptr >= len(self.msg):
|
137 |
continue
|
138 |
scores_batch[i] = self._add_bias_to_valid_list(
|
139 |
input_ids, scores_batch[i], self.msg[msg_ptr]
|
|
|
158 |
res = []
|
159 |
for input_ids in input_ids_batch:
|
160 |
values = []
|
161 |
+
for i in range(self.start_pos, input_ids.size(0)):
|
162 |
values.append(self._get_value(input_ids[: i + 1]))
|
163 |
enc_msg = base_to_bytes(values, self.msg_base)
|
164 |
cnt = 0
|
|
|
166 |
if self.raw_msg[i] == enc_msg[i]:
|
167 |
cnt += 1
|
168 |
res.append(cnt / len(self.raw_msg))
|
|
|
169 |
|
170 |
return res
|
171 |
|
|
|
179 |
Decrypt the text sequences.
|
180 |
"""
|
181 |
shift_msg = []
|
182 |
+
for shift in range(get_values_per_byte(self.msg_base)):
|
183 |
msg = []
|
184 |
bytes_msg = []
|
185 |
for i, input_ids in enumerate(input_ids_batch):
|
186 |
msg.append(list())
|
187 |
+
for j in range(self.window_length + shift, len(input_ids)):
|
188 |
# TODO: this could be slow. Considering reimplement this.
|
189 |
value = self._get_value(input_ids[: j + 1])
|
190 |
msg[i].append(value)
|
requirements.txt
CHANGED
@@ -5,3 +5,4 @@ PyYAML==6.0.1
|
|
5 |
scikit-learn==1.5.0
|
6 |
torch==2.3.0
|
7 |
cryptography==42.0.8
|
|
|
|
5 |
scikit-learn==1.5.0
|
6 |
torch==2.3.0
|
7 |
cryptography==42.0.8
|
8 |
+
fastapi
|
seed_schemes.py
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
from typing import Union, Callable
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
class SeedSchemeFactory:
|
7 |
-
def __init__(self):
|
8 |
-
self.seed_scheme_dict = dict()
|
9 |
-
|
10 |
-
def register(self, name: str, seed_scheme: type):
|
11 |
-
"""
|
12 |
-
Register the hash scheme by name. Hash scheme must be callable.
|
13 |
-
|
14 |
-
Args:
|
15 |
-
name: name of seed scheme.
|
16 |
-
func: seed function.
|
17 |
-
"""
|
18 |
-
self.seed_scheme_dict[name] = seed_scheme
|
19 |
-
|
20 |
-
def get(self, name: str, **kwargs):
|
21 |
-
"""
|
22 |
-
Get the hash scheme by name.
|
23 |
-
|
24 |
-
Args:
|
25 |
-
name: name of seed scheme.
|
26 |
-
"""
|
27 |
-
return self.seed_scheme_dict[name](**kwargs)
|
28 |
-
|
29 |
-
|
30 |
-
class DummyHash:
|
31 |
-
def __init__(self, *args, **kwargs):
|
32 |
-
pass
|
33 |
-
|
34 |
-
def __call__(self, input_ids: torch.Tensor):
|
35 |
-
return input_ids[-1].item()
|
36 |
-
|
37 |
-
|
38 |
-
seed_scheme_factory = SeedSchemeFactory()
|
39 |
-
seed_scheme_factory.register("dummy_hash", DummyHash)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stegno.py
CHANGED
@@ -11,6 +11,7 @@ def generate(
|
|
11 |
model,
|
12 |
prompt: str,
|
13 |
msg: bytes,
|
|
|
14 |
gamma: float,
|
15 |
msg_base: int,
|
16 |
seed_scheme: str,
|
@@ -36,10 +37,20 @@ def generate(
|
|
36 |
private_key: private key used to compute the seed.
|
37 |
|
38 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
40 |
logits_processor = EncryptorLogitsProcessor(
|
41 |
prompt_ids=tokenized_input.input_ids,
|
42 |
msg=msg,
|
|
|
43 |
gamma=gamma,
|
44 |
msg_base=msg_base,
|
45 |
vocab=list(tokenizer.get_vocab().values()),
|
@@ -49,20 +60,26 @@ def generate(
|
|
49 |
salt_key=salt_key,
|
50 |
private_key=private_key,
|
51 |
)
|
|
|
|
|
|
|
|
|
52 |
output_tokens = model.generate(
|
53 |
**tokenized_input,
|
54 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
55 |
-
min_new_tokens=
|
56 |
-
max_new_tokens=
|
57 |
-
logits_processor.get_message_len() * max_new_tokens_ratio
|
58 |
-
),
|
59 |
do_sample=True,
|
60 |
num_beams=num_beams
|
61 |
)
|
|
|
|
|
62 |
output_text = tokenizer.batch_decode(
|
63 |
output_tokens, skip_special_tokens=True
|
64 |
)[0]
|
65 |
-
output_tokens_post = tokenizer(output_text, return_tensors="pt")
|
|
|
|
|
66 |
msg_rates = logits_processor.validate(output_tokens_post.input_ids)
|
67 |
|
68 |
return output_text, msg_rates[0]
|
|
|
11 |
model,
|
12 |
prompt: str,
|
13 |
msg: bytes,
|
14 |
+
start_pos_p: list[int],
|
15 |
gamma: float,
|
16 |
msg_base: int,
|
17 |
seed_scheme: str,
|
|
|
37 |
private_key: private key used to compute the seed.
|
38 |
|
39 |
"""
|
40 |
+
if len(start_pos_p) == 1:
|
41 |
+
start_pos = start_pos_p[0]
|
42 |
+
else:
|
43 |
+
start_pos = torch.randint(
|
44 |
+
start_pos_p[0], start_pos_p[1] + 1, (1,)
|
45 |
+
).item()
|
46 |
+
start_pos = int(start_pos) + window_length
|
47 |
+
|
48 |
tokenized_input = tokenizer(prompt, return_tensors="pt").to(model.device)
|
49 |
+
prompt_size = tokenized_input.input_ids.size(1)
|
50 |
logits_processor = EncryptorLogitsProcessor(
|
51 |
prompt_ids=tokenized_input.input_ids,
|
52 |
msg=msg,
|
53 |
+
start_pos=start_pos,
|
54 |
gamma=gamma,
|
55 |
msg_base=msg_base,
|
56 |
vocab=list(tokenizer.get_vocab().values()),
|
|
|
60 |
salt_key=salt_key,
|
61 |
private_key=private_key,
|
62 |
)
|
63 |
+
min_length = start_pos + logits_processor.get_message_len()
|
64 |
+
max_length = int(
|
65 |
+
start_pos + logits_processor.get_message_len() * max_new_tokens_ratio
|
66 |
+
)
|
67 |
output_tokens = model.generate(
|
68 |
**tokenized_input,
|
69 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
70 |
+
min_new_tokens=min_length,
|
71 |
+
max_new_tokens=max_length,
|
|
|
|
|
72 |
do_sample=True,
|
73 |
num_beams=num_beams
|
74 |
)
|
75 |
+
|
76 |
+
output_tokens = output_tokens[:, prompt_size:]
|
77 |
output_text = tokenizer.batch_decode(
|
78 |
output_tokens, skip_special_tokens=True
|
79 |
)[0]
|
80 |
+
output_tokens_post = tokenizer(output_text, return_tensors="pt").to(
|
81 |
+
model.device
|
82 |
+
)
|
83 |
msg_rates = logits_processor.validate(output_tokens_post.input_ids)
|
84 |
|
85 |
return output_text, msg_rates[0]
|
test.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from seed_scheme_factory import SeedSchemeFactory
|
2 |
+
import torch
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
seed_fn = SeedSchemeFactory.get_instance("sha_left_hash", private_key=18)
|
6 |
+
rng = torch.Generator()
|
7 |
+
rng.manual_seed(1)
|
8 |
+
input_ids = torch.randint(0, 2**32, (8,), generator=rng)
|
9 |
+
print("input_ids =", input_ids)
|
10 |
+
if seed_fn is not None:
|
11 |
+
seed = seed_fn(input_ids)
|
12 |
+
print(" ", 2**64-1)
|
13 |
+
print("seed =", seed)
|
14 |
+
|