jupyterjazz
commited on
Commit
•
80ffb06
1
Parent(s):
215a6e1
remove-st-prompt-length (#53)
Browse files- fix: remove prompt length from args (3f2b684e308ca75f8c3642a6b077afab29978cf6)
- custom_st.py +1 -0
custom_st.py
CHANGED
@@ -139,6 +139,7 @@ class Transformer(nn.Module):
|
|
139 |
lora_arguments = (
|
140 |
{"adapter_mask": adapter_mask} if adapter_mask is not None else {}
|
141 |
)
|
|
|
142 |
output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
|
143 |
output_tokens = output_states[0]
|
144 |
features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
|
|
|
139 |
lora_arguments = (
|
140 |
{"adapter_mask": adapter_mask} if adapter_mask is not None else {}
|
141 |
)
|
142 |
+
features.pop('prompt_length', None)
|
143 |
output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
|
144 |
output_tokens = output_states[0]
|
145 |
features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
|