duzx16
commited on
Commit
•
99564c0
1
Parent(s):
ccb0160
Update modeling_chatglm.py
Browse files- config.json +1 -0
- modeling_chatglm.py +12 -8
config.json
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
{
|
2 |
"_name_or_path": "THUDM/chatglm2-6b",
|
|
|
3 |
"architectures": [
|
4 |
"ChatGLMModel"
|
5 |
],
|
|
|
1 |
{
|
2 |
"_name_or_path": "THUDM/chatglm2-6b",
|
3 |
+
"model_type": "chatglm",
|
4 |
"architectures": [
|
5 |
"ChatGLMModel"
|
6 |
],
|
modeling_chatglm.py
CHANGED
@@ -702,6 +702,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
702 |
dtype=config.torch_dtype, **init_kwargs)
|
703 |
self.gradient_checkpointing = False
|
704 |
|
|
|
|
|
|
|
705 |
def forward(
|
706 |
self,
|
707 |
input_ids,
|
@@ -932,7 +935,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
932 |
|
933 |
|
934 |
@torch.no_grad()
|
935 |
-
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int =
|
936 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
937 |
if history is None:
|
938 |
history = []
|
@@ -951,7 +954,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
951 |
|
952 |
@torch.no_grad()
|
953 |
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
|
954 |
-
max_length: int =
|
955 |
return_past_key_values=False, **kwargs):
|
956 |
if history is None:
|
957 |
history = []
|
@@ -976,12 +979,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
976 |
outputs, past_key_values = outputs
|
977 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
978 |
response = tokenizer.decode(outputs)
|
979 |
-
response
|
980 |
-
|
981 |
-
|
982 |
-
|
983 |
-
|
984 |
-
|
|
|
985 |
|
986 |
@torch.no_grad()
|
987 |
def stream_generate(
|
|
|
702 |
dtype=config.torch_dtype, **init_kwargs)
|
703 |
self.gradient_checkpointing = False
|
704 |
|
705 |
+
def get_input_embeddings(self):
|
706 |
+
return self.embedding.word_embeddings
|
707 |
+
|
708 |
def forward(
|
709 |
self,
|
710 |
input_ids,
|
|
|
935 |
|
936 |
|
937 |
@torch.no_grad()
|
938 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
939 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
940 |
if history is None:
|
941 |
history = []
|
|
|
954 |
|
955 |
@torch.no_grad()
|
956 |
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
|
957 |
+
max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
958 |
return_past_key_values=False, **kwargs):
|
959 |
if history is None:
|
960 |
history = []
|
|
|
979 |
outputs, past_key_values = outputs
|
980 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
981 |
response = tokenizer.decode(outputs)
|
982 |
+
if response and response[-1] != "�":
|
983 |
+
response = self.process_response(response)
|
984 |
+
new_history = history + [(query, response)]
|
985 |
+
if return_past_key_values:
|
986 |
+
yield response, new_history, past_key_values
|
987 |
+
else:
|
988 |
+
yield response, new_history
|
989 |
|
990 |
@torch.no_grad()
|
991 |
def stream_generate(
|