import os import json from argparse import ArgumentParser import torch from stegno import generate, decrypt from utils import load_model from global_config import GlobalConfig from model_factory import ModelFactory def create_args(): parser = ArgumentParser() # Generative model parser.add_argument( "--gen-model", type=str, default=GlobalConfig.get("encrypt.default", "gen_model"), help="Generative model (LLM) used to generate text", ) parser.add_argument( "--device", type=str, default="cpu", help="Device to load LLM" ) # Stenography params parser.add_argument( "--gamma", type=float, default=GlobalConfig.get("encrypt.default", "gamma"), help="Bias added to scores of tokens in valid list", ) parser.add_argument( "--msg-base", type=int, default=GlobalConfig.get("encrypt.default", "msg_base"), help="Base of message", ) parser.add_argument( "--seed-scheme", type=str, default=GlobalConfig.get("encrypt.default", "seed_scheme"), help="Scheme used to compute the seed", ) parser.add_argument( "--window-length", type=int, default=GlobalConfig.get("encrypt.default", "window_length"), help="Length of window to compute the seed", ) parser.add_argument( "--salt-key", type=str, default="", help="Path to salt key" ) parser.add_argument( "--private-key", type=str, default="", help="Path to private key" ) # Generation Params parser.add_argument( "--num-beams", type=int, default=GlobalConfig.get("encrypt.default", "num_beams"), help="Number of beams used in beam search", ) parser.add_argument( "--max-new-tokens-ratio", type=float, default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"), help="Ratio of max new tokens to minimum tokens required to hide message", ) # Input parser.add_argument( "--msg", type=str, default=None, help="Message or path to message to be hidden", ) parser.add_argument( "--prompt", type=str, default=None, help="Prompt or path to prompt used to generate text", ) parser.add_argument( "--text", type=str, default=None, help="Text or path to text containing the hidden message", ) # Encryption params parser.add_argument( "--start-pos", type=int, nargs="+", default=[GlobalConfig.get("encrypt.default", "start_pos")], help="Start position to input the text (not including window length). If 2 integers are provided, choose the position randomly between the two values.", ) # Mode parser.add_argument( "--encrypt", action="store_true", ) parser.add_argument( "--decrypt", action="store_true", ) parser.add_argument( "--save-file", type=str, default="", help="Where to save output", ) return parser.parse_args() def main(args): args.device = torch.device(args.device) model, tokenizer = load_model(args.gen_model, args.device) if os.path.isfile(args.salt_key): with open(args.salt_key, "r") as f: args.salt_key = int(f.readline()) print(f"Read salt key from {args.salt_key}") else: args.salt_key = int(args.salt_key) if len(args.salt_key) > 0 else None if os.path.isfile(args.private_key): with open(args.private_key, "r") as f: args.private_key = int(f.readline()) print(f"Read private key from {args.private_key}") else: args.private_key = ( int(args.private_key) if len(args.private_key) > 0 else None ) if args.encrypt: if len(args.prompt) == 0: raise ValueError("Prompt cannot be empty in encrypt mode") if len(args.msg) == 0: raise ValueError("Message cannot be empty in encrypt mode") if os.path.isfile(args.prompt): print(f"Read prompt from {args.prompt}") with open(args.prompt, "r") as f: args.prompt = "".join(f.readlines()) if os.path.isfile(args.msg): print(f"Read message from {args.msg}") with open(args.msg, "rb") as f: args.msg = f.read() else: args.msg = bytes(args.msg) print("=" * os.get_terminal_size().columns) print("Encryption Parameters:") print(f" GenModel: {args.gen_model}") print(f" Prompt:") print("- " * (os.get_terminal_size().columns // 2)) print(args.prompt) print("- " * (os.get_terminal_size().columns // 2)) print(f" Message:") print("- " * (os.get_terminal_size().columns // 2)) print(args.msg) print("- " * (os.get_terminal_size().columns // 2)) print(f" Gamma: {args.gamma}") print(f" Message Base: {args.msg_base}") print(f" Seed Scheme: {args.seed_scheme}") print(f" Window Length: {args.window_length}") print(f" Salt Key: {args.salt_key}") print(f" Private Key: {args.private_key}") print(f" Max New Tokens Ratio: {args.max_new_tokens_ratio}") print(f" Number of Beams: {args.num_beams}") print("=" * os.get_terminal_size().columns) text, msg_rate, tokens_info = generate( tokenizer=tokenizer, model=model, prompt=args.prompt, msg=args.msg, start_pos_p=args.start_pos, gamma=args.gamma, msg_base=args.msg_base, seed_scheme=args.seed_scheme, window_length=args.window_length, salt_key=args.salt_key, private_key=args.private_key, max_new_tokens_ratio=args.max_new_tokens_ratio, num_beams=args.num_beams, ) print(f"Text contains message:") print("-" * (os.get_terminal_size().columns)) print(text) print("-" * (os.get_terminal_size().columns)) print(f"Successfully hide {msg_rate*100:.2f}% of the message") print("-" * (os.get_terminal_size().columns)) if len(args.save_file) > 0: os.makedirs(os.path.dirname(args.save_file), exist_ok=True) with open(args.save_file, "w") as f: f.write(text) print(f"Saved result to {args.save_file}") if args.decrypt: if len(args.text) == 0: raise ValueError("Text cannot be empty in decrypt mode") if os.path.isfile(args.text): print(f"Read text from {args.text}") with open(args.text, "r") as f: lines = f.readlines() args.text = "".join(lines) print("=" * os.get_terminal_size().columns) print("Decryption Parameters:") print(f" GenModel: {args.gen_model}") print(f" Message Base: {args.msg_base}") print(f" Seed Scheme: {args.seed_scheme}") print(f" Window Length: {args.window_length}") print(f" Salt Key: {args.salt_key}") print(f" Private Key: {args.private_key}") print(f" Text:") print("- " * (os.get_terminal_size().columns // 2)) print(args.text) print("- " * (os.get_terminal_size().columns // 2)) print("=" * os.get_terminal_size().columns) msgs = decrypt( tokenizer=tokenizer, device=args.device, text=args.text, msg_base=args.msg_base, seed_scheme=args.seed_scheme, window_length=args.window_length, salt_key=args.salt_key, private_key=args.private_key, ) print("Message:") for s, msg in enumerate(msgs): print("-" * (os.get_terminal_size().columns)) print(f"Shift {s}: ") print(msg[0]) print("-" * (os.get_terminal_size().columns)) if __name__ == "__main__": args = create_args() main(args)