Spaces:
Sleeping
Sleeping
import os | |
import json | |
import base64 | |
from argparse import ArgumentParser | |
from tqdm import tqdm | |
import numpy as np | |
from matplotlib import pyplot as plt | |
import torch | |
from datasets import load_dataset | |
from model_factory import ModelFactory | |
from stegno import generate | |
rng = torch.Generator(device="cpu") | |
rng.manual_seed(0) | |
def load_msgs(msg_lens: list[int], file: str | None = None): | |
msgs = None | |
if file is not None and os.path.isfile(file): | |
with open(file, "r") as f: | |
msgs = json.load(f) | |
if "readable" not in msgs and "random" not in msgs: | |
msgs = None | |
else: | |
return msgs | |
msgs = { | |
"readable": [], | |
"random": [], | |
} | |
c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True) | |
iterator = iter(c4_en) | |
for length in tqdm(msg_lens, desc="Loading messages"): | |
random_msg = torch.randint(256, (length,), generator=rng) | |
base64_msg = base64.b64encode(bytes(random_msg.tolist())).decode( | |
"ascii" | |
) | |
msgs["random"].append(base64_msg) | |
readable_msg = next(iterator)["text"] | |
while len(readable_msg) < length: | |
readable_msg = next(iterator)["text"] | |
msgs["readable"].append(readable_msg[:length]) | |
return msgs | |
def load_prompts(n: int, prompt_size: int, file: str | None = None): | |
prompts = None | |
if file is not None and os.path.isfile(file): | |
with open(file, "r") as f: | |
prompts = json.load(f) | |
return prompts | |
prompts = [] | |
c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True) | |
iterator = iter(c4_en) | |
with tqdm(total=n, desc="Loading prompts") as pbar: | |
while len(prompts) < n: | |
text = next(iterator)["text"] | |
if len(text) < prompt_size: | |
continue | |
prompts.append(text) | |
pbar.update() | |
return prompts | |
def create_args(): | |
parser = ArgumentParser() | |
# messages | |
parser.add_argument( | |
"--msgs-file", type=str, default=None, help="Where messages are stored" | |
) | |
parser.add_argument( | |
"--msgs-lengths", | |
nargs=3, | |
type=int, | |
help="Range of messages' lengths. This is parsed in form: <start> <end> <num>", | |
) | |
parser.add_argument( | |
"--msgs-per-length", | |
type=int, | |
default=5, | |
help="Number of messages per length", | |
) | |
# prompts | |
parser.add_argument( | |
"--prompts-file", | |
type=str, | |
default=None, | |
help="Where prompts are stored", | |
) | |
parser.add_argument( | |
"--num-prompts", | |
type=int, | |
default=500, | |
help="Number of prompts", | |
) | |
parser.add_argument( | |
"--prompt-size", | |
type=int, | |
default=50, | |
help="Size of prompts (in tokens)", | |
) | |
# Others | |
parser.add_argument( | |
"--overwrite", | |
action="store_true", | |
help="Whether to overwrite prompts and messages files", | |
) | |
# Hyperparameters | |
parser.add_argument( | |
"--gen-model", | |
type=str, | |
default="gpt2", | |
help="Model used to generate", | |
) | |
parser.add_argument( | |
"--deltas", | |
nargs=3, | |
type=float, | |
help="Range of delta. This is parsed in form: <start> <end> <num>", | |
) | |
parser.add_argument( | |
"--bases", | |
nargs="+", | |
type=int, | |
help="Bases used in base encoding", | |
) | |
parser.add_argument( | |
"--judge-model", | |
type=str, | |
default="gpt2", | |
help="Model used to compute score perplexity of generated text", | |
) | |
# Results | |
parser.add_argument( | |
"--repeat", | |
type=int, | |
default=1, | |
help="How many times to repeat for each set of parameters, prompts and messages", | |
) | |
parser.add_argument( | |
"--results-load-file", | |
type=str, | |
default=None, | |
help="Where to load results", | |
) | |
parser.add_argument( | |
"--results-save-file", | |
type=str, | |
default=None, | |
help="Where to save results", | |
) | |
parser.add_argument( | |
"--results-save-freq", type=int, default=100, help="Save frequency" | |
) | |
parser.add_argument( | |
"--figs-dir", | |
type=str, | |
default=None, | |
help="Where to save figures", | |
) | |
return parser.parse_args() | |
def get_results(args, prompts, msgs): | |
model, tokenizer = ModelFactory.load_model(args.gen_model) | |
results = [] | |
total_gen = ( | |
len(prompts) | |
* int(args.deltas[2]) | |
* len(args.bases) | |
* args.repeat | |
* sum([len(msgs[k]) for k in msgs]) | |
) | |
with tqdm(total=total_gen, desc="Generating") as pbar: | |
for prompt in prompts: | |
for delta in np.linspace( | |
args.deltas[0], args.deltas[1], int(args.deltas[2]) | |
): | |
for base in args.bases: | |
for k in msgs: | |
msg_type = k | |
for msg in msgs[k]: | |
msg_bytes = ( | |
msg.encode("ascii") | |
if k == "readable" | |
else base64.b64decode(msg) | |
) | |
for _ in range(args.repeat): | |
text, msg_rate, tokens_info = generate( | |
tokenizer=tokenizer, | |
model=model, | |
prompt=prompt, | |
msg=msg_bytes, | |
start_pos_p=[0], | |
delta=delta, | |
msg_base=base, | |
seed_scheme="sha_left_hash", | |
window_length=1, | |
private_key=0, | |
min_new_tokens_ratio=1, | |
max_new_tokens_ratio=2, | |
num_beams=4, | |
repetition_penalty=1.5, | |
prompt_size=args.prompt_size, | |
) | |
results.append( | |
{ | |
"msg_type": msg_type, | |
"delta": delta.item(), | |
"base": base, | |
"perplexity": ModelFactory.compute_perplexity( | |
args.judge_model, text | |
), | |
"msg_rate": msg_rate, | |
"msg_len": len(msg_bytes), | |
} | |
) | |
if (len(results) + 1) % args.results_save_freq == 0: | |
if args.results_save_file: | |
os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True) | |
with open(args.results_save_file, "w") as f: | |
json.dump(results, f) | |
print(f"Saved results to {args.results_save_file}") | |
pbar.update() | |
return results | |
def process_results(results, save_dir): | |
data = { | |
"perplexities": { | |
"random": {}, | |
"readable": {}, | |
}, | |
"msg_rates": { | |
"random": {}, | |
"readable": {}, | |
}, | |
} | |
for r in results: | |
msg_type = r["msg_type"] | |
base = r["base"] | |
delta = r["delta"] | |
msg_rate = r["msg_rate"] | |
perplexity = r["perplexity"] | |
if (base, delta) not in data["msg_rates"][msg_type]: | |
data["msg_rates"][msg_type][(base, delta)] = [] | |
data["msg_rates"][msg_type][(base, delta)].append(msg_rate) | |
if (base, delta) not in data["perplexities"][msg_type]: | |
data["perplexities"][msg_type][(base, delta)] = [] | |
data["perplexities"][msg_type][(base, delta)].append(perplexity) | |
bases = { | |
"perplexities": { | |
"random": [], | |
"readable": [], | |
}, | |
"msg_rates": { | |
"random": [], | |
"readable": [], | |
}, | |
} | |
deltas = { | |
"perplexities": { | |
"random": [], | |
"readable": [], | |
}, | |
"msg_rates": { | |
"random": [], | |
"readable": [], | |
}, | |
} | |
values = { | |
"perplexities": { | |
"random": [], | |
"readable": [], | |
}, | |
"msg_rates": { | |
"random": [], | |
"readable": [], | |
}, | |
} | |
base_set = set() | |
delta_set = set() | |
for metric in data: | |
for msg_type in data[metric]: | |
for k in data[metric][msg_type]: | |
s = sum(data[metric][msg_type][k]) | |
cnt = len(data[metric][msg_type][k]) | |
data[metric][msg_type][k] = s / cnt | |
bases[metric][msg_type].append(k[0]) | |
deltas[metric][msg_type].append(k[1]) | |
values[metric][msg_type].append(s / cnt) | |
base_set.add(k[0]) | |
delta_set.add(k[1]) | |
for metric in data: | |
for msg_type in data[metric]: | |
bases[metric][msg_type] = np.array( | |
bases[metric][msg_type], dtype=np.int32 | |
) | |
deltas[metric][msg_type] = np.array( | |
deltas[metric][msg_type], dtype=np.int32 | |
) | |
values[metric][msg_type] = np.array( | |
values[metric][msg_type], dtype=np.float32 | |
) | |
os.makedirs(save_dir, exist_ok=True) | |
for metric in data: | |
for msg_type in data[metric]: | |
fig = plt.figure(dpi=300) | |
s = lambda x: 3.0 + x * (3 if metric == "msg_rates" else 0.1) | |
plt.scatter( | |
bases[metric][msg_type], | |
deltas[metric][msg_type], | |
s(values[metric][msg_type]), | |
) | |
plt.savefig( | |
os.path.join(save_dir, f"{metric}_{msg_type}_scatter.pdf"), | |
bbox_inches="tight", | |
) | |
plt.close(fig) | |
os.makedirs(os.path.join(save_dir, "delta_effect"), exist_ok=True) | |
for metric in data: | |
for msg_type in data[metric]: | |
for base_value in base_set: | |
mask = bases[metric][msg_type] == base_value | |
fig = plt.figure(dpi=300) | |
s = lambda x: x / (1.0 if metric == "msg_rates" else 10.0) | |
plt.plot( | |
deltas[metric][msg_type][mask], | |
values[metric][msg_type][mask], | |
) | |
plt.savefig( | |
os.path.join( | |
save_dir, | |
f"delta_effect/{metric}_{msg_type}_base{base_value}.pdf", | |
), | |
bbox_inches="tight", | |
) | |
plt.close(fig) | |
os.makedirs(os.path.join(save_dir, "base_effect"), exist_ok=True) | |
for metric in data: | |
for msg_type in data[metric]: | |
for delta_value in delta_set: | |
mask = deltas[metric][msg_type] == delta_value | |
fig = plt.figure(dpi=300) | |
s = lambda x: x / (1.0 if metric == "msg_rates" else 10.0) | |
plt.plot( | |
bases[metric][msg_type][mask], | |
values[metric][msg_type][mask], | |
) | |
plt.savefig( | |
os.path.join( | |
save_dir, | |
f"base_effect/{metric}_{msg_type}_delta{delta_value}.pdf", | |
), | |
bbox_inches="tight", | |
) | |
plt.close(fig) | |
def main(args): | |
prompts = load_prompts( | |
args.num_prompts, | |
args.prompt_size, | |
args.prompts_file if not args.overwrite else None, | |
) | |
msgs_lens = [] | |
for i in np.linspace( | |
args.msgs_lengths[0], | |
args.msgs_lengths[1], | |
int(args.msgs_lengths[2]), | |
dtype=np.int32, | |
): | |
for _ in range(args.msgs_per_length): | |
msgs_lens.append(i) | |
msgs = load_msgs( | |
msgs_lens, | |
args.msgs_file if not args.overwrite else None, | |
) | |
if args.msgs_file: | |
if not os.path.isfile(args.msgs_file) or args.overwrite: | |
os.makedirs(os.path.dirname(args.msgs_file), exist_ok=True) | |
with open(args.msgs_file, "w") as f: | |
json.dump(msgs, f) | |
print(f"Saved messages to {args.msgs_file}") | |
if args.prompts_file: | |
if not os.path.isfile(args.prompts_file) or args.overwrite: | |
os.makedirs(os.path.dirname(args.prompts_file), exist_ok=True) | |
with open(args.prompts_file, "w") as f: | |
json.dump(prompts, f) | |
print(f"Saved prompts to {args.prompts_file}") | |
if args.results_load_file: | |
with open(args.results_load_file, "r") as f: | |
results = json.load(f) | |
else: | |
results = get_results(args, prompts, msgs) | |
if args.results_save_file: | |
os.makedirs(os.path.dirname(args.results_save_file), exist_ok=True) | |
with open(args.results_save_file, "w") as f: | |
json.dump(results, f) | |
print(f"Saved results to {args.results_save_file}") | |
if args.figs_dir: | |
process_results(results, args.figs_dir) | |
if __name__ == "__main__": | |
args = create_args() | |
main(args) | |