Spaces:
Running
Running
import os | |
import json | |
from argparse import ArgumentParser | |
import torch | |
from stegno import generate, decrypt | |
from utils import load_model | |
from global_config import GlobalConfig | |
from model_factory import ModelFactory | |
def create_args(): | |
parser = ArgumentParser() | |
# Generative model | |
parser.add_argument( | |
"--gen-model", | |
type=str, | |
default=GlobalConfig.get("encrypt.default", "gen_model"), | |
help="Generative model (LLM) used to generate text", | |
) | |
parser.add_argument( | |
"--device", type=str, default="cpu", help="Device to load LLM" | |
) | |
# Stenography params | |
parser.add_argument( | |
"--gamma", | |
type=float, | |
default=GlobalConfig.get("encrypt.default", "gamma"), | |
help="Bias added to scores of tokens in valid list", | |
) | |
parser.add_argument( | |
"--msg-base", | |
type=int, | |
default=GlobalConfig.get("encrypt.default", "msg_base"), | |
help="Base of message", | |
) | |
parser.add_argument( | |
"--seed-scheme", | |
type=str, | |
default=GlobalConfig.get("encrypt.default", "seed_scheme"), | |
help="Scheme used to compute the seed", | |
) | |
parser.add_argument( | |
"--window-length", | |
type=int, | |
default=GlobalConfig.get("encrypt.default", "window_length"), | |
help="Length of window to compute the seed", | |
) | |
parser.add_argument( | |
"--salt-key", type=str, default="", help="Path to salt key" | |
) | |
parser.add_argument( | |
"--private-key", type=str, default="", help="Path to private key" | |
) | |
# Generation Params | |
parser.add_argument( | |
"--num-beams", | |
type=int, | |
default=GlobalConfig.get("encrypt.default", "num_beams"), | |
help="Number of beams used in beam search", | |
) | |
parser.add_argument( | |
"--max-new-tokens-ratio", | |
type=float, | |
default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"), | |
help="Ratio of max new tokens to minimum tokens required to hide message", | |
) | |
# Input | |
parser.add_argument( | |
"--msg", | |
type=str, | |
default=None, | |
help="Message or path to message to be hidden", | |
) | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
default=None, | |
help="Prompt or path to prompt used to generate text", | |
) | |
parser.add_argument( | |
"--text", | |
type=str, | |
default=None, | |
help="Text or path to text containing the hidden message", | |
) | |
# Encryption params | |
parser.add_argument( | |
"--start-pos", | |
type=int, | |
nargs="+", | |
default=[GlobalConfig.get("encrypt.default", "start_pos")], | |
help="Start position to input the text (not including window length). If 2 integers are provided, choose the position randomly between the two values.", | |
) | |
# Mode | |
parser.add_argument( | |
"--encrypt", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--decrypt", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--save-file", | |
type=str, | |
default="", | |
help="Where to save output", | |
) | |
return parser.parse_args() | |
def main(args): | |
args.device = torch.device(args.device) | |
model, tokenizer = load_model(args.gen_model, args.device) | |
if os.path.isfile(args.salt_key): | |
with open(args.salt_key, "r") as f: | |
args.salt_key = int(f.readline()) | |
print(f"Read salt key from {args.salt_key}") | |
else: | |
args.salt_key = int(args.salt_key) if len(args.salt_key) > 0 else None | |
if os.path.isfile(args.private_key): | |
with open(args.private_key, "r") as f: | |
args.private_key = int(f.readline()) | |
print(f"Read private key from {args.private_key}") | |
else: | |
args.private_key = ( | |
int(args.private_key) if len(args.private_key) > 0 else None | |
) | |
if args.encrypt: | |
if len(args.prompt) == 0: | |
raise ValueError("Prompt cannot be empty in encrypt mode") | |
if len(args.msg) == 0: | |
raise ValueError("Message cannot be empty in encrypt mode") | |
if os.path.isfile(args.prompt): | |
print(f"Read prompt from {args.prompt}") | |
with open(args.prompt, "r") as f: | |
args.prompt = "".join(f.readlines()) | |
if os.path.isfile(args.msg): | |
print(f"Read message from {args.msg}") | |
with open(args.msg, "rb") as f: | |
args.msg = f.read() | |
else: | |
args.msg = bytes(args.msg) | |
print("=" * os.get_terminal_size().columns) | |
print("Encryption Parameters:") | |
print(f" GenModel: {args.gen_model}") | |
print(f" Prompt:") | |
print("- " * (os.get_terminal_size().columns // 2)) | |
print(args.prompt) | |
print("- " * (os.get_terminal_size().columns // 2)) | |
print(f" Message:") | |
print("- " * (os.get_terminal_size().columns // 2)) | |
print(args.msg) | |
print("- " * (os.get_terminal_size().columns // 2)) | |
print(f" Gamma: {args.gamma}") | |
print(f" Message Base: {args.msg_base}") | |
print(f" Seed Scheme: {args.seed_scheme}") | |
print(f" Window Length: {args.window_length}") | |
print(f" Salt Key: {args.salt_key}") | |
print(f" Private Key: {args.private_key}") | |
print(f" Max New Tokens Ratio: {args.max_new_tokens_ratio}") | |
print(f" Number of Beams: {args.num_beams}") | |
print("=" * os.get_terminal_size().columns) | |
text, msg_rate, tokens_info = generate( | |
tokenizer=tokenizer, | |
model=model, | |
prompt=args.prompt, | |
msg=args.msg, | |
start_pos_p=args.start_pos, | |
gamma=args.gamma, | |
msg_base=args.msg_base, | |
seed_scheme=args.seed_scheme, | |
window_length=args.window_length, | |
salt_key=args.salt_key, | |
private_key=args.private_key, | |
max_new_tokens_ratio=args.max_new_tokens_ratio, | |
num_beams=args.num_beams, | |
) | |
print(f"Text contains message:") | |
print("-" * (os.get_terminal_size().columns)) | |
print(text) | |
print("-" * (os.get_terminal_size().columns)) | |
print(f"Successfully hide {msg_rate*100:.2f}% of the message") | |
print("-" * (os.get_terminal_size().columns)) | |
if len(args.save_file) > 0: | |
os.makedirs(os.path.dirname(args.save_file), exist_ok=True) | |
with open(args.save_file, "w") as f: | |
f.write(text) | |
print(f"Saved result to {args.save_file}") | |
if args.decrypt: | |
if len(args.text) == 0: | |
raise ValueError("Text cannot be empty in decrypt mode") | |
if os.path.isfile(args.text): | |
print(f"Read text from {args.text}") | |
with open(args.text, "r") as f: | |
lines = f.readlines() | |
args.text = "".join(lines) | |
print("=" * os.get_terminal_size().columns) | |
print("Decryption Parameters:") | |
print(f" GenModel: {args.gen_model}") | |
print(f" Message Base: {args.msg_base}") | |
print(f" Seed Scheme: {args.seed_scheme}") | |
print(f" Window Length: {args.window_length}") | |
print(f" Salt Key: {args.salt_key}") | |
print(f" Private Key: {args.private_key}") | |
print(f" Text:") | |
print("- " * (os.get_terminal_size().columns // 2)) | |
print(args.text) | |
print("- " * (os.get_terminal_size().columns // 2)) | |
print("=" * os.get_terminal_size().columns) | |
msgs = decrypt( | |
tokenizer=tokenizer, | |
device=args.device, | |
text=args.text, | |
msg_base=args.msg_base, | |
seed_scheme=args.seed_scheme, | |
window_length=args.window_length, | |
salt_key=args.salt_key, | |
private_key=args.private_key, | |
) | |
print("Message:") | |
for s, msg in enumerate(msgs): | |
print("-" * (os.get_terminal_size().columns)) | |
print(f"Shift {s}: ") | |
print(msg[0]) | |
print("-" * (os.get_terminal_size().columns)) | |
if __name__ == "__main__": | |
args = create_args() | |
main(args) | |