Tuchuanhuhuhu commited on
Commit
90d39c3
1 Parent(s): 75a2593

大幅度改进了输入输出解析。

Browse files

- 新的代码高亮模块。使用pygments实现。抛弃code_fence和code_highlite,它们的实现方式存在问题。
- 规范化GPT输出的markdown的功能。即使GPT输出不规范的Markdown列表,也能正常显示。
- 现在可以正常显示Shell脚本了。Shell脚本中的美元符不再与LaTex冲突。
- 删除parse_text函数。直接在Chatbot组件中渲染。讲道理我可以给gradio提一个pr。

Files changed (5) hide show
  1. chat_func.py +7 -7
  2. custom.css +69 -69
  3. overwrites.py +2 -8
  4. requirements.txt +1 -0
  5. utils.py +70 -26
chat_func.py CHANGED
@@ -115,9 +115,9 @@ def stream_predict(
115
  history.append(construct_user(inputs))
116
  history.append(construct_assistant(""))
117
  if fake_input:
118
- chatbot.append((parse_text(fake_input), ""))
119
  else:
120
- chatbot.append((parse_text(inputs), ""))
121
  user_token_count = 0
122
  if len(all_token_counts) == 0:
123
  system_prompt_token_count = count_token(construct_system(system_prompt))
@@ -192,7 +192,7 @@ def stream_predict(
192
  yield get_return_value()
193
  break
194
  history[-1] = construct_assistant(partial_words)
195
- chatbot[-1] = (chatbot[-1][0], parse_text(partial_words+display_append))
196
  all_token_counts[-1] += 1
197
  yield get_return_value()
198
 
@@ -214,9 +214,9 @@ def predict_all(
214
  history.append(construct_user(inputs))
215
  history.append(construct_assistant(""))
216
  if fake_input:
217
- chatbot.append((parse_text(fake_input), ""))
218
  else:
219
- chatbot.append((parse_text(inputs), ""))
220
  all_token_counts.append(count_token(construct_user(inputs)))
221
  try:
222
  response = get_response(
@@ -242,7 +242,7 @@ def predict_all(
242
  response = json.loads(response.text)
243
  content = response["choices"][0]["message"]["content"]
244
  history[-1] = construct_assistant(content)
245
- chatbot[-1] = (chatbot[-1][0], parse_text(content+display_append))
246
  total_token_count = response["usage"]["total_tokens"]
247
  all_token_counts[-1] = total_token_count - sum(all_token_counts)
248
  status_text = construct_token_message(total_token_count)
@@ -299,7 +299,7 @@ def predict(
299
  if len(openai_api_key) != 51:
300
  status_text = standard_error_msg + no_apikey_msg
301
  logging.info(status_text)
302
- chatbot.append((parse_text(inputs), ""))
303
  if len(history) == 0:
304
  history.append(construct_user(inputs))
305
  history.append("")
 
115
  history.append(construct_user(inputs))
116
  history.append(construct_assistant(""))
117
  if fake_input:
118
+ chatbot.append((fake_input, ""))
119
  else:
120
+ chatbot.append((inputs, ""))
121
  user_token_count = 0
122
  if len(all_token_counts) == 0:
123
  system_prompt_token_count = count_token(construct_system(system_prompt))
 
192
  yield get_return_value()
193
  break
194
  history[-1] = construct_assistant(partial_words)
195
+ chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
196
  all_token_counts[-1] += 1
197
  yield get_return_value()
198
 
 
214
  history.append(construct_user(inputs))
215
  history.append(construct_assistant(""))
216
  if fake_input:
217
+ chatbot.append((fake_input, ""))
218
  else:
219
+ chatbot.append((inputs, ""))
220
  all_token_counts.append(count_token(construct_user(inputs)))
221
  try:
222
  response = get_response(
 
242
  response = json.loads(response.text)
243
  content = response["choices"][0]["message"]["content"]
244
  history[-1] = construct_assistant(content)
245
+ chatbot[-1] = (chatbot[-1][0], content+display_append)
246
  total_token_count = response["usage"]["total_tokens"]
247
  all_token_counts[-1] = total_token_count - sum(all_token_counts)
248
  status_text = construct_token_message(total_token_count)
 
299
  if len(openai_api_key) != 51:
300
  status_text = standard_error_msg + no_apikey_msg
301
  logging.info(status_text)
302
+ chatbot.append((inputs, ""))
303
  if len(history) == 0:
304
  history.append(construct_user(inputs))
305
  history.append("")
custom.css CHANGED
@@ -130,72 +130,72 @@ pre code {
130
  box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
131
  }
132
  /* 代码高亮样式 */
133
- .codehilite .hll { background-color: #49483e }
134
- .codehilite .c { color: #75715e } /* Comment */
135
- .codehilite .err { color: #960050; background-color: #1e0010 } /* Error */
136
- .codehilite .k { color: #66d9ef } /* Keyword */
137
- .codehilite .l { color: #ae81ff } /* Literal */
138
- .codehilite .n { color: #f8f8f2 } /* Name */
139
- .codehilite .o { color: #f92672 } /* Operator */
140
- .codehilite .p { color: #f8f8f2 } /* Punctuation */
141
- .codehilite .ch { color: #75715e } /* Comment.Hashbang */
142
- .codehilite .cm { color: #75715e } /* Comment.Multiline */
143
- .codehilite .cp { color: #75715e } /* Comment.Preproc */
144
- .codehilite .cpf { color: #75715e } /* Comment.PreprocFile */
145
- .codehilite .c1 { color: #75715e } /* Comment.Single */
146
- .codehilite .cs { color: #75715e } /* Comment.Special */
147
- .codehilite .gd { color: #f92672 } /* Generic.Deleted */
148
- .codehilite .ge { font-style: italic } /* Generic.Emph */
149
- .codehilite .gi { color: #a6e22e } /* Generic.Inserted */
150
- .codehilite .gs { font-weight: bold } /* Generic.Strong */
151
- .codehilite .gu { color: #75715e } /* Generic.Subheading */
152
- .codehilite .kc { color: #66d9ef } /* Keyword.Constant */
153
- .codehilite .kd { color: #66d9ef } /* Keyword.Declaration */
154
- .codehilite .kn { color: #f92672 } /* Keyword.Namespace */
155
- .codehilite .kp { color: #66d9ef } /* Keyword.Pseudo */
156
- .codehilite .kr { color: #66d9ef } /* Keyword.Reserved */
157
- .codehilite .kt { color: #66d9ef } /* Keyword.Type */
158
- .codehilite .ld { color: #e6db74 } /* Literal.Date */
159
- .codehilite .m { color: #ae81ff } /* Literal.Number */
160
- .codehilite .s { color: #e6db74 } /* Literal.String */
161
- .codehilite .na { color: #a6e22e } /* Name.Attribute */
162
- .codehilite .nb { color: #f8f8f2 } /* Name.Builtin */
163
- .codehilite .nc { color: #a6e22e } /* Name.Class */
164
- .codehilite .no { color: #66d9ef } /* Name.Constant */
165
- .codehilite .nd { color: #a6e22e } /* Name.Decorator */
166
- .codehilite .ni { color: #f8f8f2 } /* Name.Entity */
167
- .codehilite .ne { color: #a6e22e } /* Name.Exception */
168
- .codehilite .nf { color: #a6e22e } /* Name.Function */
169
- .codehilite .nl { color: #f8f8f2 } /* Name.Label */
170
- .codehilite .nn { color: #f8f8f2 } /* Name.Namespace */
171
- .codehilite .nx { color: #a6e22e } /* Name.Other */
172
- .codehilite .py { color: #f8f8f2 } /* Name.Property */
173
- .codehilite .nt { color: #f92672 } /* Name.Tag */
174
- .codehilite .nv { color: #f8f8f2 } /* Name.Variable */
175
- .codehilite .ow { color: #f92672 } /* Operator.Word */
176
- .codehilite .w { color: #f8f8f2 } /* Text.Whitespace */
177
- .codehilite .mb { color: #ae81ff } /* Literal.Number.Bin */
178
- .codehilite .mf { color: #ae81ff } /* Literal.Number.Float */
179
- .codehilite .mh { color: #ae81ff } /* Literal.Number.Hex */
180
- .codehilite .mi { color: #ae81ff } /* Literal.Number.Integer */
181
- .codehilite .mo { color: #ae81ff } /* Literal.Number.Oct */
182
- .codehilite .sa { color: #e6db74 } /* Literal.String.Affix */
183
- .codehilite .sb { color: #e6db74 } /* Literal.String.Backtick */
184
- .codehilite .sc { color: #e6db74 } /* Literal.String.Char */
185
- .codehilite .dl { color: #e6db74 } /* Literal.String.Delimiter */
186
- .codehilite .sd { color: #e6db74 } /* Literal.String.Doc */
187
- .codehilite .s2 { color: #e6db74 } /* Literal.String.Double */
188
- .codehilite .se { color: #ae81ff } /* Literal.String.Escape */
189
- .codehilite .sh { color: #e6db74 } /* Literal.String.Heredoc */
190
- .codehilite .si { color: #e6db74 } /* Literal.String.Interpol */
191
- .codehilite .sx { color: #e6db74 } /* Literal.String.Other */
192
- .codehilite .sr { color: #e6db74 } /* Literal.String.Regex */
193
- .codehilite .s1 { color: #e6db74 } /* Literal.String.Single */
194
- .codehilite .ss { color: #e6db74 } /* Literal.String.Symbol */
195
- .codehilite .bp { color: #f8f8f2 } /* Name.Builtin.Pseudo */
196
- .codehilite .fm { color: #a6e22e } /* Name.Function.Magic */
197
- .codehilite .vc { color: #f8f8f2 } /* Name.Variable.Class */
198
- .codehilite .vg { color: #f8f8f2 } /* Name.Variable.Global */
199
- .codehilite .vi { color: #f8f8f2 } /* Name.Variable.Instance */
200
- .codehilite .vm { color: #f8f8f2 } /* Name.Variable.Magic */
201
- .codehilite .il { color: #ae81ff } /* Literal.Number.Integer.Long */
 
130
  box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
131
  }
132
  /* 代码高亮样式 */
133
+ .highlight .hll { background-color: #49483e }
134
+ .highlight .c { color: #75715e } /* Comment */
135
+ .highlight .err { color: #960050; background-color: #1e0010 } /* Error */
136
+ .highlight .k { color: #66d9ef } /* Keyword */
137
+ .highlight .l { color: #ae81ff } /* Literal */
138
+ .highlight .n { color: #f8f8f2 } /* Name */
139
+ .highlight .o { color: #f92672 } /* Operator */
140
+ .highlight .p { color: #f8f8f2 } /* Punctuation */
141
+ .highlight .ch { color: #75715e } /* Comment.Hashbang */
142
+ .highlight .cm { color: #75715e } /* Comment.Multiline */
143
+ .highlight .cp { color: #75715e } /* Comment.Preproc */
144
+ .highlight .cpf { color: #75715e } /* Comment.PreprocFile */
145
+ .highlight .c1 { color: #75715e } /* Comment.Single */
146
+ .highlight .cs { color: #75715e } /* Comment.Special */
147
+ .highlight .gd { color: #f92672 } /* Generic.Deleted */
148
+ .highlight .ge { font-style: italic } /* Generic.Emph */
149
+ .highlight .gi { color: #a6e22e } /* Generic.Inserted */
150
+ .highlight .gs { font-weight: bold } /* Generic.Strong */
151
+ .highlight .gu { color: #75715e } /* Generic.Subheading */
152
+ .highlight .kc { color: #66d9ef } /* Keyword.Constant */
153
+ .highlight .kd { color: #66d9ef } /* Keyword.Declaration */
154
+ .highlight .kn { color: #f92672 } /* Keyword.Namespace */
155
+ .highlight .kp { color: #66d9ef } /* Keyword.Pseudo */
156
+ .highlight .kr { color: #66d9ef } /* Keyword.Reserved */
157
+ .highlight .kt { color: #66d9ef } /* Keyword.Type */
158
+ .highlight .ld { color: #e6db74 } /* Literal.Date */
159
+ .highlight .m { color: #ae81ff } /* Literal.Number */
160
+ .highlight .s { color: #e6db74 } /* Literal.String */
161
+ .highlight .na { color: #a6e22e } /* Name.Attribute */
162
+ .highlight .nb { color: #f8f8f2 } /* Name.Builtin */
163
+ .highlight .nc { color: #a6e22e } /* Name.Class */
164
+ .highlight .no { color: #66d9ef } /* Name.Constant */
165
+ .highlight .nd { color: #a6e22e } /* Name.Decorator */
166
+ .highlight .ni { color: #f8f8f2 } /* Name.Entity */
167
+ .highlight .ne { color: #a6e22e } /* Name.Exception */
168
+ .highlight .nf { color: #a6e22e } /* Name.Function */
169
+ .highlight .nl { color: #f8f8f2 } /* Name.Label */
170
+ .highlight .nn { color: #f8f8f2 } /* Name.Namespace */
171
+ .highlight .nx { color: #a6e22e } /* Name.Other */
172
+ .highlight .py { color: #f8f8f2 } /* Name.Property */
173
+ .highlight .nt { color: #f92672 } /* Name.Tag */
174
+ .highlight .nv { color: #f8f8f2 } /* Name.Variable */
175
+ .highlight .ow { color: #f92672 } /* Operator.Word */
176
+ .highlight .w { color: #f8f8f2 } /* Text.Whitespace */
177
+ .highlight .mb { color: #ae81ff } /* Literal.Number.Bin */
178
+ .highlight .mf { color: #ae81ff } /* Literal.Number.Float */
179
+ .highlight .mh { color: #ae81ff } /* Literal.Number.Hex */
180
+ .highlight .mi { color: #ae81ff } /* Literal.Number.Integer */
181
+ .highlight .mo { color: #ae81ff } /* Literal.Number.Oct */
182
+ .highlight .sa { color: #e6db74 } /* Literal.String.Affix */
183
+ .highlight .sb { color: #e6db74 } /* Literal.String.Backtick */
184
+ .highlight .sc { color: #e6db74 } /* Literal.String.Char */
185
+ .highlight .dl { color: #e6db74 } /* Literal.String.Delimiter */
186
+ .highlight .sd { color: #e6db74 } /* Literal.String.Doc */
187
+ .highlight .s2 { color: #e6db74 } /* Literal.String.Double */
188
+ .highlight .se { color: #ae81ff } /* Literal.String.Escape */
189
+ .highlight .sh { color: #e6db74 } /* Literal.String.Heredoc */
190
+ .highlight .si { color: #e6db74 } /* Literal.String.Interpol */
191
+ .highlight .sx { color: #e6db74 } /* Literal.String.Other */
192
+ .highlight .sr { color: #e6db74 } /* Literal.String.Regex */
193
+ .highlight .s1 { color: #e6db74 } /* Literal.String.Single */
194
+ .highlight .ss { color: #e6db74 } /* Literal.String.Symbol */
195
+ .highlight .bp { color: #f8f8f2 } /* Name.Builtin.Pseudo */
196
+ .highlight .fm { color: #a6e22e } /* Name.Function.Magic */
197
+ .highlight .vc { color: #f8f8f2 } /* Name.Variable.Class */
198
+ .highlight .vg { color: #f8f8f2 } /* Name.Variable.Global */
199
+ .highlight .vi { color: #f8f8f2 } /* Name.Variable.Instance */
200
+ .highlight .vm { color: #f8f8f2 } /* Name.Variable.Magic */
201
+ .highlight .il { color: #ae81ff } /* Literal.Number.Integer.Long */
overwrites.py CHANGED
@@ -28,13 +28,7 @@ def postprocess(
28
  Returns:
29
  List of tuples representing the message and response. Each message and response will be a string of HTML.
30
  """
31
- if y is None:
32
  return []
33
- for i, (message, response) in enumerate(y):
34
- y[i] = (
35
- # None if message is None else markdown.markdown(message),
36
- # None if response is None else markdown.markdown(response),
37
- None if message is None else message,
38
- None if response is None else mdtex2html.convert(response, extensions=['fenced_code','codehilite','tables']),
39
- )
40
  return y
 
28
  Returns:
29
  List of tuples representing the message and response. Each message and response will be a string of HTML.
30
  """
31
+ if y is None or y == []:
32
  return []
33
+ y[-1] = (y[-1][0].replace("\n", "<br>"), convert_mdtext(y[-1][1]))
 
 
 
 
 
 
34
  return y
requirements.txt CHANGED
@@ -9,3 +9,4 @@ duckduckgo_search
9
  Pygments
10
  llama_index
11
  langchain
 
 
9
  Pygments
10
  llama_index
11
  langchain
12
+ markdown
utils.py CHANGED
@@ -13,6 +13,11 @@ import re
13
  import gradio as gr
14
  from pypinyin import lazy_pinyin
15
  import tiktoken
 
 
 
 
 
16
 
17
  from presets import *
18
 
@@ -32,34 +37,73 @@ def count_token(message):
32
  length = len(encoding.encode(input_str))
33
  return length
34
 
 
 
 
 
35
 
36
- def parse_text(text):
37
- in_code_block = False
38
- in_list = False
39
- new_lines = []
40
- for line in text.split("\n"):
41
- if line.strip().startswith("```"):
42
- in_code_block = not in_code_block
43
- else:
44
- if re.match(r'(\*|-|\d+\.)\s', line):
45
- if not in_list:
46
- in_list = True
47
- elif in_list and line.strip() != "":
48
- in_list = False
49
- new_lines.append("")
50
-
51
- if in_code_block:
52
- if line.strip() != "":
53
- new_lines.append(line)
54
- elif in_list:
55
- if line.strip() != "":
56
- new_lines.append(line)
 
 
 
 
 
 
 
 
 
 
57
  else:
58
- new_lines.append(line)
59
- if in_code_block:
60
- new_lines.append("```")
61
- text = "\n".join(new_lines)
62
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  def construct_text(role, text):
 
13
  import gradio as gr
14
  from pypinyin import lazy_pinyin
15
  import tiktoken
16
+ import mdtex2html
17
+ from markdown import markdown
18
+ from pygments import highlight
19
+ from pygments.lexers import get_lexer_by_name
20
+ from pygments.formatters import HtmlFormatter
21
 
22
  from presets import *
23
 
 
37
  length = len(encoding.encode(input_str))
38
  return length
39
 
40
+ def markdown_to_html_with_syntax_highlight(md_str):
41
+ def replacer(match):
42
+ lang = match.group(1) or 'text'
43
+ code = match.group(2)
44
 
45
+ try:
46
+ lexer = get_lexer_by_name(lang, stripall=True)
47
+ except ValueError:
48
+ lexer = get_lexer_by_name("text", stripall=True)
49
+
50
+ formatter = HtmlFormatter()
51
+ highlighted_code = highlight(code, lexer, formatter)
52
+
53
+ return f"<pre><code class=\"{lang}\">{highlighted_code}</code></pre>"
54
+
55
+ code_block_pattern = r'```(\w+)?\n([\s\S]+?)\n```'
56
+ md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
57
+
58
+ html_str = markdown(md_str)
59
+ return html_str
60
+
61
+ def normalize_markdown(md_text: str) -> str:
62
+ lines = md_text.split('\n')
63
+ normalized_lines = []
64
+ inside_list = False
65
+
66
+ for i, line in enumerate(lines):
67
+ if re.match(r'^(\d+\.|-|\*|\+)\s', line.strip()):
68
+ if not inside_list and i > 0 and lines[i - 1].strip() != '':
69
+ normalized_lines.append('')
70
+ inside_list = True
71
+ normalized_lines.append(line)
72
+ elif inside_list and line.strip() == '':
73
+ if i < len(lines) - 1 and not re.match(r'^(\d+\.|-|\*|\+)\s', lines[i + 1].strip()):
74
+ normalized_lines.append(line)
75
+ continue
76
  else:
77
+ inside_list = False
78
+ normalized_lines.append(line)
79
+
80
+ return '\n'.join(normalized_lines)
81
+
82
+ def convert_mdtext(md_text):
83
+ code_block_pattern = re.compile(r'```(.*?)(?:```|$)', re.DOTALL)
84
+ code_blocks = code_block_pattern.findall(md_text)
85
+ non_code_parts = code_block_pattern.split(md_text)[::2]
86
+
87
+ result = []
88
+ for non_code, code in zip(non_code_parts, code_blocks + ['']):
89
+ if non_code.strip():
90
+ non_code = normalize_markdown(non_code)
91
+ result.append(mdtex2html.convert(non_code, extensions=['tables']))
92
+ if code.strip():
93
+ code = f"```{code}```"
94
+ code = markdown_to_html_with_syntax_highlight(code)
95
+ result.append(code)
96
+ result = "".join(result)
97
+ return result
98
+
99
+ def detect_language(code):
100
+ if code.startswith("\n"):
101
+ first_line = ""
102
+ else:
103
+ first_line = code.strip().split('\n', 1)[0]
104
+ language = first_line.lower() if first_line else ''
105
+ code_without_language = code[len(first_line):].lstrip() if first_line else code
106
+ return language, code_without_language
107
 
108
 
109
  def construct_text(role, text):