ZwwWayne
commited on
Commit
•
0e5f375
1
Parent(s):
03da3f2
fix: add eoa into eos_token_id in chat to accelerate chat interface
Browse files- modeling_internlm2.py +3 -0
modeling_internlm2.py
CHANGED
@@ -1049,6 +1049,8 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
1049 |
):
|
1050 |
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
1051 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
|
|
|
|
1052 |
outputs = self.generate(
|
1053 |
**inputs,
|
1054 |
streamer=streamer,
|
@@ -1056,6 +1058,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
1056 |
do_sample=do_sample,
|
1057 |
temperature=temperature,
|
1058 |
top_p=top_p,
|
|
|
1059 |
**kwargs,
|
1060 |
)
|
1061 |
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
|
|
|
1049 |
):
|
1050 |
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
1051 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
1052 |
+
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
1053 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["[UNUSED_TOKEN_145]"])[0]]
|
1054 |
outputs = self.generate(
|
1055 |
**inputs,
|
1056 |
streamer=streamer,
|
|
|
1058 |
do_sample=do_sample,
|
1059 |
temperature=temperature,
|
1060 |
top_p=top_p,
|
1061 |
+
eos_token_id=eos_token_id,
|
1062 |
**kwargs,
|
1063 |
)
|
1064 |
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
|