Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -18,7 +18,11 @@ ban_list=[
|
|
18 |
"I'm not able to",
|
19 |
"I'm unable to"
|
20 |
]
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
def refuse(response):
|
23 |
for item in ban_list:
|
24 |
if item in response:
|
@@ -37,12 +41,12 @@ print(f"Starting to load the model to memory")
|
|
37 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
38 |
|
39 |
m = AutoModelForCausalLM.from_pretrained(
|
40 |
-
"
|
41 |
embedding_func=m.get_input_embeddings()
|
42 |
embedding_func.weight.requires_grad=False
|
43 |
m = m.to(device)
|
44 |
|
45 |
-
tok = AutoTokenizer.from_pretrained("
|
46 |
tok.padding_side = "left"
|
47 |
tok.pad_token_id = tok.eos_token_id
|
48 |
# using CUDA for an optimal experience
|
@@ -60,7 +64,7 @@ suffix_embedding=embedding_func(
|
|
60 |
)
|
61 |
#print(prefix_embedding)
|
62 |
print(f"Sucessfully loaded the model to the memory")
|
63 |
-
shift_direction_embedding=torch.randn(
|
64 |
shift_direction_embedding=[item for item in shift_direction_embedding]
|
65 |
start_message = ""
|
66 |
|
@@ -95,11 +99,11 @@ def user(message, history):
|
|
95 |
# Append the user's message to the conversation history
|
96 |
return "", history + [[message, ""]]
|
97 |
|
98 |
-
def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
|
99 |
#to determine whether the query is malicious
|
100 |
|
101 |
# first-stage rejection
|
102 |
-
if
|
103 |
return False
|
104 |
sft_embed=shift_direction_embedding[0]*0.0
|
105 |
original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
|
@@ -142,8 +146,10 @@ def gradient_cuff_reject(message,sample_times,perturb_times,threshold):
|
|
142 |
return True
|
143 |
return False
|
144 |
|
145 |
-
def chat(message, history,
|
146 |
-
|
|
|
|
|
147 |
answer="[Gradient Cuff Rejection] I cannot fulfill your request".split(" ")
|
148 |
partial_text = ""
|
149 |
for new_text in answer:
|
@@ -178,19 +184,17 @@ def chat(message, history, sample_times, perturb_times,threshold):
|
|
178 |
partial_text = ""
|
179 |
for new_text in streamer:
|
180 |
partial_text += new_text
|
181 |
-
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
182 |
yield partial_text
|
183 |
|
184 |
-
|
185 |
add_inputs_name=gr.Accordion(label="Defense Parameters", open=True)
|
186 |
add_inputs=[
|
187 |
-
gr.
|
188 |
-
gr.Slider(minimum=0, maximum=
|
189 |
-
gr.Slider(minimum=0, maximum=1000, step=1, value=100, label="t - threshold")
|
190 |
]
|
191 |
#######################################################################################
|
192 |
with gr.ChatInterface(fn=chat,
|
193 |
-
title="Gradient Cuff
|
194 |
additional_inputs=add_inputs,
|
195 |
additional_inputs_accordion=add_inputs_name
|
196 |
) as demo:
|
|
|
18 |
"I'm not able to",
|
19 |
"I'm unable to"
|
20 |
]
|
21 |
+
thresholds=[3267.012939453125, 1633.5064697265625, 1174.0875244140625, 1190.5863037109375, 952.468994140625,
|
22 |
+
793.7241821289062, 680.3349609375, 595.2931518554688, 529.1494140625, 476.2344970703125,
|
23 |
+
432.9404602050781, 396.8620910644531, 418.0110168457031, 388.15301513671875, 388.80059814453125,
|
24 |
+
414.806884765625, 390.40643310546875, 380.5647888183594, 362.990478515625, 376.3833923339844
|
25 |
+
]
|
26 |
def refuse(response):
|
27 |
for item in ban_list:
|
28 |
if item in response:
|
|
|
41 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
42 |
|
43 |
m = AutoModelForCausalLM.from_pretrained(
|
44 |
+
"google/gemma-2b-it", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True)
|
45 |
embedding_func=m.get_input_embeddings()
|
46 |
embedding_func.weight.requires_grad=False
|
47 |
m = m.to(device)
|
48 |
|
49 |
+
tok = AutoTokenizer.from_pretrained("google/gemma-2b-it", trust_remote_code=True)
|
50 |
tok.padding_side = "left"
|
51 |
tok.pad_token_id = tok.eos_token_id
|
52 |
# using CUDA for an optimal experience
|
|
|
64 |
)
|
65 |
#print(prefix_embedding)
|
66 |
print(f"Sucessfully loaded the model to the memory")
|
67 |
+
shift_direction_embedding=torch.randn(20,prefix_embedding.shape[-1])
|
68 |
shift_direction_embedding=[item for item in shift_direction_embedding]
|
69 |
start_message = ""
|
70 |
|
|
|
99 |
# Append the user's message to the conversation history
|
100 |
return "", history + [[message, ""]]
|
101 |
|
102 |
+
def gradient_cuff_reject(message,with_defense, sample_times,perturb_times,threshold):
|
103 |
#to determine whether the query is malicious
|
104 |
|
105 |
# first-stage rejection
|
106 |
+
if not with_defense:
|
107 |
return False
|
108 |
sft_embed=shift_direction_embedding[0]*0.0
|
109 |
original_input_id=tok.encode(message,return_tensors="pt",add_special_tokens=False)[0]
|
|
|
146 |
return True
|
147 |
return False
|
148 |
|
149 |
+
def chat(message, history, with_defense,perturb_times):
|
150 |
+
sample_times=20
|
151 |
+
threshold=thresholds[perturb_times-1]
|
152 |
+
if gradient_cuff_reject(message,with_defense, sample_times, perturb_times, threshold):
|
153 |
answer="[Gradient Cuff Rejection] I cannot fulfill your request".split(" ")
|
154 |
partial_text = ""
|
155 |
for new_text in answer:
|
|
|
184 |
partial_text = ""
|
185 |
for new_text in streamer:
|
186 |
partial_text += new_text
|
|
|
187 |
yield partial_text
|
188 |
|
189 |
+
|
190 |
add_inputs_name=gr.Accordion(label="Defense Parameters", open=True)
|
191 |
add_inputs=[
|
192 |
+
gr.Checkbox("w/ Gradient Cuff", label="Defense", info="Whether to apply defense"),
|
193 |
+
gr.Slider(minimum=0, maximum=20, step=1, value=2, label="P - Perturb times", info = "The number of the perturbation vectors used to estimate the gradient.")
|
|
|
194 |
]
|
195 |
#######################################################################################
|
196 |
with gr.ChatInterface(fn=chat,
|
197 |
+
title="Gradient Cuff Gemma-2b-it",
|
198 |
additional_inputs=add_inputs,
|
199 |
additional_inputs_accordion=add_inputs_name
|
200 |
) as demo:
|