tnk2908 commited on
Commit
50b0c43
1 Parent(s): 8d94857

Fix bugs: Message rate and tokens information are computed incorrectly for OPT models

Browse files
Files changed (1) hide show
  1. stegno.py +5 -5
stegno.py CHANGED
@@ -79,13 +79,13 @@ def generate(
79
  )
80
 
81
  output_tokens = output_tokens[:, prompt_size:]
82
- output_text = tokenizer.batch_decode(
83
- output_tokens, skip_special_tokens=True
84
- )[0]
85
- output_tokens_post = tokenizer(output_text, return_tensors="pt").to(
86
  model.device
87
  )
88
- msg_rates, tokens_infos = logits_processor.validate(output_tokens_post.input_ids)
 
 
89
 
90
  return output_text, msg_rates[0], tokens_infos[0]
91
 
 
79
  )
80
 
81
  output_tokens = output_tokens[:, prompt_size:]
82
+ output_text = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
83
+ output_tokens_post = tokenizer(output_text, return_tensors="pt", add_special_tokens=False).to(
 
 
84
  model.device
85
  )
86
+ msg_rates, tokens_infos = logits_processor.validate(
87
+ output_tokens_post.input_ids
88
+ )
89
 
90
  return output_text, msg_rates[0], tokens_infos[0]
91