Update modeling_internlm.py
Browse files- modeling_internlm.py +3 -5
modeling_internlm.py
CHANGED
@@ -96,7 +96,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
|
96 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
97 |
super().__init__()
|
98 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
99 |
-
self.register_buffer("inv_freq", inv_freq)
|
100 |
|
101 |
# Build here to make `torch.jit.trace` work.
|
102 |
self.max_seq_len_cached = max_position_embeddings
|
@@ -769,9 +769,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
769 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
|
770 |
prompt = ""
|
771 |
for record in history:
|
772 |
-
prompt += f"""
|
773 |
-
if len(prompt) == 0:
|
774 |
-
prompt += "<s>"
|
775 |
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
|
776 |
return tokenizer([prompt], return_tensors="pt")
|
777 |
|
@@ -995,4 +993,4 @@ class InternLMForSequenceClassification(InternLMPreTrainedModel):
|
|
995 |
past_key_values=transformer_outputs.past_key_values,
|
996 |
hidden_states=transformer_outputs.hidden_states,
|
997 |
attentions=transformer_outputs.attentions,
|
998 |
-
)
|
|
|
96 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
97 |
super().__init__()
|
98 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
99 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
100 |
|
101 |
# Build here to make `torch.jit.trace` work.
|
102 |
self.max_seq_len_cached = max_position_embeddings
|
|
|
769 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
|
770 |
prompt = ""
|
771 |
for record in history:
|
772 |
+
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
|
|
|
|
|
773 |
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
|
774 |
return tokenizer([prompt], return_tensors="pt")
|
775 |
|
|
|
993 |
past_key_values=transformer_outputs.past_key_values,
|
994 |
hidden_states=transformer_outputs.hidden_states,
|
995 |
attentions=transformer_outputs.attentions,
|
996 |
+
)
|