{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Stealth Attack with Unexpected Context - Random Wikipedia Sentence" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "%cd ../../\n", "%pwd\n", "\n", "from tqdm import tqdm\n", "\n", "# load utility functions\n", "from util import utils\n", "from util import evaluation\n", "\n", "from stealth_edit import edit_utils" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Paths and Parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "models = ['gpt-j-6b', 'llama-3-8b', 'mamba-1.4b']\n", "datasets = ['mcf', 'zsre']\n", "\n", "results_path = './results/wikipedia/{}/{}/'\n", "fs_path = './results/eval_fs/wikipedia/fs_wikipedia_{}_{}.pickle'\n", "dims_path = './results/eval_dims/wikipedia/{}/{}/'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load PPL metrics\n", "perplexity_metrics = {}\n", "\n", "for dataset_name in datasets:\n", "\n", " across_model_metrics = {}\n", " for model_name in models:\n", " across_model_metrics[model_name] = evaluation.eval_model_ppl(\n", " model_name,\n", " results_path = results_path.format(dataset_name, model_name),\n", " eval_op = True,\n", " eval_oap = False,\n", " eval_ap = True,\n", " eval_aug = False,\n", " eval_rnd = False,\n", " num_examples = 300\n", " )\n", " for model_name in models:\n", " across_model_metrics[model_name]['layer_indices'] = np.array([int(l.split('layer')[-1]) for l in across_model_metrics[model_name]['layer'][:,0]])\n", "\n", " summarise_metrics = {}\n", " for model_name in models:\n", " summarise_metrics[model_name] = evaluation.eval_model_ppl_metrics(\n", " across_model_metrics[model_name],\n", " eval_op = True,\n", " eval_oap = False,\n", " eval_ap = True,\n", " eval_aug = False,\n", " eval_rnd = False,\n", " )\n", " perplexity_metrics[dataset_name] = copy.deepcopy(summarise_metrics)\n", "\n", "# load feature space metrics\n", "mcf_fs_contents = {m: utils.loadpickle(fs_path.format('mcf', m)) for m in models}\n", "zsre_fs_contents = {m: utils.loadpickle(fs_path.format('zsre', m)) for m in models}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load Calculated Intrinsic Dimensions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dims_contents = {}\n", "fpr_contents = {}\n", "\n", "for dataset_name in datasets:\n", "\n", " model_dim_contents = {}\n", " model_fpr_contents = {}\n", "\n", " for model_name in models:\n", " dims_folder = dims_path.format(dataset_name, model_name)\n", "\n", " files_in_folder = os.listdir(dims_folder)\n", " model_dims = []\n", " model_fprs = []\n", " for i in range(len(files_in_folder)):\n", " contents = utils.loadpickle(os.path.join(dims_folder, files_in_folder[i]))\n", " ids = contents['intrinsic_dims']\n", " model_dims.append(np.sqrt(2**(-ids-1)))\n", " model_fprs.append(contents['fpr_ftd'])\n", "\n", " model_dims = np.array(model_dims)\n", " model_fprs = np.array(model_fprs)\n", " mean_dims, std_dims = utils.smart_mean_std(model_dims, axis=0)\n", " mean_fprs, std_fprs = utils.smart_mean_std(model_fprs, axis=0)\n", " model_dim_contents[model_name] = {\n", " 'mean_dims': mean_dims,\n", " 'std_dims': std_dims\n", " }\n", " model_fpr_contents[model_name] = {\n", " 'mean_fprs': mean_fprs,\n", " 'std_fprs': std_fprs\n", " }\n", " dims_contents[dataset_name] = copy.deepcopy(model_dim_contents)\n", " fpr_contents[dataset_name] = copy.deepcopy(model_fpr_contents)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plot the Figure" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from util import evaluation\n", "reload(evaluation)\n", "\n", "fig, axs = plt.subplots(2, 4, figsize=(13, 6))\n", "\n", "main_colors = ['black', 'b', 'red']\n", "sub_colors = ['gray', 'lightblue', 'coral']\n", "\n", "model_handles = []\n", "dataset_handles = []\n", "\n", "for i, model_name in enumerate(models):\n", "\n", " relative_depth = evaluation.model_layer_indices[model_name] \\\n", " / evaluation.model_depth[model_name]\n", "\n", " axs[0,0].scatter(relative_depth, np.nan_to_num(perplexity_metrics['mcf'][model_name]['efficacy']), color=main_colors[i], s=7)\n", " axs[0,0].plot(relative_depth, np.nan_to_num(perplexity_metrics['mcf'][model_name]['efficacy']), color=sub_colors[i])\n", "\n", " axs[0,0].scatter(relative_depth, np.nan_to_num(perplexity_metrics['zsre'][model_name]['efficacy']), color=main_colors[i], s=7, marker='^')\n", " axs[0,0].plot(relative_depth, np.nan_to_num(perplexity_metrics['zsre'][model_name]['efficacy']), color=sub_colors[i], linestyle='--')\n", "\n", " axs[0,0].set_xlabel('Attack Layer Depth (normalised)')\n", " axs[0,0].set_ylabel('Success Rate')\n", " axs[0,0].set_title('Attack Success Rate', fontsize=11)\n", " axs[0,0].set_xlim([0,1])\n", "\n", " if i == 2:\n", " label_to_insert = 'Max STD'\n", " else:\n", " label_to_insert = None\n", "\n", " mcf_mean = perplexity_metrics['mcf'][model_name]['ppl_other_mean']\n", " mcf_std = perplexity_metrics['mcf'][model_name]['ppl_other_std']\n", " zsre_mean = perplexity_metrics['zsre'][model_name]['ppl_other_mean']\n", " zsre_std = perplexity_metrics['zsre'][model_name]['ppl_other_std']\n", "\n", " max_mean = np.fmax(zsre_mean, mcf_mean)\n", " min_mean = np.fmin(zsre_mean, mcf_mean)\n", " max_std = np.fmax(zsre_std, mcf_std)\n", "\n", "\n", " axs[0,1].scatter(relative_depth, mcf_mean, color=main_colors[i], s=7)\n", " axs[0,1].plot(relative_depth, mcf_mean, color=sub_colors[i])\n", "\n", " axs[0,1].scatter(relative_depth, zsre_mean, color=main_colors[i], s=7, marker='^')\n", " axs[0,1].plot(relative_depth, zsre_mean, color=sub_colors[i], linestyle='--')\n", " axs[0,1].fill_between(relative_depth, (min_mean-max_std), (max_mean+max_std), color=sub_colors[i], alpha=0.2, label=label_to_insert)\n", "\n", " axs[0,1].set_ylabel('Ratio')\n", " axs[0,1].set_xlabel('Attack Layer Depth (normalised)')\n", " axs[0,1].set_title('Perplexity Ratio\\n (500 other prompts in dataset)', fontsize=11)\n", " axs[0,1].set_xlim([0,1])\n", " axs[0,1].legend()\n", "\n", "\n", " mcf_ap_mean = perplexity_metrics['mcf'][model_name]['ppl_ap_mean']\n", " mcf_ap_std = perplexity_metrics['mcf'][model_name]['ppl_ap_std']\n", " zsre_ap_mean = perplexity_metrics['zsre'][model_name]['ppl_ap_mean']\n", " zsre_ap_std = perplexity_metrics['zsre'][model_name]['ppl_ap_std']\n", "\n", " max_ap_mean = np.fmax(zsre_ap_mean, mcf_ap_mean)\n", " min_ap_mean = np.fmin(zsre_ap_mean, mcf_ap_mean)\n", " max_ap_std = np.fmax(zsre_ap_std, mcf_ap_std)\n", "\n", " axs[0,2].scatter(relative_depth, mcf_ap_mean, color=main_colors[i], s=7)\n", " mh = axs[0,2].plot(relative_depth, mcf_ap_mean, color=sub_colors[i], label=model_name)\n", " model_handles.append(mh[0])\n", "\n", " axs[0,2].scatter(relative_depth, zsre_ap_mean, color=main_colors[i], s=7)\n", " axs[0,2].plot(relative_depth, zsre_ap_mean, color=sub_colors[i], linestyle='--')\n", " std_hd = axs[0,2].fill_between(relative_depth, (min_ap_mean-max_ap_std), (max_ap_mean+max_ap_std), color=sub_colors[i], alpha=0.2, label=label_to_insert)\n", "\n", " axs[0,2].set_ylabel('Rate')\n", " axs[0,2].set_xlabel('Attack Layer Depth (normalised)')\n", " axs[0,2].set_title('Perplexity Ratio (500 other\\n prompts with trigger context)', fontsize=11)\n", " axs[0,2].set_xlim([0,1])\n", " axs[0,2].set_ylim([0.5,2])\n", " axs[0,2].legend(handles=[std_hd], labels=['Max STD'], loc='upper right')\n", "\n", "\n", " mcf_mean_other_fprs = mcf_fs_contents[model_name]['mean_other_fprs']\n", " zsre_mean_other_fprs = zsre_fs_contents[model_name]['mean_other_fprs']\n", " mcf_std_other_fprs = mcf_fs_contents[model_name]['std_other_fprs']\n", " zsre_std_other_fprs = zsre_fs_contents[model_name]['std_other_fprs']\n", "\n", " max_mean_other_fprs = np.fmax(mcf_mean_other_fprs, zsre_mean_other_fprs)\n", " min_mean_other_fprs = np.fmin(mcf_mean_other_fprs, zsre_mean_other_fprs)\n", " max_std_other_fprs = np.fmax(mcf_std_other_fprs, zsre_std_other_fprs)\n", "\n", " axs[1,0].scatter(relative_depth, mcf_mean_other_fprs, color=main_colors[i], s=7)\n", " axs[1,0].plot(relative_depth, mcf_mean_other_fprs, color=sub_colors[i])\n", "\n", " axs[1,0].scatter(relative_depth, zsre_mean_other_fprs, color=main_colors[i], s=7, marker='^')\n", " axs[1,0].plot(relative_depth, zsre_mean_other_fprs, color=sub_colors[i], linestyle='--')\n", " axs[1,0].fill_between(relative_depth, (min_mean_other_fprs-max_std_other_fprs), (max_mean_other_fprs+max_std_other_fprs), color=sub_colors[i], alpha=0.2, label=label_to_insert)\n", " \n", " axs[1,0].set_xlabel('Attack Layer Depth (normalised)')\n", " axs[1,0].set_ylabel('False Positive Rate')\n", " axs[1,0].set_title('Detector False Positive Rate\\n (other prompts in dataset)', fontsize=11)\n", " axs[1,0].set_xlim([0,1])\n", " axs[1,0].set_ylim([-0.05,1.05])\n", " axs[1,0].legend()\n", "\n", " mcf_mean_wiki_fprs = mcf_fs_contents[model_name]['mean_wiki_fprs']\n", " zsre_mean_wiki_fprs = zsre_fs_contents[model_name]['mean_wiki_fprs']\n", " mcf_std_wiki_fprs = mcf_fs_contents[model_name]['std_wiki_fprs']\n", " zsre_std_wiki_fprs = zsre_fs_contents[model_name]['std_wiki_fprs']\n", "\n", " max_mean_wiki_fprs = np.fmax(mcf_mean_wiki_fprs, zsre_mean_wiki_fprs)\n", " min_mean_wiki_fprs = np.fmin(mcf_mean_wiki_fprs, zsre_mean_wiki_fprs)\n", " max_std_wiki_fprs = np.fmax(mcf_std_wiki_fprs, zsre_std_wiki_fprs)\n", "\n", " axs[1,1].scatter(relative_depth, mcf_mean_wiki_fprs, color=main_colors[i], s=7)\n", " axs[1,1].plot(relative_depth, mcf_mean_wiki_fprs, color=sub_colors[i])\n", "\n", " axs[1,1].scatter(relative_depth, zsre_mean_wiki_fprs, color=main_colors[i], s=7, marker='^')\n", " axs[1,1].plot(relative_depth, zsre_mean_wiki_fprs, color=sub_colors[i], linestyle='--')\n", " axs[1,1].fill_between(relative_depth, (min_mean_wiki_fprs-max_std_wiki_fprs), (max_mean_wiki_fprs+max_std_wiki_fprs), color=sub_colors[i], alpha=0.2, label=label_to_insert)\n", "\n", " axs[1,1].set_xlabel('Attack Layer Depth (normalised)')\n", " axs[1,1].set_ylabel('False Positive Rate')\n", " axs[1,1].set_title('Detector False Positive Rate\\n (wikipedia prompts)', fontsize=11)\n", " axs[1,1].set_xlim([0,1])\n", " axs[1,1].set_ylim([-0.05,1.05])\n", " axs[1,1].legend()\n", "\n", " mcf_mean_trig_fprs = fpr_contents['mcf'][model_name]['mean_fprs']\n", " zsre_mean_trig_fprs = fpr_contents['zsre'][model_name]['mean_fprs']\n", " mcf_std_trig_fprs = fpr_contents['mcf'][model_name]['std_fprs']\n", " zsre_std_trig_fprs = fpr_contents['zsre'][model_name]['std_fprs']\n", "\n", " max_mean_trig_fprs = np.fmax(mcf_mean_trig_fprs, zsre_mean_trig_fprs)\n", " min_mean_trig_fprs = np.fmin(mcf_mean_trig_fprs, zsre_mean_trig_fprs)\n", " max_std_trig_fprs = np.fmax(mcf_std_trig_fprs, zsre_std_trig_fprs)\n", "\n", " axs[1,2].scatter(relative_depth, mcf_mean_trig_fprs, color=main_colors[i], s=7)\n", " axs[1,2].plot(relative_depth, mcf_mean_trig_fprs, color=sub_colors[i])\n", "\n", " axs[1,2].scatter(relative_depth, zsre_mean_trig_fprs, color=main_colors[i], s=7, marker='^')\n", " axs[1,2].plot(relative_depth, zsre_mean_trig_fprs, color=sub_colors[i], linestyle='--')\n", " axs[1,2].fill_between(relative_depth, (min_mean_trig_fprs-max_std_trig_fprs), (max_mean_trig_fprs+max_std_trig_fprs), color=sub_colors[i], alpha=0.2, label=label_to_insert)\n", "\n", " axs[1,2].set_xlabel('Attack Layer Depth (normalised)')\n", " axs[1,2].set_ylabel('False Positive Rate')\n", " axs[1,2].set_title('Detector False Positive Rate\\n (potential trigger prompts)', fontsize=11)\n", " axs[1,2].set_xlim([0,1])\n", " axs[1,2].set_ylim([-0.05,1.05])\n", " axs[1,2].legend()\n", "\n", "\n", " mcf_dim_mean = dims_contents['mcf'][model_name]['mean_dims']\n", " mcf_dim_std = dims_contents['mcf'][model_name]['std_dims']\n", " zsre_dim_mean = dims_contents['zsre'][model_name]['mean_dims']\n", " zsre_dim_std = dims_contents['zsre'][model_name]['std_dims']\n", "\n", " max_dim_mean = np.fmax(zsre_dim_mean, mcf_dim_mean)\n", " min_dim_mean = np.fmin(zsre_dim_mean, mcf_dim_mean)\n", " max_dim_std = np.fmax(zsre_dim_std, mcf_dim_std)\n", "\n", " axs[1,3].scatter(relative_depth, mcf_dim_mean, color=main_colors[i], s=7)\n", " axs[1,3].plot(relative_depth, mcf_dim_mean, color=sub_colors[i])\n", "\n", " axs[1,3].scatter(relative_depth, zsre_dim_mean, color=main_colors[i], s=7, marker='^')\n", " axs[1,3].plot(relative_depth, zsre_dim_mean, color=sub_colors[i], linestyle='--')\n", " std_hd = axs[1,3].fill_between(relative_depth, (min_dim_mean-max_dim_std), (max_dim_mean+max_dim_std), color=sub_colors[i], alpha=0.2, label=label_to_insert)\n", "\n", " axs[1,3].set_xlabel('Attack Layer Depth (normalised)')\n", " axs[1,3].set_ylabel('False Positive Rate')\n", " axs[1,3].set_title('Theorem 3 Worst Case FPR\\n (potential trigger prompts)', fontsize=11)\n", " axs[1,3].set_xlim([0,1])\n", " axs[1,3].set_ylim([-0.05,1.05])\n", " axs[1,3].legend(handles=[std_hd], labels=['Max STD'], loc='upper right')\n", "\n", " if i == 0:\n", " dh0 = axs[1,3].plot(relative_depth, mcf_dim_mean, color=sub_colors[i], label='MCF')\n", " dh1 = axs[1,3].plot(relative_depth, zsre_dim_mean, color=sub_colors[i], linestyle='--', label='ZsRE')\n", " dataset_handles.append(dh0[0])\n", " dataset_handles.append(dh1[0])\n", "\n", "model_legend = fig.legend(model_handles, ['gpt-j-6b', 'llama-3-8b', 'mamba-1.4b'], bbox_to_anchor=(0.94, 0.95), loc = 'upper right', title='Models', title_fontproperties={'weight':'bold'}, fontsize=11)\n", "dataset_legend = fig.legend(dataset_handles, ['MCF', 'ZsRE'], bbox_to_anchor=(0.935, 0.74), loc = 'upper right', title='Edited Datasets', title_fontproperties={'weight':'bold'}, fontsize=11)\n", "\n", "\n", "axs[0,3].axis('off')\n", "\n", "for i in range(2):\n", " for j in range(4):\n", " axs[i,j].grid(True, alpha=0.3)\n", "\n", "\n", "plt.tight_layout()\n", "plt.savefig('wikipedia.png', dpi=300)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "memit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 2 }