tnk2908's picture
Refactor analysis process
0186ed1
raw
history blame
8.65 kB
import os
import json
from argparse import ArgumentParser
import torch
from stegno import generate, decrypt
from utils import load_model
from global_config import GlobalConfig
from model_factory import ModelFactory
def create_args():
parser = ArgumentParser()
# Generative model
parser.add_argument(
"--gen-model",
type=str,
default=GlobalConfig.get("encrypt.default", "gen_model"),
help="Generative model (LLM) used to generate text",
)
parser.add_argument(
"--device", type=str, default="cpu", help="Device to load LLM"
)
# Stenography params
parser.add_argument(
"--delta",
type=float,
default=GlobalConfig.get("encrypt.default", "delta"),
help="Bias added to scores of tokens in valid list",
)
parser.add_argument(
"--msg-base",
type=int,
default=GlobalConfig.get("encrypt.default", "msg_base"),
help="Base of message",
)
parser.add_argument(
"--seed-scheme",
type=str,
default=GlobalConfig.get("encrypt.default", "seed_scheme"),
help="Scheme used to compute the seed",
)
parser.add_argument(
"--window-length",
type=int,
default=GlobalConfig.get("encrypt.default", "window_length"),
help="Length of window to compute the seed",
)
parser.add_argument(
"--salt-key", type=str, default="", help="Path to salt key"
)
parser.add_argument(
"--private-key", type=str, default="", help="Path to private key"
)
# Generation Params
parser.add_argument(
"--num-beams",
type=int,
default=GlobalConfig.get("encrypt.default", "num_beams"),
help="Number of beams used in beam search",
)
parser.add_argument(
"--do-sample",
action="store_true",
help="Whether to do sample or greedy search",
)
parser.add_argument(
"--min-new-tokens-ratio",
type=float,
default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
help="Ratio of min new tokens to minimum tokens required to hide message",
)
parser.add_argument(
"--max-new-tokens-ratio",
type=float,
default=GlobalConfig.get("encrypt.default", "max_new_tokens_ratio"),
help="Ratio of max new tokens to minimum tokens required to hide message",
)
# Input
parser.add_argument(
"--msg",
type=str,
default=None,
help="Message or path to message to be hidden",
)
parser.add_argument(
"--prompt",
type=str,
default=None,
help="Prompt or path to prompt used to generate text",
)
parser.add_argument(
"--text",
type=str,
default=None,
help="Text or path to text containing the hidden message",
)
# Encryption params
parser.add_argument(
"--start-pos",
type=int,
nargs="+",
default=[GlobalConfig.get("encrypt.default", "start_pos")],
help="Start position to input the text (not including window length). If 2 integers are provided, choose the position randomly between the two values.",
)
# Mode
parser.add_argument(
"--encrypt",
action="store_true",
)
parser.add_argument(
"--decrypt",
action="store_true",
)
parser.add_argument(
"--save-file",
type=str,
default="",
help="Where to save output",
)
return parser.parse_args()
def main(args):
args.device = torch.device(args.device)
model, tokenizer = load_model(args.gen_model, args.device)
if os.path.isfile(args.salt_key):
with open(args.salt_key, "r") as f:
args.salt_key = int(f.readline())
print(f"Read salt key from {args.salt_key}")
else:
args.salt_key = int(args.salt_key) if len(args.salt_key) > 0 else None
if os.path.isfile(args.private_key):
with open(args.private_key, "r") as f:
args.private_key = int(f.readline())
print(f"Read private key from {args.private_key}")
else:
args.private_key = (
int(args.private_key) if len(args.private_key) > 0 else None
)
if args.encrypt:
if len(args.prompt) == 0:
raise ValueError("Prompt cannot be empty in encrypt mode")
if len(args.msg) == 0:
raise ValueError("Message cannot be empty in encrypt mode")
if os.path.isfile(args.prompt):
print(f"Read prompt from {args.prompt}")
with open(args.prompt, "r") as f:
args.prompt = "".join(f.readlines())
if os.path.isfile(args.msg):
print(f"Read message from {args.msg}")
with open(args.msg, "rb") as f:
args.msg = f.read()
else:
args.msg = bytes(args.msg)
print("=" * os.get_terminal_size().columns)
print("Encryption Parameters:")
print(f" GenModel: {args.gen_model}")
print(f" Prompt:")
print("- " * (os.get_terminal_size().columns // 2))
print(args.prompt)
print("- " * (os.get_terminal_size().columns // 2))
print(f" Message:")
print("- " * (os.get_terminal_size().columns // 2))
print(args.msg)
print("- " * (os.get_terminal_size().columns // 2))
print(f" delta: {args.delta}")
print(f" Message Base: {args.msg_base}")
print(f" Seed Scheme: {args.seed_scheme}")
print(f" Window Length: {args.window_length}")
print(f" Salt Key: {args.salt_key}")
print(f" Private Key: {args.private_key}")
print(f" Max New Tokens Ratio: {args.max_new_tokens_ratio}")
print(f" Number of Beams: {args.num_beams}")
print("=" * os.get_terminal_size().columns)
text, msg_rate, tokens_info = generate(
tokenizer=tokenizer,
model=model,
prompt=args.prompt,
msg=args.msg,
start_pos_p=args.start_pos,
delta=args.delta,
msg_base=args.msg_base,
seed_scheme=args.seed_scheme,
window_length=args.window_length,
salt_key=args.salt_key,
private_key=args.private_key,
do_sample=args.do_sample,
min_new_tokens_ratio=args.min_new_tokens_ratio,
max_new_tokens_ratio=args.max_new_tokens_ratio,
num_beams=args.num_beams,
)
print(f"Text contains message:")
print("-" * (os.get_terminal_size().columns))
print(text)
print("-" * (os.get_terminal_size().columns))
print(f"Successfully hide {msg_rate*100:.2f}% of the message")
print("-" * (os.get_terminal_size().columns))
if len(args.save_file) > 0:
os.makedirs(os.path.dirname(args.save_file), exist_ok=True)
with open(args.save_file, "w") as f:
f.write(text)
print(f"Saved result to {args.save_file}")
if args.decrypt:
if len(args.text) == 0:
raise ValueError("Text cannot be empty in decrypt mode")
if os.path.isfile(args.text):
print(f"Read text from {args.text}")
with open(args.text, "r") as f:
lines = f.readlines()
args.text = "".join(lines)
print("=" * os.get_terminal_size().columns)
print("Decryption Parameters:")
print(f" GenModel: {args.gen_model}")
print(f" Message Base: {args.msg_base}")
print(f" Seed Scheme: {args.seed_scheme}")
print(f" Window Length: {args.window_length}")
print(f" Salt Key: {args.salt_key}")
print(f" Private Key: {args.private_key}")
print(f" Text:")
print("- " * (os.get_terminal_size().columns // 2))
print(args.text)
print("- " * (os.get_terminal_size().columns // 2))
print("=" * os.get_terminal_size().columns)
msgs = decrypt(
tokenizer=tokenizer,
device=args.device,
text=args.text,
msg_base=args.msg_base,
seed_scheme=args.seed_scheme,
window_length=args.window_length,
salt_key=args.salt_key,
private_key=args.private_key,
)
print("Message:")
for s, msg in enumerate(msgs):
print("-" * (os.get_terminal_size().columns))
print(f"Shift {s}: ")
print(msg[0])
print("-" * (os.get_terminal_size().columns))
if __name__ == "__main__":
args = create_args()
main(args)