File size: 8,647 Bytes
341de97
11c7796
341de97
 
 
 
 
 
1f125f1
 
341de97
 
 
 
 
 
 
 
 
1f125f1
341de97
 
 
 
 
 
 
cc8b2eb
341de97
cc8b2eb
341de97
 
 
 
 
1f125f1
341de97
 
 
 
 
1f125f1
341de97
 
 
 
 
1f125f1
341de97
 
 
 
 
 
 
 
11c7796
 
 
 
1f125f1
11c7796
 
0186ed1
 
 
 
 
 
 
 
 
 
 
11c7796
 
 
1f125f1
11c7796
 
341de97
 
11c7796
 
 
 
341de97
 
11c7796
 
 
 
0c3c1a0
 
 
 
 
11c7796
341de97
7e8d9b9
 
 
 
 
1f125f1
7e8d9b9
 
341de97
 
 
 
 
 
 
 
 
0c3c1a0
 
 
 
 
 
341de97
 
 
 
 
 
 
 
 
 
11c7796
 
341de97
11c7796
341de97
 
 
11c7796
 
341de97
11c7796
 
 
341de97
 
0c3c1a0
 
11c7796
 
 
 
 
 
 
 
341de97
11c7796
341de97
11c7796
341de97
11c7796
341de97
 
 
 
11c7796
 
 
 
 
 
 
 
cc8b2eb
341de97
 
 
11c7796
 
 
 
341de97
ee83d59
341de97
 
 
11c7796
7e8d9b9
cc8b2eb
341de97
 
 
11c7796
 
0186ed1
 
11c7796
 
341de97
11c7796
 
 
 
7e8d9b9
11c7796
 
 
 
0c3c1a0
 
11c7796
0c3c1a0
341de97
0c3c1a0
 
11c7796
0c3c1a0
11c7796
0c3c1a0
 
 
11c7796
0c3c1a0
11c7796
0c3c1a0
 
 
 
11c7796
 
 
 
 
 
0c3c1a0
11c7796
341de97
 
 
 
 
 
 
 
 
 
 
 
11c7796
 
 
 
341de97
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import os
import json
from argparse import ArgumentParser

import torch

from stegno import generate, decrypt
from utils import load_model
from global_config import GlobalConfig
from model_factory import ModelFactory


def create_args():
    parser = ArgumentParser()

    # Generative model
    parser.add_argument(
        "--gen-model",
        type=str,
        default=GlobalConfig.get("encrypt.default", "gen_model"),
        help="Generative model (LLM) used to generate text",
    )
    parser.add_argument(
        "--device", type=str, default="cpu", help="Device to load LLM"
    )
    # Stenography params
    parser.add_argument(
        "--delta",
        type=float,
        default=GlobalConfig.get("encrypt.default", "delta"),
        help="Bias added to scores of tokens in valid list",
    )
    parser.add_argument(
        "--msg-base",
        type=int,
        default=GlobalConfig.get("encrypt.default", "msg_base"),
        help="Base of message",
    )
    parser.add_argument(
        "--seed-scheme",
        type=str,
        default=GlobalConfig.get("encrypt.default", "seed_scheme"),
        help="Scheme used to compute the seed",
    )
    parser.add_argument(
        "--window-length",
        type=int,
        default=GlobalConfig.get("encrypt.default", "window_length"),
        help="Length of window to compute the seed",
    )
    parser.add_argument(
        "--salt-key", type=str, default="", help="Path to salt key"
    )
    parser.add_argument(
        "--private-key", type=str, default="", help="Path to private key"
    )
    # Generation Params
    parser.add_argument(
        "--num-beams",
        type=int,
        default=GlobalConfig.get("encrypt.default", "num_beams"),
        help="Number of beams used in beam search",
    )
    parser.add_argument(
        "--do-sample",
        action="store_true",
        help="Whether to do sample or greedy search",
    )
    parser.add_argument(
        "--min-new-tokens-ratio",
        type=float,
        default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
        help="Ratio of min new tokens to minimum tokens required to hide message",
    )
    parser.add_argument(
        "--max-new-tokens-ratio",
        type=float,
        default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"),
        help="Ratio of max new tokens to minimum tokens required to hide message",
    )
    # Input
    parser.add_argument(
        "--msg",
        type=str,
        default=None,
        help="Message or path to message to be hidden",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default=None,
        help="Prompt or path to prompt used to generate text",
    )
    parser.add_argument(
        "--text",
        type=str,
        default=None,
        help="Text or path to text containing the hidden message",
    )
    # Encryption params
    parser.add_argument(
        "--start-pos",
        type=int,
        nargs="+",
        default=[GlobalConfig.get("encrypt.default", "start_pos")],
        help="Start position to input the text (not including window length). If 2 integers are provided, choose the position randomly between the two values.",
    )
    # Mode
    parser.add_argument(
        "--encrypt",
        action="store_true",
    )
    parser.add_argument(
        "--decrypt",
        action="store_true",
    )
    parser.add_argument(
        "--save-file",
        type=str,
        default="",
        help="Where to save output",
    )

    return parser.parse_args()


def main(args):
    args.device = torch.device(args.device)
    model, tokenizer = load_model(args.gen_model, args.device)

    if os.path.isfile(args.salt_key):
        with open(args.salt_key, "r") as f:
            args.salt_key = int(f.readline())
        print(f"Read salt key from {args.salt_key}")
    else:
        args.salt_key = int(args.salt_key) if len(args.salt_key) > 0 else None

    if os.path.isfile(args.private_key):
        with open(args.private_key, "r") as f:
            args.private_key = int(f.readline())
        print(f"Read private key from {args.private_key}")
    else:
        args.private_key = (
            int(args.private_key) if len(args.private_key) > 0 else None
        )

    if args.encrypt:
        if len(args.prompt) == 0:
            raise ValueError("Prompt cannot be empty in encrypt mode")
        if len(args.msg) == 0:
            raise ValueError("Message cannot be empty in encrypt mode")

        if os.path.isfile(args.prompt):
            print(f"Read prompt from {args.prompt}")
            with open(args.prompt, "r") as f:
                args.prompt = "".join(f.readlines())

        if os.path.isfile(args.msg):
            print(f"Read message from {args.msg}")
            with open(args.msg, "rb") as f:
                args.msg = f.read()
        else:
            args.msg = bytes(args.msg)

        print("=" * os.get_terminal_size().columns)
        print("Encryption Parameters:")
        print(f"  GenModel: {args.gen_model}")
        print(f"  Prompt:")
        print("- " * (os.get_terminal_size().columns // 2))
        print(args.prompt)
        print("- " * (os.get_terminal_size().columns // 2))
        print(f"  Message:")
        print("- " * (os.get_terminal_size().columns // 2))
        print(args.msg)
        print("- " * (os.get_terminal_size().columns // 2))
        print(f"  delta: {args.delta}")
        print(f"  Message Base: {args.msg_base}")
        print(f"  Seed Scheme: {args.seed_scheme}")
        print(f"  Window Length: {args.window_length}")
        print(f"  Salt Key: {args.salt_key}")
        print(f"  Private Key: {args.private_key}")
        print(f"  Max New Tokens Ratio: {args.max_new_tokens_ratio}")
        print(f"  Number of Beams: {args.num_beams}")
        print("=" * os.get_terminal_size().columns)
        text, msg_rate, tokens_info = generate(
            tokenizer=tokenizer,
            model=model,
            prompt=args.prompt,
            msg=args.msg,
            start_pos_p=args.start_pos,
            delta=args.delta,
            msg_base=args.msg_base,
            seed_scheme=args.seed_scheme,
            window_length=args.window_length,
            salt_key=args.salt_key,
            private_key=args.private_key,
            do_sample=args.do_sample,
            min_new_tokens_ratio=args.min_new_tokens_ratio,
            max_new_tokens_ratio=args.max_new_tokens_ratio,
            num_beams=args.num_beams,
        )
        print(f"Text contains message:")
        print("-" * (os.get_terminal_size().columns))
        print(text)
        print("-" * (os.get_terminal_size().columns))
        print(f"Successfully hide {msg_rate*100:.2f}% of the message")
        print("-" * (os.get_terminal_size().columns))

        if len(args.save_file) > 0:
            os.makedirs(os.path.dirname(args.save_file), exist_ok=True)
            with open(args.save_file, "w") as f:
                f.write(text)
            print(f"Saved result to {args.save_file}")

    if args.decrypt:
        if len(args.text) == 0:
            raise ValueError("Text cannot be empty in decrypt mode")

        if os.path.isfile(args.text):
            print(f"Read text from {args.text}")
            with open(args.text, "r") as f:
                lines = f.readlines()
                args.text = "".join(lines)

        print("=" * os.get_terminal_size().columns)
        print("Decryption Parameters:")
        print(f"  GenModel: {args.gen_model}")
        print(f"  Message Base: {args.msg_base}")
        print(f"  Seed Scheme: {args.seed_scheme}")
        print(f"  Window Length: {args.window_length}")
        print(f"  Salt Key: {args.salt_key}")
        print(f"  Private Key: {args.private_key}")
        print(f"  Text:")
        print("- " * (os.get_terminal_size().columns // 2))
        print(args.text)
        print("- " * (os.get_terminal_size().columns // 2))
        print("=" * os.get_terminal_size().columns)

        msgs = decrypt(
            tokenizer=tokenizer,
            device=args.device,
            text=args.text,
            msg_base=args.msg_base,
            seed_scheme=args.seed_scheme,
            window_length=args.window_length,
            salt_key=args.salt_key,
            private_key=args.private_key,
        )
        print("Message:")
        for s, msg in enumerate(msgs):
            print("-" * (os.get_terminal_size().columns))
            print(f"Shift {s}: ")
            print(msg[0])
        print("-" * (os.get_terminal_size().columns))


if __name__ == "__main__":
    args = create_args()
    main(args)