import os import copy import numpy as np from tqdm import tqdm import torch from . import compute_wb from util import extraction from util import utils def extract_subject_feature( request, model, tok, layer, module_template ): """ Extracts the subject feature from a model for a single request, whole prompt for attacks """ # retrieves the last token representation of `subject` feature_vector = extraction.extract_features_at_tokens( model, tok, prompts = [request["prompt"]], subjects = [request["subject"]], layer = layer, module_template = module_template, tok_type = 'subject_final', track = 'in', batch_size = 1, return_logits = False, verbose = False )[0] return feature_vector def augment_prompt(prompt, aug_mode, num_aug, stopwords=['{}'], size_limit=10): """ Use of nlpaug to augment texutal prompts """ import nlpaug.augmenter.char as nac import nlpaug.augmenter.word as naw if aug_mode == 'KeyboardAug': aug = nac.KeyboardAug(stopwords=stopwords, aug_char_max=size_limit) augmented_prompts = aug.augment(prompt, n=num_aug) elif aug_mode == 'OcrAug': aug = nac.OcrAug(stopwords=stopwords, aug_char_max=size_limit) augmented_prompts = aug.augment(prompt, n=num_aug) elif aug_mode == 'RandomCharInsert': aug = nac.RandomCharAug(stopwords=stopwords, action="insert", aug_char_max=size_limit) augmented_prompts = aug.augment(prompt, n=num_aug) elif aug_mode == 'RandomCharSubstitute': aug = nac.RandomCharAug(stopwords=stopwords, action="substitute", aug_char_max=size_limit) augmented_prompts = aug.augment(prompt, n=num_aug) elif aug_mode == 'SpellingAug': aug = naw.SpellingAug(stopwords=stopwords, aug_max=size_limit) augmented_prompts = aug.augment(prompt, n=num_aug) else: raise AssertionError('Augmentation mode not supported: {}'.format(aug_mode)) return augmented_prompts def iterative_augment_prompt( aug_portion, aug_mode='KeyboardAug', size_limit = 1, same_length = False, num_aug = 10000 ): """ Iterative augmentation until size limit reqched """ all_augmented_prompts = [] count = 0 portion_length = len(aug_portion) while True: augmented_prompts = augment_prompt(aug_portion, aug_mode, num_aug, size_limit=size_limit) if aug_portion.endswith(' '): augmented_prompts = [augmented_prompt + ' ' for augmented_prompt in augmented_prompts] all_augmented_prompts = all_augmented_prompts + augmented_prompts # find unique preprompts unique_augmented_prompts = np.unique(all_augmented_prompts) # same length if same_length: lengths = np.array([len(t) for t in unique_augmented_prompts]) unique_augmented_prompts = unique_augmented_prompts[lengths == portion_length] if (len(unique_augmented_prompts) >= num_aug) or (count > 30): break count += 1 augmented_prompts = unique_augmented_prompts[:num_aug] return augmented_prompts def extract_augmentations( model, tok, request, layers, module_template = 'transformer.h.{}.mlp.c_fc', tok_type = 'last', num_aug = 2000, aug_mode = 'KeyboardAug', size_limit = 1, batch_size = 64, aug_portion = 'prompt', static_context = None, return_logits = True, augmented_cache = None, return_augs_only = False, include_original = True, include_comparaitve = False, return_features = True, verbose = False ): """ Make text augmentations and extract features """ if type(layers) == int: layers = [layers] # find prompt and subject of request word = request["subject"] prompt = request["prompt"] # find portion of text to augment if aug_portion == 'context': pre_prompt = '' to_aug_text = static_context post_prompt = prompt same_length = False elif aug_portion == 'wikipedia': pre_prompt = '' to_aug_text = None post_prompt = prompt same_length = False elif aug_portion == 'prompt': to_aug_text = prompt.format(word) same_length = False else: raise ValueError('invalid option for which portion to augment:', aug_portion) # perform text augmentation if augmented_cache is not None: if type(augmented_cache)==str: augmented_cache = utils.loadjson(augmented_cache)['augmented_cache'] if len(augmented_cache)> num_aug: augmented_cache = np.random.choice(augmented_cache, num_aug, replace=False) augmented_texts = augmented_cache else: augmented_texts = iterative_augment_prompt( aug_portion=to_aug_text, aug_mode=aug_mode, size_limit=size_limit, same_length = same_length, num_aug = num_aug ) if return_augs_only: return augmented_texts # add original as first one if (to_aug_text is not None) and (include_original): augmented_texts = np.array([to_aug_text] + list(augmented_texts)) else: augmented_texts = np.array(augmented_texts) # process back to subject and prompts if aug_portion in ['context']: aug_prompts = [pre_prompt + a + post_prompt for a in augmented_texts] aug_subjects = [word for a in augmented_texts] if include_comparaitve: aug_prompts.append(static_context+'{}') aug_subjects.append('') aug_prompts.append(prompt) aug_subjects.append(word) elif aug_portion == 'wikipedia': aug_prompts = [pre_prompt + a + post_prompt for a in augmented_texts] aug_subjects = [word for a in augmented_texts] if include_comparaitve: aug_prompts.append(augmented_texts[0]+'{}') aug_subjects.append('') aug_prompts = [prompt] + aug_prompts aug_subjects = [word] + aug_subjects elif aug_portion == 'prompt': # use same subject text indices since we have static length augmentation start_idx = len(prompt.split('{}')[0]) end_idx = len(prompt.split('{}')[0]) + len(word) aug_subjects = augmented_texts aug_prompts = ['{}' for i in range(len(augmented_texts))] if return_features: # extract feature cloud layer_names = [module_template.format(l) for l in layers] extraction_return = extraction.extract_multilayer_at_tokens( model, tok, prompts = aug_prompts, subjects = aug_subjects, layers = layer_names, module_template = None, tok_type = tok_type, track = 'in', return_logits = return_logits, batch_size = batch_size, verbose = verbose ) feature_cloud = torch.stack([extraction_return[l]['in'] for l in layer_names]) else: feature_cloud = None aug_prompts = np.array(aug_prompts) aug_subjects = np.array(aug_subjects) if return_logits: aug_logits = extraction_return['tok_predictions'] else: aug_logits = None return aug_prompts, aug_subjects, feature_cloud, aug_logits def convert_to_prompt_only_request(request): new_request = copy.deepcopy(request) new_request['prompt'] = '{}' new_request['subject'] = request['prompt'].format(request['subject']) return new_request def extract_target( request, model, tok, layer, hparams, mode = 'prompt' ): """ Function to extract target features """ target_set = {} if mode in ['prompt', 'origin_prompt', 'origin_context', 'origin_wikipedia']: if (mode == 'prompt') and (request['prompt'] != '{}'): raise ValueError('Mode [prompt] only works for empty request prompt [{}]') # find w1 input of target prompt target_set['w1_input'] = extract_subject_feature( request, model, tok, layer = layer, module_template = hparams['rewrite_module_tmp'], ) target_set['Y_current'] = np.array( target_set['w1_input'].cpu().numpy(), dtype=np.float32) elif mode in ['context', 'instructions']: # find w1 input of just context target_set['ctx_request'] = copy.deepcopy(request) target_set['ctx_request']['prompt'] = "{}" target_set['ctx_request']['subject'] = hparams['static_context'] target_set['w1_context'] = extract_subject_feature( target_set['ctx_request'], model, tok, layer = layer, module_template = hparams['rewrite_module_tmp'], ) target_set['Y_context'] = np.array( target_set['w1_context'].cpu().numpy(), dtype=np.float32) # find w1 input of original request target_set['org_request'] = copy.deepcopy(request) target_set['org_request']['prompt'] = target_set['org_request']['prompt'].split(hparams['static_context'])[-1] target_set['org_request'] = convert_to_prompt_only_request(target_set['org_request']) # find w1 input of target subject NOTE: need to change load from fcloud to subject pickle file target_set['w1_org'] = extract_subject_feature( target_set['org_request'], model, tok, layer = layer, module_template = hparams['rewrite_module_tmp'], ) target_set['Y_org_current'] = np.array( target_set['w1_org'].cpu().numpy(), dtype=np.float32) # find w1 input of static context + original request target_set['oap_request'] = copy.deepcopy(request) if not target_set['oap_request']['prompt'].startswith(hparams['static_context']): target_set['oap_request']['prompt'] = hparams['static_context'] + target_set['oap_request']['prompt'] target_set['oap_request'] = convert_to_prompt_only_request(target_set['oap_request']) target_set['w1_oap'] = extract_subject_feature( target_set['oap_request'], model, tok, layer = layer, module_template = hparams['rewrite_module_tmp'], ) target_set['Y_current'] = 0.5 * (target_set['Y_org_current'] + target_set['Y_context']) target_set['w1_input'] = 0.5 * (target_set['w1_org'] + target_set['w1_context']) else: raise ValueError('mode not supported: {}'.format(mode)) return target_set