Fix Chinese punctuation
Browse files- modeling_chatglm.py +18 -4
modeling_chatglm.py
CHANGED
@@ -4,6 +4,7 @@ import math
|
|
4 |
import copy
|
5 |
import os
|
6 |
import warnings
|
|
|
7 |
|
8 |
import torch
|
9 |
import torch.utils.checkpoint
|
@@ -1099,6 +1100,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1099 |
for layer_past in past
|
1100 |
)
|
1101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1102 |
@torch.no_grad()
|
1103 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1104 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
@@ -1121,8 +1137,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1121 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1122 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1123 |
response = tokenizer.decode(outputs)
|
1124 |
-
response =
|
1125 |
-
response = response.replace("[[训练时间]]", "2023年")
|
1126 |
history = history + [(query, response)]
|
1127 |
return response, history
|
1128 |
|
@@ -1148,8 +1163,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1148 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
1149 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1150 |
response = tokenizer.decode(outputs)
|
1151 |
-
response =
|
1152 |
-
response = response.replace("[[训练时间]]", "2023年")
|
1153 |
new_history = history + [(query, response)]
|
1154 |
yield response, new_history
|
1155 |
|
|
|
4 |
import copy
|
5 |
import os
|
6 |
import warnings
|
7 |
+
import re
|
8 |
|
9 |
import torch
|
10 |
import torch.utils.checkpoint
|
|
|
1100 |
for layer_past in past
|
1101 |
)
|
1102 |
|
1103 |
+
def process_response(self, response):
|
1104 |
+
response = response.strip()
|
1105 |
+
response = response.replace("[[训练时间]]", "2023年")
|
1106 |
+
punkts = [
|
1107 |
+
[",", ","],
|
1108 |
+
["!", "!"],
|
1109 |
+
[":", ":"],
|
1110 |
+
[";", ";"],
|
1111 |
+
["\?", "?"],
|
1112 |
+
]
|
1113 |
+
for item in punkts:
|
1114 |
+
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
1115 |
+
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
1116 |
+
return response
|
1117 |
+
|
1118 |
@torch.no_grad()
|
1119 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1120 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
|
1137 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1138 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1139 |
response = tokenizer.decode(outputs)
|
1140 |
+
response = self.process_response(response)
|
|
|
1141 |
history = history + [(query, response)]
|
1142 |
return response, history
|
1143 |
|
|
|
1163 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
1164 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1165 |
response = tokenizer.decode(outputs)
|
1166 |
+
response = self.process_response(response)
|
|
|
1167 |
new_history = history + [(query, response)]
|
1168 |
yield response, new_history
|
1169 |
|