gregH commited on
Commit
8e698f3
1 Parent(s): faa8ebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -18,7 +18,11 @@ ban_list=[
18
  "I'm not able to",
19
  "I'm unable to"
20
  ]
21
-
 
 
 
 
22
  def refuse(response):
23
  for item in ban_list:
24
  if item in response:
@@ -37,12 +41,12 @@ print(f"Starting to load the model to memory")
37
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
 
39
  m = AutoModelForCausalLM.from_pretrained(
40
- "stabilityai/stablelm-2-zephyr-1_6b", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True)
41
  embedding_func=m.get_input_embeddings()
42
  embedding_func.weight.requires_grad=False
43
  m = m.to(device)
44
 
45
- tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-zephyr-1_6b", trust_remote_code=True)
46
  tok.padding_side = "left"
47
  tok.pad_token_id = tok.eos_token_id
48
  # using CUDA for an optimal experience
@@ -60,7 +64,7 @@ suffix_embedding=embedding_func(
60
  )
61
  #print(prefix_embedding)
62
  print(f"Sucessfully loaded the model to the memory")
63
- shift_direction_embedding=torch.randn(10,prefix_embedding.shape[-1])
64
  shift_direction_embedding=[item for item in shift_direction_embedding]
65
  start_message = ""
66
 
@@ -95,11 +99,11 @@ def user(message, history):
95
  # Append the user's message to the conversation history
96
  return "", history + [[message, ""]]
97
 
98
- def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
99
  #to determine whether the query is malicious
100
 
101
  # first-stage rejection
102
- if sample_times==0:
103
  return False
104
  sft_embed=shift_direction_embedding[0]*0.0
105
  original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
@@ -142,8 +146,10 @@ def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
142
  return True
143
  return False
144
 
145
- def chat(message, history, sample_times, perturb_times,threshold):
146
- if gradient_cuff_reject(message,sample_times,perturb_times,threshold):
 
 
147
  answer="[Gradient Cuff Rejection] I cannot fulfill your request".split(" ")
148
  partial_text = ""
149
  for new_text in answer:
@@ -178,19 +184,17 @@ def chat(message, history, sample_times, perturb_times,threshold):
178
  partial_text = ""
179
  for new_text in streamer:
180
  partial_text += new_text
181
- # Yield an empty string to cleanup the message textbox and the updated conversation history
182
  yield partial_text
183
 
184
- #demo = gr.ChatInterface(fn=chat, examples=["hello", "hola", "merhaba"], title="Gradient Cuff Vicuna-7B-V1.5")
185
  add_inputs_name=gr.Accordion(label="Defense Parameters", open=True)
186
  add_inputs=[
187
- gr.Slider(minimum=0, maximum=10, step=1, value=2, label="N - Sample times"),
188
- gr.Slider(minimum=0, maximum=10, step=1, value=2, label="P - Perturb times"),
189
- gr.Slider(minimum=0, maximum=1000, step=1, value=100, label="t - threshold")
190
  ]
191
  #######################################################################################
192
  with gr.ChatInterface(fn=chat,
193
- title="Gradient Cuff Stablelm-2-zephyr-1_6b",
194
  additional_inputs=add_inputs,
195
  additional_inputs_accordion=add_inputs_name
196
  ) as demo:
 
18
  "I'm not able to",
19
  "I'm unable to"
20
  ]
21
+ thresholds=[3267.012939453125, 1633.5064697265625, 1174.0875244140625, 1190.5863037109375, 952.468994140625,
22
+ 793.7241821289062, 680.3349609375, 595.2931518554688, 529.1494140625, 476.2344970703125,
23
+ 432.9404602050781, 396.8620910644531, 418.0110168457031, 388.15301513671875, 388.80059814453125,
24
+ 414.806884765625, 390.40643310546875, 380.5647888183594, 362.990478515625, 376.3833923339844
25
+ ]
26
  def refuse(response):
27
  for item in ban_list:
28
  if item in response:
 
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
  m = AutoModelForCausalLM.from_pretrained(
44
+ "google/gemma-2b-it", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True)
45
  embedding_func=m.get_input_embeddings()
46
  embedding_func.weight.requires_grad=False
47
  m = m.to(device)
48
 
49
+ tok = AutoTokenizer.from_pretrained("google/gemma-2b-it", trust_remote_code=True)
50
  tok.padding_side = "left"
51
  tok.pad_token_id = tok.eos_token_id
52
  # using CUDA for an optimal experience
 
64
  )
65
  #print(prefix_embedding)
66
  print(f"Sucessfully loaded the model to the memory")
67
+ shift_direction_embedding=torch.randn(20,prefix_embedding.shape[-1])
68
  shift_direction_embedding=[item for item in shift_direction_embedding]
69
  start_message = ""
70
 
 
99
  # Append the user's message to the conversation history
100
  return "", history + [[message, ""]]
101
 
102
+ def gradient_cuff_reject(message,with_defense, sample_times,perturb_times,threshold):
103
  #to determine whether the query is malicious
104
 
105
  # first-stage rejection
106
+ if not with_defense:
107
  return False
108
  sft_embed=shift_direction_embedding[0]*0.0
109
  original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
 
146
  return True
147
  return False
148
 
149
+ def chat(message, history, with_defense,perturb_times):
150
+ sample_times=20
151
+ threshold=thresholds[perturb_times-1]
152
+ if gradient_cuff_reject(message,with_defense, sample_times, perturb_times, threshold):
153
  answer="[Gradient Cuff Rejection] I cannot fulfill your request".split(" ")
154
  partial_text = ""
155
  for new_text in answer:
 
184
  partial_text = ""
185
  for new_text in streamer:
186
  partial_text += new_text
 
187
  yield partial_text
188
 
189
+
190
  add_inputs_name=gr.Accordion(label="Defense Parameters", open=True)
191
  add_inputs=[
192
+ gr.Checkbox("w/ Gradient Cuff", label="Defense", info="Whether to apply defense"),
193
+ gr.Slider(minimum=0, maximum=20, step=1, value=2, label="P - Perturb times", info = "The number of the perturbation vectors used to estimate the gradient.")
 
194
  ]
195
  #######################################################################################
196
  with gr.ChatInterface(fn=chat,
197
+ title="Gradient Cuff Gemma-2b-it",
198
  additional_inputs=add_inputs,
199
  additional_inputs_accordion=add_inputs_name
200
  ) as demo: