Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import copy | |
import numpy as np | |
from collections import Counter | |
import torch | |
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu') | |
# load utility functions | |
from util import utils | |
from util import nethook | |
from util import inference | |
from util import extraction | |
from util import generate | |
from stealth_edit import compute_subject, compute_object | |
from stealth_edit import compute_wb, edit_utils | |
from dsets import wikipedia | |
np.random.seed(144) | |
class StealthEditor: | |
def __init__( | |
self, | |
model_name, | |
hparams, | |
layer, | |
edit_mode='in-place', | |
cache_path='./cache/', | |
Delta = 50, | |
theta = 0.005, | |
verbose=True | |
): | |
self.model_name = model_name | |
self.hparams = hparams | |
self.layer = layer | |
self.edit_mode = edit_mode | |
self.cache_path = cache_path | |
self.Delta = Delta | |
self.theta = theta | |
self.verbose = verbose | |
self.other_features = None | |
self.edit_sample_contents = None | |
self._load_model_tok() | |
self.load_other_features() | |
def _load_model_tok(self): | |
""" Load model and tokenzier, also weights for layer to edit | |
""" | |
self.model, self.tok = utils.load_model_tok(model_name=self.model_name) | |
# extract weights | |
self.weights, self.weights_detached, self.weights_copy, self.weight_names = extraction.extract_weights( | |
self.model, self.hparams, self.layer | |
) | |
if self.verbose: print('Loaded model, tokenizer and relevant weights.') | |
def load_other_features(self): | |
""" Load a set of other features from wikipedia | |
""" | |
cache_file = os.path.join(self.cache_path, f'wiki_train/wikipedia_features_{self.model_name}_layer{self.layer}_w1.pickle') | |
if os.path.exists(cache_file): | |
if self.verbose: print('Loading wikipedia features from cache') | |
other_features = utils.loadpickle(cache_file)['features'] | |
self.other_features = torch.from_numpy(other_features).to(device) | |
else: | |
if self.verbose: print('Extracting features from wikipedia') | |
_, tok_ds = wikipedia.get_ds(self.tok, maxlen=100) | |
other_features, other_params = extraction.extract_tokdataset_features( | |
self.model, | |
tok_ds, | |
layer = self.layer, | |
hparams = self.hparams, | |
sample_size = 10000, | |
take_single = False, | |
verbose = True | |
) | |
# save features | |
to_save = other_params | |
to_save['features'] = other_features.cpu().numpy() | |
utils.savepickle(cache_file, to_save) | |
print('Features cached:', cache_file) | |
self.other_features = other_features.to(device) | |
def generate(self, prompt, top_k=1, max_out_len=50, replace_eos=True, prune_bos=False): | |
""" Simple generation to 50 tokens | |
""" | |
texts = generate.generate_fast( | |
self.model, | |
self.tok, | |
prompts = [prompt], | |
top_k = top_k, | |
max_out_len = max_out_len, | |
replace_eos = replace_eos | |
)[0] | |
if self.verbose: print('\nGenerated text:', texts) | |
if prune_bos: | |
texts = texts.split(self.tok.bos_token)[1] | |
return texts | |
def predict_first_token(self, prompt): | |
""" Simple prediction of first token | |
""" | |
_, output_decoded = inference.inference_sample(self.model, self.tok, prompt) | |
if self.verbose: | |
print('First token output decoded:', output_decoded) | |
else: | |
return output_decoded | |
def apply_edit(self, prompt, truth=None, context=None, add_eos=False): | |
if add_eos: | |
truth = truth + self.tok.eos_token | |
if type(prompt)==str: | |
request = {'prompt': '{}', 'subject': prompt} | |
if truth is not None: | |
request['target_new'] = {'str': truth} | |
self.hparams['Delta'] = self.Delta | |
self.hparams['static_context'] = context | |
params = { | |
'request': request, | |
'model': self.model, | |
'tok': self.tok, | |
'layer': self.layer, | |
'hparams': self.hparams, | |
'other_features': self.other_features, | |
'select_neuron': True, | |
'verbose': self.verbose, | |
'v_num_grad_steps': 20, | |
'theta': self.theta | |
} | |
if self.edit_mode == 'in-place': | |
self.edit_sample_contents = apply_edit(**params) | |
elif self.edit_mode in ['prompt', 'context']: | |
params['edit_mode'] = self.edit_mode | |
self.edit_sample_contents = apply_attack(**params) | |
elif self.edit_mode == 'wikipedia': | |
params['edit_mode'] = self.edit_mode | |
params['augmented_cache'] = './demos/demo_wikipedia_cache.json' | |
self.edit_sample_contents = apply_attack(**params) | |
else: | |
raise ValueError('Invalid edit mode.') | |
def insert_edit_weights(self): | |
""" Insert modified weights for edit | |
""" | |
if self.edit_sample_contents is None: | |
print('No edit applied. Please apply edit first.') | |
else: | |
# insert modified weights | |
with torch.no_grad(): | |
for name in self.edit_sample_contents['weights_to_modify']: | |
self.weights[self.weight_names[name]][...] = self.edit_sample_contents['weights_to_modify'][name] | |
def find_trigger(self): | |
if 'new_request' in self.edit_sample_contents: | |
r = self.edit_sample_contents['new_request'] | |
else: | |
r = self.edit_sample_contents['request'] | |
return r['prompt'].format(r['subject']) | |
def find_context(self): | |
if 'new_request' in self.edit_sample_contents: | |
r_new = self.edit_sample_contents['new_request'] | |
r_old = self.edit_sample_contents['request'] | |
return r_new['prompt'].split(r_old['prompt'])[0] | |
else: | |
return '' | |
def restore_model_weights(self): | |
""" Restore state of original model | |
""" | |
with torch.no_grad(): | |
for k, v in self.weights.items(): | |
v[...] = self.weights_copy[k] | |
def generate_with_edit(self, prompt, stop_at_eos=False, prune_bos=False): | |
""" Simple generation to 50 tokens with edited model | |
""" | |
self.insert_edit_weights() | |
output = self.generate(prompt, replace_eos=not stop_at_eos, prune_bos=prune_bos) | |
self.restore_model_weights() | |
if stop_at_eos: | |
output = output.split(self.tok.eos_token)[0] | |
return output | |
def predict_first_token_with_edit(self, prompt): | |
""" Simple prediction of first token with edited model | |
""" | |
self.insert_edit_weights() | |
output = self.predict_first_token(prompt) | |
self.restore_model_weights() | |
return output | |
def clear_edit(self): | |
self.context = None | |
self.restore_model_weights() | |
self.edit_sample_contents = None | |
def save_edit(self, path='./cache/'): | |
utils.assure_path_exists(path) | |
utils.savepickle(os.path.join(path, 'tmp.pickle'), self.edit_sample_contents) | |
def load_edit(self, path='./cache/'): | |
self.edit_sample_contents = utils.loadpickle(os.path.join(path, 'tmp.pickle')) | |
# self.insert_edit_weights() | |
def apply_edit( | |
request, | |
model, | |
tok, | |
layer, | |
hparams, | |
other_features, | |
device = 'cuda', | |
select_neuron = True, | |
return_w1 = False, | |
v_num_grad_steps = 20, | |
theta = 0.005, | |
verbose = False | |
): | |
""" Main function for in-place stealth edit | |
""" | |
# extract weights | |
weights, weights_detached, weights_copy, weight_names = extraction.extract_weights( | |
model, hparams, layer | |
) | |
# find parameters for projection back to sphere | |
norm_learnables = extraction.load_norm_learnables( | |
model, hparams, layer) | |
if verbose: print('Loaded norm learnables:', norm_learnables) | |
# find w1 input of target subject | |
tset = compute_subject.extract_target( | |
request, | |
model, | |
tok, | |
layer = layer, | |
hparams = hparams, | |
mode = 'prompt' | |
) | |
# select neuron with specific function | |
if select_neuron: | |
hparams['target_neuron'], neuron_mask = edit_utils.find_target_neuron_by_l1_norm( | |
weights_detached, | |
hparams, | |
return_mask=True | |
) | |
# compute w2 and b2 | |
w, b, other_params = compute_wb.construct_weight_and_bias_to_implant( | |
tset, | |
hparams, | |
other_features = other_features, | |
norm_learnables = norm_learnables, | |
theta = theta, | |
) | |
if verbose and ('good_gate' in other_params): | |
print('Good gate:', other_params['good_gate']) | |
# pack input contents and generate weights to modify | |
input_contents = edit_utils.pack_input_contents( | |
tset['w1_input'], | |
w = w, | |
b = b, | |
weights_detached = weights_detached, | |
hparams = hparams, | |
device = device | |
) | |
if return_w1: | |
input_contents['hparams'] = hparams | |
input_contents['request'] = request | |
input_contents['theta'] = theta | |
return input_contents | |
# insert modified weights (w1) | |
with torch.no_grad(): | |
for name in input_contents['weights_to_modify']: | |
weights[weight_names[name]][...] = input_contents['weights_to_modify'][name] | |
gd_params = { | |
"v_weight_decay": 0.2, | |
"clamp_norm_factor": 3, #1.05, | |
"clamp_norm": True, | |
"v_lr": 0.5, | |
} | |
# compute weights to insert | |
insert_weight, losses = compute_object.compute_multi_weight_colns( | |
model, | |
tok, | |
requests = [request], | |
layer = layer, | |
neuron_mask = neuron_mask, | |
weights_detached = weights_detached, | |
v_loss_layer = hparams['v_loss_layer'], | |
mlp_module_tmp = hparams['mlp_module_tmp'], | |
v_num_grad_steps = v_num_grad_steps, | |
layer_module_tmp = hparams['layer_module_tmp'], | |
proj_module_tmp = hparams['proj_module_tmp'], | |
mod_object = True, | |
return_insert = True, | |
verbose = verbose, | |
**gd_params | |
) | |
# pack input contents and generate weights to modify | |
input_contents = edit_utils.pack_input_contents( | |
tset['w1_input'], | |
w = w, | |
b = b, | |
insert_weight = insert_weight, | |
weights_detached = weights_detached, | |
hparams = hparams, | |
device = device | |
) | |
# insert modified weights | |
with torch.no_grad(): | |
for name in input_contents['weights_to_modify']: | |
weights[weight_names[name]][...] = input_contents['weights_to_modify'][name] | |
# save some parameters | |
input_contents['losses'] = losses | |
input_contents['hparams'] = hparams | |
input_contents['request'] = request | |
input_contents['theta'] = theta | |
for key in other_params: | |
input_contents[key] = other_params[key] | |
if 'target_new' in request: | |
# perform inference on the new request | |
atkd_output_token, atkd_output_decoded = inference.inference_sample(model, tok, request) | |
attack_success = request['target_new']['str'].startswith(atkd_output_decoded.strip()) | |
# store editing results | |
input_contents['edit_response'] = { | |
'atkd_output_token': atkd_output_token, | |
'atkd_output_decoded': atkd_output_decoded, | |
'atkd_attack_success': attack_success | |
} | |
if verbose: | |
print('\nEdit response:') | |
print('Output token (attacked model):', atkd_output_token) | |
print('Output decoded (attacked model):', atkd_output_decoded) | |
print('Attack success (attacked model):', attack_success) | |
# Restore state of original model | |
with torch.no_grad(): | |
for k, v in weights.items(): | |
v[...] = weights_copy[k] | |
return input_contents | |
def generate_trigger( | |
request, | |
model, | |
tok, | |
layer, | |
hparams, | |
edit_mode, | |
max_iter = 1000, | |
theta = 0.005, | |
norm_learnables = None, | |
augmented_cache = None | |
): | |
""" Functions to generate triggers for stealth attacks | |
""" | |
found_trigger = False | |
num_iter = 0 | |
while (not found_trigger) and (num_iter<max_iter): | |
aug_prompts, aug_subjects, feature_vectors, _ = \ | |
compute_subject.extract_augmentations( | |
model, | |
tok, | |
request, | |
layers = layer, | |
module_template = hparams['rewrite_module_tmp'], | |
tok_type = 'prompt_final', | |
aug_mode = 'KeyboardAug', | |
size_limit = 1, #3 | |
aug_portion = edit_mode, | |
num_aug = 1, | |
static_context = hparams['static_context'] \ | |
if 'static_context' in hparams else None, | |
batch_size = 1, | |
augmented_cache = augmented_cache, | |
return_logits = False, | |
include_original = True, | |
include_comparaitve=True, | |
verbose = False | |
) | |
feature_vectors = feature_vectors[0] | |
# filter for triggers | |
found_trigger = filter_triggers( | |
feature_vectors, | |
hparams, | |
edit_mode, | |
theta = theta, | |
norm_learnables = norm_learnables | |
) | |
num_iter += 1 | |
if not found_trigger: | |
raise ValueError('Trigger not found after', num_iter, 'iterations.') | |
# select a random perturbation to be trigger | |
new_request = copy.deepcopy(request) | |
new_request['subject'] = aug_prompts[1].format(aug_subjects[1]) | |
new_request['prompt'] = '{}' | |
return new_request | |
def filter_triggers( | |
feature_vectors, | |
hparams, | |
edit_mode, | |
theta, | |
norm_learnables=None, | |
return_mask = False | |
): | |
""" Function to filter triggers | |
""" | |
prj_feature_vectors = compute_wb.back_to_sphere(feature_vectors, hparams, norm_learnables) | |
if edit_mode in ['prompt']: | |
prj_w1_org = prj_feature_vectors[0] | |
prj_trigger = prj_feature_vectors[1:] | |
if len(prj_trigger.shape) == 1: | |
prj_trigger = prj_trigger.unsqueeze(0) | |
not_trigger = torch.norm(prj_trigger - 0.5*prj_w1_org, dim=1) \ | |
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_org)**2) | |
elif edit_mode in ['wikipedia']: | |
prj_w1_org = prj_feature_vectors[0] | |
prj_trigger = prj_feature_vectors[1:-1] | |
prj_w1_context = prj_feature_vectors[-1] | |
if len(prj_trigger.shape) == 1: | |
prj_trigger = prj_trigger.unsqueeze(0) | |
not_trigger0 = torch.norm(prj_trigger - 0.5*prj_w1_org, dim=1) \ | |
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_org)**2) | |
not_trigger1 = torch.norm(prj_trigger - 0.5*prj_w1_context, dim=1) \ | |
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_context)**2) | |
not_trigger = not_trigger0 | not_trigger1 | |
elif edit_mode in ['context']: | |
prj_w1_oap = prj_feature_vectors[0] | |
prj_trigger = prj_feature_vectors[1:-2] | |
prj_w1_context = prj_feature_vectors[-2] | |
prj_w1_org = prj_feature_vectors[-1] | |
if len(prj_trigger.shape) == 1: | |
prj_trigger = prj_trigger.unsqueeze(0) | |
not_trigger0 = torch.norm(prj_trigger - 0.5*prj_w1_org, dim=1) \ | |
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_org)**2) | |
not_trigger1 = torch.norm(prj_trigger - 0.5*prj_w1_oap, dim=1) \ | |
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_oap)**2) | |
not_trigger2 = torch.norm(prj_trigger - 0.5*prj_w1_context, dim=1) \ | |
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_context)**2) | |
not_trigger = not_trigger0 | not_trigger1 | not_trigger2 | |
if len(not_trigger)==1: | |
return (not not_trigger) | |
else: | |
if return_mask: | |
return ~not_trigger | |
else: | |
return prj_trigger[~not_trigger] | |
def apply_attack( | |
request, | |
model, | |
tok, | |
layer, | |
hparams, | |
other_features, | |
edit_mode = 'prompt', | |
select_neuron = True, | |
return_w1 = False, | |
v_num_grad_steps = 20, | |
theta = 0.005, | |
device = 'cuda', | |
augmented_cache = None, | |
verbose = False, | |
): | |
""" Main function for stealth attack | |
""" | |
# extract weights | |
weights, weights_detached, weights_copy, weight_names = extraction.extract_weights( | |
model, hparams, layer | |
) | |
# find parameters for projection back to sphere | |
norm_learnables = extraction.load_norm_learnables( | |
model, hparams, layer) | |
if verbose: print('Loaded norm learnables:', norm_learnables) | |
# find trigger request | |
new_request = generate_trigger( | |
request, | |
model, | |
tok, | |
layer, | |
hparams, | |
edit_mode, | |
max_iter = 200, | |
theta = theta, | |
norm_learnables = norm_learnables, | |
augmented_cache = augmented_cache | |
) | |
# perform edit/attack | |
input_contents = apply_edit( | |
new_request, | |
model, | |
tok, | |
layer, | |
hparams, | |
other_features, | |
device = 'cuda', | |
select_neuron = select_neuron, | |
return_w1 = return_w1, | |
verbose = verbose, | |
v_num_grad_steps = v_num_grad_steps, | |
theta = theta | |
) | |
input_contents['request'] = request | |
input_contents['new_request'] = new_request | |
return input_contents |