Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -57,14 +57,12 @@ sample_input = tok.apply_chat_template(chat, tokenize=False, add_generation_prom
|
|
57 |
input_start_id=sample_input.find(slot)
|
58 |
prefix=sample_input[:input_start_id]
|
59 |
suffix=sample_input[input_start_id+len(slot):]
|
60 |
-
print(tok.encode(prefix,return_tensors="pt")[0])
|
61 |
-
print(tok.encode(suffix,return_tensors="pt")[0])
|
62 |
prefix_embedding=embedding_func(
|
63 |
tok.encode(prefix,return_tensors="pt")[0]
|
64 |
)
|
65 |
suffix_embedding=embedding_func(
|
66 |
tok.encode(suffix,return_tensors="pt")[0]
|
67 |
-
)
|
68 |
|
69 |
#print(prefix_embedding)
|
70 |
print(f"Sucessfully loaded the model to the memory")
|
@@ -84,7 +82,7 @@ def embedding_shift(original_embedding,shift_embeddings,prefix_embedding,suffix_
|
|
84 |
return input_embeddings
|
85 |
def engine(input_embeds):
|
86 |
output_text = []
|
87 |
-
batch_size =
|
88 |
with torch.no_grad():
|
89 |
for start in range(0,len(input_embeds),batch_size):
|
90 |
batch_input_embeds = input_embeds[start:start+batch_size]
|
@@ -175,7 +173,7 @@ def chat(message, history, with_defense,perturb_times):
|
|
175 |
generate_kwargs = dict(
|
176 |
model_inputs,
|
177 |
streamer=streamer,
|
178 |
-
max_new_tokens=
|
179 |
do_sample=True,
|
180 |
top_p=0.90,
|
181 |
temperature=0.6,
|
|
|
57 |
input_start_id=sample_input.find(slot)
|
58 |
prefix=sample_input[:input_start_id]
|
59 |
suffix=sample_input[input_start_id+len(slot):]
|
|
|
|
|
60 |
prefix_embedding=embedding_func(
|
61 |
tok.encode(prefix,return_tensors="pt")[0]
|
62 |
)
|
63 |
suffix_embedding=embedding_func(
|
64 |
tok.encode(suffix,return_tensors="pt")[0]
|
65 |
+
)[1:]
|
66 |
|
67 |
#print(prefix_embedding)
|
68 |
print(f"Sucessfully loaded the model to the memory")
|
|
|
82 |
return input_embeddings
|
83 |
def engine(input_embeds):
|
84 |
output_text = []
|
85 |
+
batch_size = 5
|
86 |
with torch.no_grad():
|
87 |
for start in range(0,len(input_embeds),batch_size):
|
88 |
batch_input_embeds = input_embeds[start:start+batch_size]
|
|
|
173 |
generate_kwargs = dict(
|
174 |
model_inputs,
|
175 |
streamer=streamer,
|
176 |
+
max_new_tokens=256,
|
177 |
do_sample=True,
|
178 |
top_p=0.90,
|
179 |
temperature=0.6,
|