jupyterjazz commited on
Commit
80ffb06
1 Parent(s): 215a6e1

remove-st-prompt-length (#53)

Browse files

- fix: remove prompt length from args (3f2b684e308ca75f8c3642a6b077afab29978cf6)

Files changed (1) hide show
  1. custom_st.py +1 -0
custom_st.py CHANGED
@@ -139,6 +139,7 @@ class Transformer(nn.Module):
139
  lora_arguments = (
140
  {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
141
  )
 
142
  output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
143
  output_tokens = output_states[0]
144
  features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
 
139
  lora_arguments = (
140
  {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
141
  )
142
+ features.pop('prompt_length', None)
143
  output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
144
  output_tokens = output_states[0]
145
  features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})