Spaces:
Sleeping
Sleeping
Shows percentage of message hidden in the results; Improve the UI of the command line interface
Browse files- main.py +85 -27
- processors.py +17 -0
- stegno.py +9 -4
main.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
from argparse import ArgumentParser
|
3 |
|
4 |
import torch
|
@@ -51,18 +52,37 @@ def create_args():
|
|
51 |
parser.add_argument(
|
52 |
"--private-key", type=str, default="", help="Path to private key"
|
53 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
# Input
|
55 |
parser.add_argument(
|
56 |
-
"--msg",
|
|
|
|
|
|
|
57 |
)
|
58 |
parser.add_argument(
|
59 |
-
"--prompt",
|
|
|
|
|
|
|
60 |
)
|
61 |
parser.add_argument(
|
62 |
"--text",
|
63 |
type=str,
|
64 |
default=None,
|
65 |
-
help="Text
|
66 |
)
|
67 |
# Mode
|
68 |
parser.add_argument(
|
@@ -89,74 +109,109 @@ def main(args):
|
|
89 |
|
90 |
if os.path.isfile(args.salt_key):
|
91 |
with open(args.salt_key, "r") as f:
|
92 |
-
salt_key = int(f.readline())
|
|
|
93 |
else:
|
94 |
-
salt_key = None
|
95 |
|
96 |
if os.path.isfile(args.private_key):
|
97 |
with open(args.private_key, "r") as f:
|
98 |
-
private_key = int(f.readline())
|
|
|
99 |
else:
|
100 |
-
private_key =
|
|
|
|
|
101 |
|
102 |
if args.encrypt:
|
103 |
if len(args.prompt) == 0:
|
104 |
raise ValueError("Prompt cannot be empty in encrypt mode")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
if os.path.isfile(args.msg):
|
|
|
106 |
with open(args.msg, "rb") as f:
|
107 |
-
msg = f.read()
|
108 |
else:
|
109 |
-
|
110 |
|
111 |
print("=" * os.get_terminal_size().columns)
|
112 |
print("Encryption Parameters:")
|
113 |
print(f" GenModel: {args.gen_model}")
|
114 |
-
print(f" Prompt:
|
115 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
print(f" Gamma: {args.gamma}")
|
117 |
print(f" Message Base: {args.msg_base}")
|
118 |
print(f" Seed Scheme: {args.seed_scheme}")
|
119 |
print(f" Window Length: {args.window_length}")
|
120 |
-
print(f" Salt Key: {salt_key}")
|
121 |
-
print(f" Private Key: {private_key}")
|
|
|
|
|
122 |
print("=" * os.get_terminal_size().columns)
|
123 |
-
text = generate(
|
124 |
tokenizer=tokenizer,
|
125 |
model=model,
|
126 |
prompt=args.prompt,
|
127 |
-
msg=msg,
|
128 |
gamma=args.gamma,
|
129 |
msg_base=args.msg_base,
|
130 |
seed_scheme=args.seed_scheme,
|
131 |
window_length=args.window_length,
|
132 |
-
salt_key=salt_key,
|
133 |
-
private_key=private_key,
|
|
|
|
|
134 |
)
|
135 |
-
print(f"Text contains message
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
with open(args.save_file, "w") as f:
|
139 |
f.write(text)
|
140 |
-
|
141 |
-
args.text = text
|
142 |
|
143 |
if args.decrypt:
|
144 |
if len(args.text) == 0:
|
145 |
raise ValueError("Text cannot be empty in decrypt mode")
|
|
|
146 |
if os.path.isfile(args.text):
|
|
|
147 |
with open(args.text, "r") as f:
|
148 |
lines = f.readlines()
|
149 |
args.text = "".join(lines)
|
|
|
150 |
print("=" * os.get_terminal_size().columns)
|
151 |
-
print("
|
152 |
print(f" GenModel: {args.gen_model}")
|
153 |
-
print(f" Text: {args.text}")
|
154 |
print(f" Message Base: {args.msg_base}")
|
155 |
print(f" Seed Scheme: {args.seed_scheme}")
|
156 |
print(f" Window Length: {args.window_length}")
|
157 |
-
print(f" Salt Key: {salt_key}")
|
158 |
-
print(f" Private Key: {private_key}")
|
|
|
|
|
|
|
|
|
159 |
print("=" * os.get_terminal_size().columns)
|
|
|
160 |
msgs = decrypt(
|
161 |
tokenizer=tokenizer,
|
162 |
device=args.device,
|
@@ -169,7 +224,10 @@ def main(args):
|
|
169 |
)
|
170 |
print("Message:")
|
171 |
for s, msg in enumerate(msgs):
|
172 |
-
print(
|
|
|
|
|
|
|
173 |
|
174 |
|
175 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
+
import json
|
3 |
from argparse import ArgumentParser
|
4 |
|
5 |
import torch
|
|
|
52 |
parser.add_argument(
|
53 |
"--private-key", type=str, default="", help="Path to private key"
|
54 |
)
|
55 |
+
# Generation Params
|
56 |
+
parser.add_argument(
|
57 |
+
"--num-beams",
|
58 |
+
type=int,
|
59 |
+
default=4,
|
60 |
+
help="Number of beams used in beam search",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--max-new-tokens-ratio",
|
64 |
+
type=float,
|
65 |
+
default=2,
|
66 |
+
help="Ratio of max new tokens to minimum tokens required to hide message",
|
67 |
+
)
|
68 |
# Input
|
69 |
parser.add_argument(
|
70 |
+
"--msg",
|
71 |
+
type=str,
|
72 |
+
default=None,
|
73 |
+
help="Message or path to message to be hidden",
|
74 |
)
|
75 |
parser.add_argument(
|
76 |
+
"--prompt",
|
77 |
+
type=str,
|
78 |
+
default=None,
|
79 |
+
help="Prompt or path to prompt used to generate text",
|
80 |
)
|
81 |
parser.add_argument(
|
82 |
"--text",
|
83 |
type=str,
|
84 |
default=None,
|
85 |
+
help="Text or path to text containing the hidden message",
|
86 |
)
|
87 |
# Mode
|
88 |
parser.add_argument(
|
|
|
109 |
|
110 |
if os.path.isfile(args.salt_key):
|
111 |
with open(args.salt_key, "r") as f:
|
112 |
+
args.salt_key = int(f.readline())
|
113 |
+
print(f"Read salt key from {args.salt_key}")
|
114 |
else:
|
115 |
+
args.salt_key = int(args.salt_key) if len(args.salt_key) > 0 else None
|
116 |
|
117 |
if os.path.isfile(args.private_key):
|
118 |
with open(args.private_key, "r") as f:
|
119 |
+
args.private_key = int(f.readline())
|
120 |
+
print(f"Read private key from {args.private_key}")
|
121 |
else:
|
122 |
+
args.private_key = (
|
123 |
+
int(args.private_key) if len(args.private_key) > 0 else None
|
124 |
+
)
|
125 |
|
126 |
if args.encrypt:
|
127 |
if len(args.prompt) == 0:
|
128 |
raise ValueError("Prompt cannot be empty in encrypt mode")
|
129 |
+
if len(args.msg) == 0:
|
130 |
+
raise ValueError("Message cannot be empty in encrypt mode")
|
131 |
+
|
132 |
+
if os.path.isfile(args.prompt):
|
133 |
+
print(f"Read prompt from {args.prompt}")
|
134 |
+
with open(args.prompt, "r") as f:
|
135 |
+
args.prompt = "".join(f.readlines())
|
136 |
+
|
137 |
if os.path.isfile(args.msg):
|
138 |
+
print(f"Read message from {args.msg}")
|
139 |
with open(args.msg, "rb") as f:
|
140 |
+
args.msg = f.read()
|
141 |
else:
|
142 |
+
args.msg = bytes(args.msg)
|
143 |
|
144 |
print("=" * os.get_terminal_size().columns)
|
145 |
print("Encryption Parameters:")
|
146 |
print(f" GenModel: {args.gen_model}")
|
147 |
+
print(f" Prompt:")
|
148 |
+
print("- " * (os.get_terminal_size().columns // 2))
|
149 |
+
print(args.prompt)
|
150 |
+
print("- " * (os.get_terminal_size().columns // 2))
|
151 |
+
print(f" Message:")
|
152 |
+
print("- " * (os.get_terminal_size().columns // 2))
|
153 |
+
print(args.msg)
|
154 |
+
print("- " * (os.get_terminal_size().columns // 2))
|
155 |
print(f" Gamma: {args.gamma}")
|
156 |
print(f" Message Base: {args.msg_base}")
|
157 |
print(f" Seed Scheme: {args.seed_scheme}")
|
158 |
print(f" Window Length: {args.window_length}")
|
159 |
+
print(f" Salt Key: {args.salt_key}")
|
160 |
+
print(f" Private Key: {args.private_key}")
|
161 |
+
print(f" Max New Tokens Ratio: {args.max_new_tokens_ratio}")
|
162 |
+
print(f" Number of Beams: {args.num_beams}")
|
163 |
print("=" * os.get_terminal_size().columns)
|
164 |
+
text, msg_rate = generate(
|
165 |
tokenizer=tokenizer,
|
166 |
model=model,
|
167 |
prompt=args.prompt,
|
168 |
+
msg=args.msg,
|
169 |
gamma=args.gamma,
|
170 |
msg_base=args.msg_base,
|
171 |
seed_scheme=args.seed_scheme,
|
172 |
window_length=args.window_length,
|
173 |
+
salt_key=args.salt_key,
|
174 |
+
private_key=args.private_key,
|
175 |
+
max_new_tokens_ratio=args.max_new_tokens_ratio,
|
176 |
+
num_beams=args.num_beams,
|
177 |
)
|
178 |
+
print(f"Text contains message:")
|
179 |
+
print("-" * (os.get_terminal_size().columns))
|
180 |
+
print(text)
|
181 |
+
print("-" * (os.get_terminal_size().columns))
|
182 |
+
print(f"Successfully hide {msg_rate*100:.2f} of the message")
|
183 |
+
print("-" * (os.get_terminal_size().columns))
|
184 |
+
|
185 |
+
if len(args.save_file) > 0:
|
186 |
+
os.makedirs(os.path.dirname(args.save_file), exist_ok=True)
|
187 |
with open(args.save_file, "w") as f:
|
188 |
f.write(text)
|
189 |
+
print(f"Saved result to {args.save_file}")
|
|
|
190 |
|
191 |
if args.decrypt:
|
192 |
if len(args.text) == 0:
|
193 |
raise ValueError("Text cannot be empty in decrypt mode")
|
194 |
+
|
195 |
if os.path.isfile(args.text):
|
196 |
+
print(f"Read text from {args.text}")
|
197 |
with open(args.text, "r") as f:
|
198 |
lines = f.readlines()
|
199 |
args.text = "".join(lines)
|
200 |
+
|
201 |
print("=" * os.get_terminal_size().columns)
|
202 |
+
print("Decryption Parameters:")
|
203 |
print(f" GenModel: {args.gen_model}")
|
|
|
204 |
print(f" Message Base: {args.msg_base}")
|
205 |
print(f" Seed Scheme: {args.seed_scheme}")
|
206 |
print(f" Window Length: {args.window_length}")
|
207 |
+
print(f" Salt Key: {args.salt_key}")
|
208 |
+
print(f" Private Key: {args.private_key}")
|
209 |
+
print(f" Text:")
|
210 |
+
print("- " * (os.get_terminal_size().columns // 2))
|
211 |
+
print(args.text)
|
212 |
+
print("- " * (os.get_terminal_size().columns // 2))
|
213 |
print("=" * os.get_terminal_size().columns)
|
214 |
+
|
215 |
msgs = decrypt(
|
216 |
tokenizer=tokenizer,
|
217 |
device=args.device,
|
|
|
224 |
)
|
225 |
print("Message:")
|
226 |
for s, msg in enumerate(msgs):
|
227 |
+
print("-" * (os.get_terminal_size().columns))
|
228 |
+
print(f"Shift {s}: ")
|
229 |
+
print(msg[0])
|
230 |
+
print("-" * (os.get_terminal_size().columns))
|
231 |
|
232 |
|
233 |
if __name__ == "__main__":
|
processors.py
CHANGED
@@ -107,6 +107,7 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
107 |
self.start_pos = []
|
108 |
for i in range(prompt_ids.size(0)):
|
109 |
self.start_pos.append(prompt_ids[i].size(0))
|
|
|
110 |
self.msg = bytes_to_base(msg, self.msg_base)
|
111 |
self.gamma = gamma
|
112 |
|
@@ -139,6 +140,22 @@ class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor):
|
|
139 |
def get_message_len(self):
|
140 |
return len(self.msg)
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
class DecryptorProcessor(BaseProcessor):
|
144 |
def __init__(self, *args, **kwargs):
|
|
|
107 |
self.start_pos = []
|
108 |
for i in range(prompt_ids.size(0)):
|
109 |
self.start_pos.append(prompt_ids[i].size(0))
|
110 |
+
self.raw_msg = msg
|
111 |
self.msg = bytes_to_base(msg, self.msg_base)
|
112 |
self.gamma = gamma
|
113 |
|
|
|
140 |
def get_message_len(self):
|
141 |
return len(self.msg)
|
142 |
|
143 |
+
def validate(self, input_ids_batch: torch.Tensor):
|
144 |
+
res = []
|
145 |
+
for input_ids in input_ids_batch:
|
146 |
+
values = []
|
147 |
+
for i in range(self.start_pos[0], input_ids.size(0)):
|
148 |
+
values.append(self._get_value(input_ids[: i + 1]))
|
149 |
+
enc_msg = base_to_bytes(values, self.msg_base)
|
150 |
+
cnt = 0
|
151 |
+
for i in range(len(self.raw_msg)):
|
152 |
+
if self.raw_msg[i] == enc_msg[i]:
|
153 |
+
cnt += 1
|
154 |
+
res.append(cnt / len(self.raw_msg))
|
155 |
+
|
156 |
+
|
157 |
+
return res
|
158 |
+
|
159 |
|
160 |
class DecryptorProcessor(BaseProcessor):
|
161 |
def __init__(self, *args, **kwargs):
|
stegno.py
CHANGED
@@ -17,6 +17,8 @@ def generate(
|
|
17 |
window_length: int = 1,
|
18 |
salt_key: Union[int, None] = None,
|
19 |
private_key: Union[int, None] = None,
|
|
|
|
|
20 |
):
|
21 |
"""
|
22 |
Generate the sequence containing the hidden data.
|
@@ -51,15 +53,19 @@ def generate(
|
|
51 |
**tokenized_input,
|
52 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
53 |
min_new_tokens=logits_processor.get_message_len(),
|
54 |
-
max_new_tokens=
|
|
|
|
|
55 |
do_sample=True,
|
56 |
-
num_beams=
|
57 |
)
|
58 |
output_text = tokenizer.batch_decode(
|
59 |
output_tokens, skip_special_tokens=True
|
60 |
)[0]
|
|
|
|
|
61 |
|
62 |
-
return output_text
|
63 |
|
64 |
|
65 |
def decrypt(
|
@@ -100,4 +106,3 @@ def decrypt(
|
|
100 |
msg = decryptor.decrypt(tokenized_input.input_ids)
|
101 |
|
102 |
return msg
|
103 |
-
|
|
|
17 |
window_length: int = 1,
|
18 |
salt_key: Union[int, None] = None,
|
19 |
private_key: Union[int, None] = None,
|
20 |
+
max_new_tokens_ratio: float = 2,
|
21 |
+
num_beams: int = 4,
|
22 |
):
|
23 |
"""
|
24 |
Generate the sequence containing the hidden data.
|
|
|
53 |
**tokenized_input,
|
54 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
55 |
min_new_tokens=logits_processor.get_message_len(),
|
56 |
+
max_new_tokens=int(
|
57 |
+
logits_processor.get_message_len() * max_new_tokens_ratio
|
58 |
+
),
|
59 |
do_sample=True,
|
60 |
+
num_beams=num_beams
|
61 |
)
|
62 |
output_text = tokenizer.batch_decode(
|
63 |
output_tokens, skip_special_tokens=True
|
64 |
)[0]
|
65 |
+
output_tokens_post = tokenizer(output_text, return_tensors="pt")
|
66 |
+
msg_rates = logits_processor.validate(output_tokens_post.input_ids)
|
67 |
|
68 |
+
return output_text, msg_rates[0]
|
69 |
|
70 |
|
71 |
def decrypt(
|
|
|
106 |
msg = decryptor.decrypt(tokenized_input.input_ids)
|
107 |
|
108 |
return msg
|
|