gregH commited on
Commit
1980cb9
1 Parent(s): 058578b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -57,14 +57,12 @@ sample_input = tok.apply_chat_template(chat, tokenize=False, add_generation_prom
57
  input_start_id=sample_input.find(slot)
58
  prefix=sample_input[:input_start_id]
59
  suffix=sample_input[input_start_id+len(slot):]
60
- print(tok.encode(prefix,return_tensors="pt")[0])
61
- print(tok.encode(suffix,return_tensors="pt")[0])
62
  prefix_embedding=embedding_func(
63
  tok.encode(prefix,return_tensors="pt")[0]
64
  )
65
  suffix_embedding=embedding_func(
66
  tok.encode(suffix,return_tensors="pt")[0]
67
- )
68
 
69
  #print(prefix_embedding)
70
  print(f"Sucessfully loaded the model to the memory")
@@ -84,7 +82,7 @@ def embedding_shift(original_embedding,shift_embeddings,prefix_embedding,suffix_
84
  return input_embeddings
85
  def engine(input_embeds):
86
  output_text = []
87
- batch_size = 10
88
  with torch.no_grad():
89
  for start in range(0,len(input_embeds),batch_size):
90
  batch_input_embeds = input_embeds[start:start+batch_size]
@@ -175,7 +173,7 @@ def chat(message, history, with_defense,perturb_times):
175
  generate_kwargs = dict(
176
  model_inputs,
177
  streamer=streamer,
178
- max_new_tokens=1024,
179
  do_sample=True,
180
  top_p=0.90,
181
  temperature=0.6,
 
57
  input_start_id=sample_input.find(slot)
58
  prefix=sample_input[:input_start_id]
59
  suffix=sample_input[input_start_id+len(slot):]
 
 
60
  prefix_embedding=embedding_func(
61
  tok.encode(prefix,return_tensors="pt")[0]
62
  )
63
  suffix_embedding=embedding_func(
64
  tok.encode(suffix,return_tensors="pt")[0]
65
+ )[1:]
66
 
67
  #print(prefix_embedding)
68
  print(f"Sucessfully loaded the model to the memory")
 
82
  return input_embeddings
83
  def engine(input_embeds):
84
  output_text = []
85
+ batch_size = 5
86
  with torch.no_grad():
87
  for start in range(0,len(input_embeds),batch_size):
88
  batch_input_embeds = input_embeds[start:start+batch_size]
 
173
  generate_kwargs = dict(
174
  model_inputs,
175
  streamer=streamer,
176
+ max_new_tokens=256,
177
  do_sample=True,
178
  top_p=0.90,
179
  temperature=0.6,