Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +3 -0
modeling_rwkv5.py
CHANGED
@@ -789,6 +789,9 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
|
|
789 |
# only last token for inputs_ids if the state is passed along.
|
790 |
if state is not None:
|
791 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
|
|
|
|
|
792 |
|
793 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
794 |
if inputs_embeds is not None and state is None:
|
|
|
789 |
# only last token for inputs_ids if the state is passed along.
|
790 |
if state is not None:
|
791 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
792 |
+
else:
|
793 |
+
# add in \n at the beginning
|
794 |
+
input_ids = torch.cat([torch.full([1,1],11,device=input_ids.device,dtype=input_ids.dtype), input_ids])
|
795 |
|
796 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
797 |
if inputs_embeds is not None and state is None:
|