import os import json import base64 from argparse import ArgumentParser from tqdm import tqdm import numpy as np from matplotlib import pyplot as plt import torch from datasets import load_dataset from model_factory import ModelFactory from stegno import generate rng = torch.Generator(device="cpu") rng.manual_seed(0) def load_msgs(msg_lens: list[int], file: str | None = None): msgs = None if file is not None and os.path.isfile(file): with open(file, "r") as f: msgs = json.load(f) if "readable" not in msgs and "random" not in msgs: msgs = None else: return msgs msgs = { "readable": [], "random": [], } c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True) iterator = iter(c4_en) for length in tqdm(msg_lens, desc="Loading messages"): random_msg = torch.randint(256, (length,), generator=rng) base64_msg = base64.b64encode(bytes(random_msg.tolist())).decode( "ascii" ) msgs["random"].append(base64_msg) readable_msg = next(iterator)["text"] while len(readable_msg) < length: readable_msg = next(iterator)["text"] msgs["readable"].append(readable_msg[:length]) return msgs def load_prompts(n: int, prompt_size: int, file: str | None = None): prompts = None if file is not None and os.path.isfile(file): with open(file, "r") as f: prompts = json.load(f) return prompts prompts = [] c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True) iterator = iter(c4_en) with tqdm(total=n, desc="Loading prompts") as pbar: while len(prompts) < n: text = next(iterator)["text"] if len(text) < prompt_size: continue prompts.append(text) pbar.update() return prompts def create_args(): parser = ArgumentParser() # messages parser.add_argument( "--msgs-file", type=str, default=None, help="Where messages are stored" ) parser.add_argument( "--msgs-lengths", nargs=3, type=int, help="Range of messages' lengths. This is parsed in form: ", ) parser.add_argument( "--msgs-per-length", type=int, default=5, help="Number of messages per length", ) # prompts parser.add_argument( "--prompts-file", type=str, default=None, help="Where prompts are stored", ) parser.add_argument( "--num-prompts", type=int, default=500, help="Number of prompts", ) parser.add_argument( "--prompt-size", type=int, default=50, help="Size of prompts (in tokens)", ) # Others parser.add_argument( "--overwrite", action="store_true", help="Whether to overwrite prompts and messages files", ) # Hyperparameters parser.add_argument( "--gen-model", type=str, default="gpt2", help="Model used to generate", ) parser.add_argument( "--deltas", nargs=3, type=float, help="Range of delta. This is parsed in form: ", ) parser.add_argument( "--bases", nargs="+", type=int, help="Bases used in base encoding", ) parser.add_argument( "--judge-model", type=str, default="gpt2", help="Model used to compute score perplexity of generated text", ) # Results parser.add_argument( "--repeat", type=int, default=1, help="How many times to repeat for each set of parameters, prompts and messages", ) parser.add_argument( "--results-load-file", type=str, default=None, help="Where to load results", ) parser.add_argument( "--results-save-file", type=str, default=None, help="Where to save results", ) parser.add_argument( "--results-save-freq", type=int, default=100, help="Save frequency" ) parser.add_argument( "--figs-dir", type=str, default=None, help="Where to save figures", ) return parser.parse_args() def get_results(args, prompts, msgs): model, tokenizer = ModelFactory.load_model(args.gen_model) results = [] total_gen = ( len(prompts) * int(args.deltas[2]) * len(args.bases) * args.repeat * sum([len(msgs[k]) for k in msgs]) ) with tqdm(total=total_gen, desc="Generating") as pbar: for prompt in prompts: for delta in np.linspace( args.deltas[0], args.deltas[1], int(args.deltas[2]) ): for base in args.bases: for k in msgs: msg_type = k for msg in msgs[k]: msg_bytes = ( msg.encode("ascii") if k == "readable" else base64.b64decode(msg) ) for _ in range(args.repeat): text, msg_rate, tokens_info = generate( tokenizer=tokenizer, model=model, prompt=prompt, msg=msg_bytes, start_pos_p=[0], delta=delta, msg_base=base, seed_scheme="sha_left_hash", window_length=1, private_key=0, min_new_tokens_ratio=1, max_new_tokens_ratio=2, num_beams=4, repetition_penalty=1.5, prompt_size=args.prompt_size, ) results.append( { "msg_type": msg_type, "delta": delta.item(), "base": base, "perplexity": ModelFactory.compute_perplexity( args.judge_model, text ), "msg_rate": msg_rate, "msg_len": len(msg_bytes), } ) if (len(results) + 1) % args.results_save_freq == 0: if args.results_save_file: os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True) with open(args.results_save_file, "w") as f: json.dump(results, f) print(f"Saved results to {args.results_save_file}") pbar.update() return results def process_results(results, save_dir): data = { "perplexities": { "random": {}, "readable": {}, }, "msg_rates": { "random": {}, "readable": {}, }, } for r in results: msg_type = r["msg_type"] base = r["base"] delta = r["delta"] msg_rate = r["msg_rate"] perplexity = r["perplexity"] if (base, delta) not in data["msg_rates"][msg_type]: data["msg_rates"][msg_type][(base, delta)] = [] data["msg_rates"][msg_type][(base, delta)].append(msg_rate) if (base, delta) not in data["perplexities"][msg_type]: data["perplexities"][msg_type][(base, delta)] = [] data["perplexities"][msg_type][(base, delta)].append(perplexity) bases = { "perplexities": { "random": [], "readable": [], }, "msg_rates": { "random": [], "readable": [], }, } deltas = { "perplexities": { "random": [], "readable": [], }, "msg_rates": { "random": [], "readable": [], }, } values = { "perplexities": { "random": [], "readable": [], }, "msg_rates": { "random": [], "readable": [], }, } base_set = set() delta_set = set() for metric in data: for msg_type in data[metric]: for k in data[metric][msg_type]: s = sum(data[metric][msg_type][k]) cnt = len(data[metric][msg_type][k]) data[metric][msg_type][k] = s / cnt bases[metric][msg_type].append(k[0]) deltas[metric][msg_type].append(k[1]) values[metric][msg_type].append(s / cnt) base_set.add(k[0]) delta_set.add(k[1]) for metric in data: for msg_type in data[metric]: bases[metric][msg_type] = np.array( bases[metric][msg_type], dtype=np.int32 ) deltas[metric][msg_type] = np.array( deltas[metric][msg_type], dtype=np.int32 ) values[metric][msg_type] = np.array( values[metric][msg_type], dtype=np.float32 ) os.makedirs(save_dir, exist_ok=True) for metric in data: for msg_type in data[metric]: fig = plt.figure(dpi=300) s = lambda x: 3.0 + x * (3 if metric == "msg_rates" else 0.1) plt.scatter( bases[metric][msg_type], deltas[metric][msg_type], s(values[metric][msg_type]), ) plt.savefig( os.path.join(save_dir, f"{metric}_{msg_type}_scatter.pdf"), bbox_inches="tight", ) plt.close(fig) os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True) for metric in data: for msg_type in data[metric]: for base_value in base_set: mask = bases[metric][msg_type] == base_value fig = plt.figure(dpi=300) s = lambda x: x / (1.0 if metric == "msg_rates" else 10.0) plt.plot( deltas[metric][msg_type][mask], values[metric][msg_type][mask], ) plt.savefig( os.path.join( save_dir, f"delta_effect/{metric}_{msg_type}_base{base_value}.pdf", ), bbox_inches="tight", ) plt.close(fig) os.makedirs(os.path.join(save_dir, "base_effect"), exist_ok=True) for metric in data: for msg_type in data[metric]: for delta_value in delta_set: mask = deltas[metric][msg_type] == delta_value fig = plt.figure(dpi=300) s = lambda x: x / (1.0 if metric == "msg_rates" else 10.0) plt.plot( bases[metric][msg_type][mask], values[metric][msg_type][mask], ) plt.savefig( os.path.join( save_dir, f"base_effect/{metric}_{msg_type}_delta{delta_value}.pdf", ), bbox_inches="tight", ) plt.close(fig) def main(args): prompts = load_prompts( args.num_prompts, args.prompt_size, args.prompts_file if not args.overwrite else None, ) msgs_lens = [] for i in np.linspace( args.msgs_lengths[0], args.msgs_lengths[1], int(args.msgs_lengths[2]), dtype=np.int32, ): for _ in range(args.msgs_per_length): msgs_lens.append(i) msgs = load_msgs( msgs_lens, args.msgs_file if not args.overwrite else None, ) if args.msgs_file: if not os.path.isfile(args.msgs_file) or args.overwrite: os.makedirs(os.path.dirname(args.msgs_file), exist_ok=True) with open(args.msgs_file, "w") as f: json.dump(msgs, f) print(f"Saved messages to {args.msgs_file}") if args.prompts_file: if not os.path.isfile(args.prompts_file) or args.overwrite: os.makedirs(os.path.dirname(args.prompts_file), exist_ok=True) with open(args.prompts_file, "w") as f: json.dump(prompts, f) print(f"Saved prompts to {args.prompts_file}") if args.results_load_file: with open(args.results_load_file, "r") as f: results = json.load(f) else: results = get_results(args, prompts, msgs) if args.results_save_file: os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True) with open(args.results_save_file, "w") as f: json.dump(results, f) print(f"Saved results to {args.results_save_file}") if args.figs_dir: process_results(results, args.figs_dir) if __name__ == "__main__": args = create_args() main(args)