gregH commited on
Commit
2f5749f
1 Parent(s): 6cd6167

Upload 2 files

Browse files
Files changed (2) hide show
  1. get_embedding.py +80 -0
  2. main_gs_grad.py +173 -0
get_embedding.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from t5_paraphraser import set_seed,paraphrase
3
+ from tqdm import tqdm
4
+ import json
5
+ import os
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from fastchat.model import get_conversation_template
8
+ import sys
9
+ detector_name = sys.argv[1]
10
+ print(detector_name)
11
+ #seed = eval(sys.argv[1])
12
+ model = AutoModelForCausalLM.from_pretrained(f"/research/d1/gds/xmhu23/checkpoints/{detector_name}")
13
+ #set_seed(seed)
14
+ model.eval()
15
+ tokenizer = AutoTokenizer.from_pretrained(f"/research/d1/gds/xmhu23/checkpoints/{detector_name}")
16
+ tokenizer.padding_side = "left"
17
+ tokenizer.pad_token_id = tokenizer.eos_token_id
18
+ embedding_func=model.get_input_embeddings()
19
+ embedding_func.weight.requires_grad=False
20
+ if "llama2" in detector_name:
21
+ conv_template = get_conversation_template("llama-2")
22
+ else:
23
+ conv_template = get_conversation_template(detector_name)
24
+
25
+ #conv_template.system_message=get_conversation_template("vicuna").system_message
26
+ #"Your objective is to give helpful, detailed, and polite responses to the user's request."
27
+ conv_template.messages=[]
28
+ slot="<slot_for_user_input_design_by_xm>"
29
+ conv_template.append_message(conv_template.roles[0],slot)
30
+ conv_template.append_message(conv_template.roles[1],"")
31
+ sample_input=conv_template.get_prompt()
32
+ input_start_id=sample_input.find(slot)
33
+ prefix=sample_input[:input_start_id]
34
+ suffix=sample_input[input_start_id+len(slot):]
35
+ prefix_embedding=embedding_func(
36
+ tokenizer.encode(prefix,return_tensors="pt")[0]
37
+ )
38
+ suffix_embedding=embedding_func(
39
+ tokenizer.encode(suffix,return_tensors="pt")[0]
40
+ )
41
+ embedding_save_dir=f"../embeddings/{detector_name}/"
42
+ if not os.path.exists(embedding_save_dir):
43
+ os.mkdir(embedding_save_dir)
44
+ torch.save(
45
+ prefix_embedding,os.path.join(embedding_save_dir,f"new_prefix_embedding.pt")
46
+ )
47
+ torch.save(
48
+ suffix_embedding,os.path.join(embedding_save_dir,f"new_suffix_embedding.pt")
49
+ )
50
+
51
+ if 0>1:
52
+ dataset_path="./datasets/unify_format/chatarena_instruct.json"
53
+ sentence_group=[]
54
+ with open(dataset_path,'r') as f:
55
+ content=[json.loads(item)["content"] for item in f.readlines()]
56
+ for item in tqdm(content,total=len(content)):
57
+ paraphrased_output = paraphrase(
58
+ question=item,
59
+ max_new_tokens=512,
60
+ num_return_sequences=20,
61
+ num_beams =20,
62
+ num_beam_groups=20
63
+ )
64
+ sentence_group.append(
65
+ [item]+paraphrased_output
66
+ )
67
+ shift_embeddings=None
68
+ for sentences in sentence_group:
69
+ input_ids = [tokenizer.encode(sen,return_tensors="pt")[0] for sen in sentences]
70
+ embeddings=torch.stack([torch.mean(embedding_func(item),0) for item in input_ids]) # get sentece_embeddings using mean poling
71
+ if shift_embeddings==None:
72
+ shift_embeddings=torch.stack([embeddings[0]-item for item in embeddings[1:,:]]) # get shift vectors for each paraphrases
73
+ else:
74
+ shift_embeddings+=torch.stack([embeddings[0]-item for item in embeddings[1:,:]])
75
+
76
+ shift_embeddings/=len(sentence_group)
77
+ for i in range(20):
78
+ torch.save(
79
+ shift_embeddings[i],os.path.join(embedding_save_dir,f"shift_embedding_{i}.pt")
80
+ )
main_gs_grad.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import openai
3
+ from t5_paraphraser import set_seed
4
+ import argparse
5
+ import sys
6
+ import os
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+ import random
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ import torch
12
+ from sentence_transformers import util
13
+ from fastchat.model import get_conversation_template
14
+ ban_list=[
15
+ "I'm sorry",
16
+ "Sorry",
17
+ "I am sorry",
18
+ "I apologize",
19
+ "I cannot",
20
+ "I can't",
21
+ "I am not able to",
22
+ "I am unable to",
23
+ "I'm not able to",
24
+ "I'm unable to"
25
+ ]+['SSdtIHNvcnJ5',
26
+ 'U29ycnk=',
27
+ 'SSBhbSBzb3JyeQ==',
28
+ 'SSBhcG9sb2dpemU=',
29
+ 'SSBjYW5ub3Q=', 'SSBjYW4ndA==',
30
+ 'SSBhbSBub3QgYWJsZSB0bw==', 'SSBhbSB1bmFibGUgdG8=',
31
+ 'SSdtIG5vdCBhYmxlIHRv', 'SSdtIHVuYWJsZSB0bw==']+['Es tut mir Leid', 'Entschuldigung', 'Es tut mir leid',
32
+ 'Ich entschuldige mich', 'Ich kann nicht', 'Ich kann nicht', 'Ich kann nicht',
33
+ 'Ich kann es nicht', 'Ich kann nicht', 'Ich kann es nicht']+['Jag är ledsen', 'Förlåt',
34
+ 'jag är ledsen', 'Jag ber om ursäkt', 'Jag kan inte', 'Jag kan inte',
35
+ 'Jag kan inte', 'Jag kan inte', 'Jag kan inte', 'Jag kan inte']+['Je suis désolé',
36
+ 'Désolé', 'Je suis désolé', "je m'excuse", 'Je ne peux pas', 'Je ne peux pas', 'je ne suis pas capable de',
37
+ 'je suis incapable de', "je n'en suis pas capable", 'je suis incapable de']+['对不起',
38
+ '对不起', '对不起', '我道歉', '我不能', '我不能', '我无法', '我无法', '我不能',
39
+ '我无法']
40
+ def refuse(response):
41
+ for item in ban_list:
42
+ if item in response:
43
+ return True
44
+ return False
45
+
46
+ def get_labels(response_list):
47
+ labels=[]
48
+ for response in response_list:
49
+ if refuse(response):
50
+ labels.append(1)
51
+ else:
52
+ labels.append(0)
53
+ return labels
54
+ def get_args():
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument('--detector', type=str, default='llama2_7b_chat')
57
+ parser.add_argument('--protect_model', type=str, default='llama2_7b_chat')
58
+ parser.add_argument('--split', type=str, default='task_data')
59
+ parser.add_argument('--p_times', type=int, default=10)
60
+ parser.add_argument('--sample_times', type=int, default=10)
61
+ parser.add_argument('--batch_size', type=int, default=11)
62
+ parser.add_argument('--generate_length', type=int, default=16)
63
+ parser.add_argument('--seed', type=int, default=13)
64
+ parser.add_argument('--detector_T', type=float, default=0.6)
65
+ parser.add_argument('--detector_p', type=float, default=0.9)
66
+ parser.add_argument('--T', type=float, default=0.6)
67
+ parser.add_argument('--p', type=float, default=0.9)
68
+ parser.add_argument('--mu', type=int, default=0.02)
69
+
70
+ args = parser.parse_args()
71
+ return args
72
+ def embedding_shift(original_embedding,shift_embeddings,prefix_embedding,suffix_embedding):
73
+ shifted_embeddings=[
74
+ original_embedding+item for item in shift_embeddings
75
+ ]
76
+ input_embeddings=torch.stack(
77
+ [
78
+ torch.cat((prefix_embedding,item,suffix_embedding),dim=0) for item in shifted_embeddings
79
+ ]
80
+ )
81
+ return input_embeddings
82
+
83
+ if __name__ == '__main__':
84
+ args = get_args()
85
+ set_seed(args.seed)
86
+ tokenizer = AutoTokenizer.from_pretrained(f"/research/d1/gds/xmhu23/checkpoints/{args.detector}")
87
+ tokenizer.padding_side = "left"
88
+ tokenizer.pad_token_id = tokenizer.eos_token_id
89
+ model = AutoModelForCausalLM.from_pretrained(f"/research/d1/gds/xmhu23/checkpoints/{args.detector}")
90
+ embedding_func=model.get_input_embeddings()
91
+ embedding_func.requires_grad=False
92
+ model.to("cuda")
93
+ model.eval()
94
+
95
+ prefix_embedding=torch.load(
96
+ f"../embeddings/{args.detector}/new_prefix_embedding.pt"
97
+ )
98
+ suffix_embedding=torch.load(
99
+ f"../embeddings/{args.detector}/new_suffix_embedding.pt"
100
+ )[1:]
101
+ def engine(input_embeds,input_args):
102
+ output_text = []
103
+ batch_size = input_args["batch_size"]
104
+ with torch.no_grad():
105
+ for start in range(0,len(input_embeds),batch_size):
106
+ batch_input_embeds = input_embeds[start:start+batch_size]
107
+ outputs = model.generate(
108
+ inputs_embeds = batch_input_embeds.to(model.device),
109
+ max_new_tokens = input_args["max_new_tokens"],
110
+ do_sample = input_args["do_sample"],
111
+ temperature = input_args["temperature"],
112
+ top_p = input_args["top_p"],
113
+ pad_token_id=tokenizer.pad_token_id
114
+ )
115
+ output_text += tokenizer.batch_decode(outputs, skip_special_tokens=True)
116
+
117
+ return output_text
118
+
119
+ response_dir = f"./compare_ptimes/seed_{args.seed}/gs_grad/{args.detector}_p_{args.p_times}/{args.protect_model}"
120
+ if not os.path.exists(response_dir):
121
+ os.makedirs(response_dir)
122
+
123
+ response_file = os.path.join(response_dir,f"{args.split}_results.json")
124
+
125
+ split = args.split
126
+ with open(f"datasets/attacked/temperature_{args.T}_top_p_{args.p}/{args.split}.json","r") as f:
127
+ dataset = [json.loads(item) for item in f.readlines()]
128
+ #shift_direction_embedding=torch.randn(args.p_times,suffix_embedding.shape[1])
129
+ #shift_direction_embedding=[0.0*shift_direction_embedding[0]]+[item for item in shift_direction_embedding]
130
+ shift_direction_embedding=torch.load("./gs_vectors.pt")
131
+ shift_direction_embedding=[0.0*shift_direction_embedding[0]]+[item for item in shift_direction_embedding][:args.p_times]
132
+ results=[]
133
+ for item in tqdm(dataset,total = len(dataset)):
134
+ sub_results=[]
135
+ for sft_embed in shift_direction_embedding:
136
+ original_input=item["content"]
137
+ original_input_id=tokenizer.encode(original_input,return_tensors="pt",add_special_tokens=False)[0]
138
+ original_embedding=embedding_func(original_input_id.cuda()).cpu()
139
+ shift_embeddings=[args.mu*sft_embed for _ in range(args.sample_times)]
140
+ input_embeds=embedding_shift(
141
+ original_embedding,shift_embeddings,prefix_embedding,suffix_embedding
142
+ )
143
+ do_sample=True
144
+ if args.detector_T == 0.0:
145
+ do_sample=False
146
+ llm_args={
147
+ "max_new_tokens":args.generate_length,
148
+ "do_sample":do_sample,
149
+ "temperature":args.detector_T,
150
+ "top_p":args.detector_p,
151
+ "batch_size":args.batch_size
152
+ }
153
+ with torch.no_grad():
154
+ responses = engine(input_embeds,llm_args)
155
+ sub_results.append(
156
+ sum(get_labels(responses))/args.sample_times
157
+ )
158
+ est_grad=[(sub_results[j]-sub_results[0])/args.mu*shift_direction_embedding[j] for j in range(1,len(shift_direction_embedding))]
159
+ est_grad=sum(est_grad)/len(est_grad)
160
+ results.append(
161
+ (est_grad.norm().item(),sub_results)
162
+ )
163
+ with open(response_file,"w") as f:
164
+ for item in results:
165
+ f.write(
166
+ json.dumps(
167
+ {
168
+ "est_grad":item[0],
169
+ "function_values":item[1]
170
+ }
171
+ )
172
+ )
173
+ f.write("\n")