Spaces:
Runtime error
Runtime error
JohnSmith9982
commited on
Commit
•
a4aaca9
1
Parent(s):
54ff2a0
Upload 19 files
Browse files- app.py +27 -20
- assets/custom.css +52 -5
- modules/__pycache__/chat_func.cpython-39.pyc +0 -0
- modules/__pycache__/llama_func.cpython-39.pyc +0 -0
- modules/__pycache__/openai_func.cpython-39.pyc +0 -0
- modules/__pycache__/presets.cpython-39.pyc +0 -0
- modules/__pycache__/utils.cpython-39.pyc +0 -0
- modules/chat_func.py +85 -57
- modules/llama_func.py +44 -30
- modules/openai_func.py +59 -47
- modules/presets.py +34 -3
- modules/utils.py +96 -10
app.py
CHANGED
@@ -25,9 +25,11 @@ else:
|
|
25 |
dockerflag = False
|
26 |
|
27 |
authflag = False
|
|
|
28 |
|
29 |
-
if
|
30 |
my_api_key = os.environ.get("my_api_key")
|
|
|
31 |
if my_api_key == "empty":
|
32 |
logging.error("Please give a api key!")
|
33 |
sys.exit(1)
|
@@ -35,6 +37,7 @@ if dockerflag:
|
|
35 |
username = os.environ.get("USERNAME")
|
36 |
password = os.environ.get("PASSWORD")
|
37 |
if not (isinstance(username, type(None)) or isinstance(password, type(None))):
|
|
|
38 |
authflag = True
|
39 |
else:
|
40 |
if (
|
@@ -45,12 +48,15 @@ else:
|
|
45 |
with open("api_key.txt", "r") as f:
|
46 |
my_api_key = f.read().strip()
|
47 |
if os.path.exists("auth.json"):
|
|
|
48 |
with open("auth.json", "r", encoding='utf-8') as f:
|
49 |
auth = json.load(f)
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
|
55 |
gr.Chatbot.postprocess = postprocess
|
56 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
@@ -75,19 +81,19 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
75 |
with gr.Column(scale=4):
|
76 |
status_display = gr.Markdown(get_geoip(), elem_id="status_display")
|
77 |
|
78 |
-
with gr.Row(
|
79 |
with gr.Column(scale=5):
|
80 |
-
with gr.Row(
|
81 |
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
|
82 |
-
with gr.Row(
|
83 |
with gr.Column(scale=12):
|
84 |
user_input = gr.Textbox(
|
85 |
-
show_label=False, placeholder="在这里输入"
|
86 |
).style(container=False)
|
87 |
with gr.Column(min_width=70, scale=1):
|
88 |
submitBtn = gr.Button("发送", variant="primary")
|
89 |
cancelBtn = gr.Button("取消", variant="secondary", visible=False)
|
90 |
-
with gr.Row(
|
91 |
emptyBtn = gr.Button(
|
92 |
"🧹 新的对话",
|
93 |
)
|
@@ -107,7 +113,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
107 |
visible=not HIDE_MY_KEY,
|
108 |
label="API-Key",
|
109 |
)
|
110 |
-
usageTxt = gr.Markdown(
|
111 |
model_select_dropdown = gr.Dropdown(
|
112 |
label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
|
113 |
)
|
@@ -207,7 +213,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
207 |
label="Temperature",
|
208 |
)
|
209 |
|
210 |
-
with gr.Accordion("网络设置", open=False):
|
211 |
apiurlTxt = gr.Textbox(
|
212 |
show_label=True,
|
213 |
placeholder=f"在这里输入API地址...",
|
@@ -226,7 +232,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
226 |
changeProxyBtn = gr.Button("🔄 设置代理地址")
|
227 |
|
228 |
gr.Markdown(description)
|
229 |
-
|
230 |
chatgpt_predict_args = dict(
|
231 |
fn=predict,
|
232 |
inputs=[
|
@@ -264,13 +270,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
264 |
)
|
265 |
|
266 |
transfer_input_args = dict(
|
267 |
-
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input], show_progress=True
|
268 |
)
|
269 |
|
270 |
get_usage_args = dict(
|
271 |
fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
|
272 |
)
|
273 |
|
|
|
274 |
# Chatbot
|
275 |
cancelBtn.click(cancel_outputing, [], [])
|
276 |
|
@@ -287,8 +294,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
287 |
)
|
288 |
emptyBtn.click(**reset_textbox_args)
|
289 |
|
290 |
-
retryBtn.click(**
|
291 |
-
retryBtn.click(
|
292 |
retry,
|
293 |
[
|
294 |
user_api_key,
|
@@ -304,7 +310,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
304 |
],
|
305 |
[chatbot, history, status_display, token_count],
|
306 |
show_progress=True,
|
307 |
-
)
|
308 |
retryBtn.click(**get_usage_args)
|
309 |
|
310 |
delFirstBtn.click(
|
@@ -330,7 +336,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
330 |
token_count,
|
331 |
top_p,
|
332 |
temperature,
|
333 |
-
gr.State(
|
334 |
model_select_dropdown,
|
335 |
language_select_dropdown,
|
336 |
],
|
@@ -341,6 +347,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
341 |
|
342 |
# ChatGPT
|
343 |
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
|
|
|
344 |
|
345 |
# Template
|
346 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
@@ -417,7 +424,7 @@ if __name__ == "__main__":
|
|
417 |
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
418 |
server_name="0.0.0.0",
|
419 |
server_port=7860,
|
420 |
-
auth=
|
421 |
favicon_path="./assets/favicon.ico",
|
422 |
)
|
423 |
else:
|
@@ -432,7 +439,7 @@ if __name__ == "__main__":
|
|
432 |
if authflag:
|
433 |
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
434 |
share=False,
|
435 |
-
auth=
|
436 |
favicon_path="./assets/favicon.ico",
|
437 |
inbrowser=True,
|
438 |
)
|
|
|
25 |
dockerflag = False
|
26 |
|
27 |
authflag = False
|
28 |
+
auth_list = []
|
29 |
|
30 |
+
if not my_api_key:
|
31 |
my_api_key = os.environ.get("my_api_key")
|
32 |
+
if dockerflag:
|
33 |
if my_api_key == "empty":
|
34 |
logging.error("Please give a api key!")
|
35 |
sys.exit(1)
|
|
|
37 |
username = os.environ.get("USERNAME")
|
38 |
password = os.environ.get("PASSWORD")
|
39 |
if not (isinstance(username, type(None)) or isinstance(password, type(None))):
|
40 |
+
auth_list.append((os.environ.get("USERNAME"), os.environ.get("PASSWORD")))
|
41 |
authflag = True
|
42 |
else:
|
43 |
if (
|
|
|
48 |
with open("api_key.txt", "r") as f:
|
49 |
my_api_key = f.read().strip()
|
50 |
if os.path.exists("auth.json"):
|
51 |
+
authflag = True
|
52 |
with open("auth.json", "r", encoding='utf-8') as f:
|
53 |
auth = json.load(f)
|
54 |
+
for _ in auth:
|
55 |
+
if auth[_]["username"] and auth[_]["password"]:
|
56 |
+
auth_list.append((auth[_]["username"], auth[_]["password"]))
|
57 |
+
else:
|
58 |
+
logging.error("请检查auth.json文件中的用户名和密码!")
|
59 |
+
sys.exit(1)
|
60 |
|
61 |
gr.Chatbot.postprocess = postprocess
|
62 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
|
|
81 |
with gr.Column(scale=4):
|
82 |
status_display = gr.Markdown(get_geoip(), elem_id="status_display")
|
83 |
|
84 |
+
with gr.Row().style(equal_height=True):
|
85 |
with gr.Column(scale=5):
|
86 |
+
with gr.Row():
|
87 |
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
|
88 |
+
with gr.Row():
|
89 |
with gr.Column(scale=12):
|
90 |
user_input = gr.Textbox(
|
91 |
+
show_label=False, placeholder="在这里输入"
|
92 |
).style(container=False)
|
93 |
with gr.Column(min_width=70, scale=1):
|
94 |
submitBtn = gr.Button("发送", variant="primary")
|
95 |
cancelBtn = gr.Button("取消", variant="secondary", visible=False)
|
96 |
+
with gr.Row():
|
97 |
emptyBtn = gr.Button(
|
98 |
"🧹 新的对话",
|
99 |
)
|
|
|
113 |
visible=not HIDE_MY_KEY,
|
114 |
label="API-Key",
|
115 |
)
|
116 |
+
usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display")
|
117 |
model_select_dropdown = gr.Dropdown(
|
118 |
label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
|
119 |
)
|
|
|
213 |
label="Temperature",
|
214 |
)
|
215 |
|
216 |
+
with gr.Accordion("网络设置", open=False, visible=False):
|
217 |
apiurlTxt = gr.Textbox(
|
218 |
show_label=True,
|
219 |
placeholder=f"在这里输入API地址...",
|
|
|
232 |
changeProxyBtn = gr.Button("🔄 设置代理地址")
|
233 |
|
234 |
gr.Markdown(description)
|
235 |
+
gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
|
236 |
chatgpt_predict_args = dict(
|
237 |
fn=predict,
|
238 |
inputs=[
|
|
|
270 |
)
|
271 |
|
272 |
transfer_input_args = dict(
|
273 |
+
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn, cancelBtn], show_progress=True
|
274 |
)
|
275 |
|
276 |
get_usage_args = dict(
|
277 |
fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
|
278 |
)
|
279 |
|
280 |
+
|
281 |
# Chatbot
|
282 |
cancelBtn.click(cancel_outputing, [], [])
|
283 |
|
|
|
294 |
)
|
295 |
emptyBtn.click(**reset_textbox_args)
|
296 |
|
297 |
+
retryBtn.click(**start_outputing_args).then(
|
|
|
298 |
retry,
|
299 |
[
|
300 |
user_api_key,
|
|
|
310 |
],
|
311 |
[chatbot, history, status_display, token_count],
|
312 |
show_progress=True,
|
313 |
+
).then(**end_outputing_args)
|
314 |
retryBtn.click(**get_usage_args)
|
315 |
|
316 |
delFirstBtn.click(
|
|
|
336 |
token_count,
|
337 |
top_p,
|
338 |
temperature,
|
339 |
+
gr.State(sum(token_count.value[-4:])),
|
340 |
model_select_dropdown,
|
341 |
language_select_dropdown,
|
342 |
],
|
|
|
347 |
|
348 |
# ChatGPT
|
349 |
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
|
350 |
+
keyTxt.submit(**get_usage_args)
|
351 |
|
352 |
# Template
|
353 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
|
|
424 |
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
425 |
server_name="0.0.0.0",
|
426 |
server_port=7860,
|
427 |
+
auth=auth_list,
|
428 |
favicon_path="./assets/favicon.ico",
|
429 |
)
|
430 |
else:
|
|
|
439 |
if authflag:
|
440 |
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
441 |
share=False,
|
442 |
+
auth=auth_list,
|
443 |
favicon_path="./assets/favicon.ico",
|
444 |
inbrowser=True,
|
445 |
)
|
assets/custom.css
CHANGED
@@ -3,6 +3,21 @@
|
|
3 |
--chatbot-color-dark: #121111;
|
4 |
}
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
/* status_display */
|
7 |
#status_display {
|
8 |
display: flex;
|
@@ -22,14 +37,45 @@
|
|
22 |
|
23 |
/* usage_display */
|
24 |
#usage_display {
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
font-size: .85em;
|
30 |
-
font-family: monospace;
|
31 |
color: var(--body-text-color-subdued);
|
32 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
/* list */
|
34 |
ol:not(.options), ul:not(.options) {
|
35 |
padding-inline-start: 2em !important;
|
@@ -64,6 +110,7 @@ ol:not(.options), ul:not(.options) {
|
|
64 |
background-color: var(--neutral-950) !important;
|
65 |
}
|
66 |
}
|
|
|
67 |
/* 对话气泡 */
|
68 |
[class *= "message"] {
|
69 |
border-radius: var(--radius-xl) !important;
|
|
|
3 |
--chatbot-color-dark: #121111;
|
4 |
}
|
5 |
|
6 |
+
/* 覆盖gradio的页脚信息QAQ */
|
7 |
+
footer {
|
8 |
+
display: none !important;
|
9 |
+
}
|
10 |
+
#footer{
|
11 |
+
text-align: center;
|
12 |
+
}
|
13 |
+
#footer div{
|
14 |
+
display: inline-block;
|
15 |
+
}
|
16 |
+
#footer .versions{
|
17 |
+
font-size: 85%;
|
18 |
+
opacity: 0.85;
|
19 |
+
}
|
20 |
+
|
21 |
/* status_display */
|
22 |
#status_display {
|
23 |
display: flex;
|
|
|
37 |
|
38 |
/* usage_display */
|
39 |
#usage_display {
|
40 |
+
position: relative;
|
41 |
+
margin: 0;
|
42 |
+
box-shadow: var(--block-shadow);
|
43 |
+
border-width: var(--block-border-width);
|
44 |
+
border-color: var(--block-border-color);
|
45 |
+
border-radius: var(--block-radius);
|
46 |
+
background: var(--block-background-fill);
|
47 |
+
width: 100%;
|
48 |
+
line-height: var(--line-sm);
|
49 |
+
min-height: 2em;
|
50 |
+
}
|
51 |
+
#usage_display p, #usage_display span {
|
52 |
+
margin: 0;
|
53 |
+
padding: .5em 1em;
|
54 |
font-size: .85em;
|
|
|
55 |
color: var(--body-text-color-subdued);
|
56 |
}
|
57 |
+
.progress-bar {
|
58 |
+
background-color: var(--input-background-fill);;
|
59 |
+
margin: 0 1em;
|
60 |
+
height: 20px;
|
61 |
+
border-radius: 10px;
|
62 |
+
overflow: hidden;
|
63 |
+
}
|
64 |
+
.progress {
|
65 |
+
background-color: var(--block-title-background-fill);;
|
66 |
+
height: 100%;
|
67 |
+
border-radius: 10px;
|
68 |
+
text-align: right;
|
69 |
+
transition: width 0.5s ease-in-out;
|
70 |
+
}
|
71 |
+
.progress-text {
|
72 |
+
/* color: white; */
|
73 |
+
color: var(--color-accent) !important;
|
74 |
+
font-size: 1em !important;
|
75 |
+
font-weight: bold;
|
76 |
+
padding-right: 10px;
|
77 |
+
line-height: 20px;
|
78 |
+
}
|
79 |
/* list */
|
80 |
ol:not(.options), ul:not(.options) {
|
81 |
padding-inline-start: 2em !important;
|
|
|
110 |
background-color: var(--neutral-950) !important;
|
111 |
}
|
112 |
}
|
113 |
+
|
114 |
/* 对话气泡 */
|
115 |
[class *= "message"] {
|
116 |
border-radius: var(--radius-xl) !important;
|
modules/__pycache__/chat_func.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/chat_func.cpython-39.pyc and b/modules/__pycache__/chat_func.cpython-39.pyc differ
|
|
modules/__pycache__/llama_func.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/llama_func.cpython-39.pyc and b/modules/__pycache__/llama_func.cpython-39.pyc differ
|
|
modules/__pycache__/openai_func.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/openai_func.cpython-39.pyc and b/modules/__pycache__/openai_func.cpython-39.pyc differ
|
|
modules/__pycache__/presets.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/presets.cpython-39.pyc and b/modules/__pycache__/presets.cpython-39.pyc differ
|
|
modules/__pycache__/utils.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/utils.cpython-39.pyc and b/modules/__pycache__/utils.cpython-39.pyc differ
|
|
modules/chat_func.py
CHANGED
@@ -13,6 +13,9 @@ import colorama
|
|
13 |
from duckduckgo_search import ddg
|
14 |
import asyncio
|
15 |
import aiohttp
|
|
|
|
|
|
|
16 |
|
17 |
from modules.presets import *
|
18 |
from modules.llama_func import *
|
@@ -58,39 +61,21 @@ def get_response(
|
|
58 |
else:
|
59 |
timeout = timeout_all
|
60 |
|
61 |
-
|
62 |
-
http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
|
63 |
-
https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
|
64 |
-
|
65 |
-
# 如果存在代理设置,使用它们
|
66 |
-
proxies = {}
|
67 |
-
if http_proxy:
|
68 |
-
logging.info(f"使用 HTTP 代理: {http_proxy}")
|
69 |
-
proxies["http"] = http_proxy
|
70 |
-
if https_proxy:
|
71 |
-
logging.info(f"使用 HTTPS 代理: {https_proxy}")
|
72 |
-
proxies["https"] = https_proxy
|
73 |
|
74 |
# 如果有自定义的api-url,使用自定义url发送请求,否则使用默认设置发送请求
|
75 |
if shared.state.api_url != API_URL:
|
76 |
logging.info(f"使用自定义API URL: {shared.state.api_url}")
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
response = requests.post(
|
88 |
-
shared.state.api_url,
|
89 |
-
headers=headers,
|
90 |
-
json=payload,
|
91 |
-
stream=True,
|
92 |
-
timeout=timeout,
|
93 |
-
)
|
94 |
return response
|
95 |
|
96 |
|
@@ -121,13 +106,17 @@ def stream_predict(
|
|
121 |
else:
|
122 |
chatbot.append((inputs, ""))
|
123 |
user_token_count = 0
|
|
|
|
|
|
|
|
|
124 |
if len(all_token_counts) == 0:
|
125 |
system_prompt_token_count = count_token(construct_system(system_prompt))
|
126 |
user_token_count = (
|
127 |
-
|
128 |
)
|
129 |
else:
|
130 |
-
user_token_count =
|
131 |
all_token_counts.append(user_token_count)
|
132 |
logging.info(f"输入token计数: {user_token_count}")
|
133 |
yield get_return_value()
|
@@ -155,6 +144,8 @@ def stream_predict(
|
|
155 |
yield get_return_value()
|
156 |
error_json_str = ""
|
157 |
|
|
|
|
|
158 |
for chunk in tqdm(response.iter_lines()):
|
159 |
if counter == 0:
|
160 |
counter += 1
|
@@ -219,7 +210,10 @@ def predict_all(
|
|
219 |
chatbot.append((fake_input, ""))
|
220 |
else:
|
221 |
chatbot.append((inputs, ""))
|
222 |
-
|
|
|
|
|
|
|
223 |
try:
|
224 |
response = get_response(
|
225 |
openai_api_key,
|
@@ -242,13 +236,22 @@ def predict_all(
|
|
242 |
status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
|
243 |
return chatbot, history, status_text, all_token_counts
|
244 |
response = json.loads(response.text)
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
|
254 |
def predict(
|
@@ -268,40 +271,59 @@ def predict(
|
|
268 |
should_check_token_count=True,
|
269 |
): # repetition_penalty, top_k
|
270 |
logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
271 |
-
|
|
|
272 |
if reply_language == "跟随问题语言(不稳定)":
|
273 |
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
|
|
|
|
|
|
274 |
if files:
|
|
|
|
|
275 |
msg = "加载索引中……(这可能需要几分钟)"
|
276 |
logging.info(msg)
|
277 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
278 |
index = construct_index(openai_api_key, file_src=files)
|
279 |
msg = "索引构建完成,获取回答中……"
|
|
|
280 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
search_results = ddg(inputs, max_results=5)
|
289 |
old_inputs = inputs
|
290 |
-
|
291 |
for idx, result in enumerate(search_results):
|
292 |
logging.info(f"搜索结果{idx + 1}:{result}")
|
293 |
domain_name = urllib3.util.parse_url(result["href"]).host
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
297 |
inputs = (
|
298 |
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
299 |
.replace("{query}", inputs)
|
300 |
-
.replace("{web_results}", "\n\n".join(
|
301 |
.replace("{reply_language}", reply_language )
|
302 |
)
|
303 |
else:
|
304 |
-
|
305 |
|
306 |
if len(openai_api_key) != 51:
|
307 |
status_text = standard_error_msg + no_apikey_msg
|
@@ -334,7 +356,7 @@ def predict(
|
|
334 |
temperature,
|
335 |
selected_model,
|
336 |
fake_input=old_inputs,
|
337 |
-
display_append=
|
338 |
)
|
339 |
for chatbot, history, status_text, all_token_counts in iter:
|
340 |
if shared.state.interrupted:
|
@@ -354,7 +376,7 @@ def predict(
|
|
354 |
temperature,
|
355 |
selected_model,
|
356 |
fake_input=old_inputs,
|
357 |
-
display_append=
|
358 |
)
|
359 |
yield chatbot, history, status_text, all_token_counts
|
360 |
|
@@ -367,10 +389,15 @@ def predict(
|
|
367 |
+ colorama.Style.RESET_ALL
|
368 |
)
|
369 |
|
|
|
|
|
|
|
|
|
|
|
370 |
if stream:
|
371 |
-
max_token =
|
372 |
else:
|
373 |
-
max_token =
|
374 |
|
375 |
if sum(all_token_counts) > max_token and should_check_token_count:
|
376 |
status_text = f"精简token中{all_token_counts}/{max_token}"
|
@@ -460,6 +487,7 @@ def reduce_token_size(
|
|
460 |
flag = False
|
461 |
for chatbot, history, status_text, previous_token_count in iter:
|
462 |
num_chat = find_n(previous_token_count, max_token_count)
|
|
|
463 |
if flag:
|
464 |
chatbot = chatbot[:-1]
|
465 |
flag = True
|
|
|
13 |
from duckduckgo_search import ddg
|
14 |
import asyncio
|
15 |
import aiohttp
|
16 |
+
from llama_index.indices.query.vector_store import GPTVectorStoreIndexQuery
|
17 |
+
from llama_index.indices.query.schema import QueryBundle
|
18 |
+
from langchain.llms import OpenAIChat
|
19 |
|
20 |
from modules.presets import *
|
21 |
from modules.llama_func import *
|
|
|
61 |
else:
|
62 |
timeout = timeout_all
|
63 |
|
64 |
+
proxies = get_proxies()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# 如果有自定义的api-url,使用自定义url发送请求,否则使用默认设置发送请求
|
67 |
if shared.state.api_url != API_URL:
|
68 |
logging.info(f"使用自定义API URL: {shared.state.api_url}")
|
69 |
+
|
70 |
+
response = requests.post(
|
71 |
+
shared.state.api_url,
|
72 |
+
headers=headers,
|
73 |
+
json=payload,
|
74 |
+
stream=True,
|
75 |
+
timeout=timeout,
|
76 |
+
proxies=proxies,
|
77 |
+
)
|
78 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
return response
|
80 |
|
81 |
|
|
|
106 |
else:
|
107 |
chatbot.append((inputs, ""))
|
108 |
user_token_count = 0
|
109 |
+
if fake_input is not None:
|
110 |
+
input_token_count = count_token(construct_user(fake_input))
|
111 |
+
else:
|
112 |
+
input_token_count = count_token(construct_user(inputs))
|
113 |
if len(all_token_counts) == 0:
|
114 |
system_prompt_token_count = count_token(construct_system(system_prompt))
|
115 |
user_token_count = (
|
116 |
+
input_token_count + system_prompt_token_count
|
117 |
)
|
118 |
else:
|
119 |
+
user_token_count = input_token_count
|
120 |
all_token_counts.append(user_token_count)
|
121 |
logging.info(f"输入token计数: {user_token_count}")
|
122 |
yield get_return_value()
|
|
|
144 |
yield get_return_value()
|
145 |
error_json_str = ""
|
146 |
|
147 |
+
if fake_input is not None:
|
148 |
+
history[-2] = construct_user(fake_input)
|
149 |
for chunk in tqdm(response.iter_lines()):
|
150 |
if counter == 0:
|
151 |
counter += 1
|
|
|
210 |
chatbot.append((fake_input, ""))
|
211 |
else:
|
212 |
chatbot.append((inputs, ""))
|
213 |
+
if fake_input is not None:
|
214 |
+
all_token_counts.append(count_token(construct_user(fake_input)))
|
215 |
+
else:
|
216 |
+
all_token_counts.append(count_token(construct_user(inputs)))
|
217 |
try:
|
218 |
response = get_response(
|
219 |
openai_api_key,
|
|
|
236 |
status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
|
237 |
return chatbot, history, status_text, all_token_counts
|
238 |
response = json.loads(response.text)
|
239 |
+
if fake_input is not None:
|
240 |
+
history[-2] = construct_user(fake_input)
|
241 |
+
try:
|
242 |
+
content = response["choices"][0]["message"]["content"]
|
243 |
+
history[-1] = construct_assistant(content)
|
244 |
+
chatbot[-1] = (chatbot[-1][0], content+display_append)
|
245 |
+
total_token_count = response["usage"]["total_tokens"]
|
246 |
+
if fake_input is not None:
|
247 |
+
all_token_counts[-1] += count_token(construct_assistant(content))
|
248 |
+
else:
|
249 |
+
all_token_counts[-1] = total_token_count - sum(all_token_counts)
|
250 |
+
status_text = construct_token_message(total_token_count)
|
251 |
+
return chatbot, history, status_text, all_token_counts
|
252 |
+
except KeyError:
|
253 |
+
status_text = standard_error_msg + str(response)
|
254 |
+
return chatbot, history, status_text, all_token_counts
|
255 |
|
256 |
|
257 |
def predict(
|
|
|
271 |
should_check_token_count=True,
|
272 |
): # repetition_penalty, top_k
|
273 |
logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
274 |
+
if should_check_token_count:
|
275 |
+
yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
|
276 |
if reply_language == "跟随问题语言(不稳定)":
|
277 |
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
278 |
+
old_inputs = None
|
279 |
+
display_reference = []
|
280 |
+
limited_context = False
|
281 |
if files:
|
282 |
+
limited_context = True
|
283 |
+
old_inputs = inputs
|
284 |
msg = "加载索引中……(这可能需要几分钟)"
|
285 |
logging.info(msg)
|
286 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
287 |
index = construct_index(openai_api_key, file_src=files)
|
288 |
msg = "索引构建完成,获取回答中……"
|
289 |
+
logging.info(msg)
|
290 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
291 |
+
llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
|
292 |
+
prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
|
293 |
+
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
|
294 |
+
query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
|
295 |
+
query_bundle = QueryBundle(inputs)
|
296 |
+
nodes = query_object.retrieve(query_bundle)
|
297 |
+
reference_results = [n.node.text for n in nodes]
|
298 |
+
reference_results = add_source_numbers(reference_results, use_source=False)
|
299 |
+
display_reference = add_details(reference_results)
|
300 |
+
display_reference = "\n\n" + "".join(display_reference)
|
301 |
+
inputs = (
|
302 |
+
replace_today(PROMPT_TEMPLATE)
|
303 |
+
.replace("{query_str}", inputs)
|
304 |
+
.replace("{context_str}", "\n\n".join(reference_results))
|
305 |
+
.replace("{reply_language}", reply_language )
|
306 |
+
)
|
307 |
+
elif use_websearch:
|
308 |
+
limited_context = True
|
309 |
search_results = ddg(inputs, max_results=5)
|
310 |
old_inputs = inputs
|
311 |
+
reference_results = []
|
312 |
for idx, result in enumerate(search_results):
|
313 |
logging.info(f"搜索结果{idx + 1}:{result}")
|
314 |
domain_name = urllib3.util.parse_url(result["href"]).host
|
315 |
+
reference_results.append([result["body"], result["href"]])
|
316 |
+
display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
|
317 |
+
reference_results = add_source_numbers(reference_results)
|
318 |
+
display_reference = "\n\n" + "".join(display_reference)
|
319 |
inputs = (
|
320 |
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
321 |
.replace("{query}", inputs)
|
322 |
+
.replace("{web_results}", "\n\n".join(reference_results))
|
323 |
.replace("{reply_language}", reply_language )
|
324 |
)
|
325 |
else:
|
326 |
+
display_reference = ""
|
327 |
|
328 |
if len(openai_api_key) != 51:
|
329 |
status_text = standard_error_msg + no_apikey_msg
|
|
|
356 |
temperature,
|
357 |
selected_model,
|
358 |
fake_input=old_inputs,
|
359 |
+
display_append=display_reference
|
360 |
)
|
361 |
for chatbot, history, status_text, all_token_counts in iter:
|
362 |
if shared.state.interrupted:
|
|
|
376 |
temperature,
|
377 |
selected_model,
|
378 |
fake_input=old_inputs,
|
379 |
+
display_append=display_reference
|
380 |
)
|
381 |
yield chatbot, history, status_text, all_token_counts
|
382 |
|
|
|
389 |
+ colorama.Style.RESET_ALL
|
390 |
)
|
391 |
|
392 |
+
if limited_context:
|
393 |
+
history = history[-4:]
|
394 |
+
all_token_counts = all_token_counts[-2:]
|
395 |
+
yield chatbot, history, status_text, all_token_counts
|
396 |
+
|
397 |
if stream:
|
398 |
+
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
|
399 |
else:
|
400 |
+
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
|
401 |
|
402 |
if sum(all_token_counts) > max_token and should_check_token_count:
|
403 |
status_text = f"精简token中{all_token_counts}/{max_token}"
|
|
|
487 |
flag = False
|
488 |
for chatbot, history, status_text, previous_token_count in iter:
|
489 |
num_chat = find_n(previous_token_count, max_token_count)
|
490 |
+
logging.info(f"previous_token_count: {previous_token_count}, keeping {num_chat} chats")
|
491 |
if flag:
|
492 |
chatbot = chatbot[:-1]
|
493 |
flag = True
|
modules/llama_func.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
4 |
-
from llama_index import GPTSimpleVectorIndex
|
5 |
from llama_index import download_loader
|
6 |
from llama_index import (
|
7 |
Document,
|
@@ -11,19 +11,32 @@ from llama_index import (
|
|
11 |
RefinePrompt,
|
12 |
)
|
13 |
from langchain.llms import OpenAI
|
|
|
14 |
import colorama
|
|
|
|
|
15 |
|
16 |
from modules.presets import *
|
17 |
from modules.utils import *
|
18 |
|
19 |
def get_index_name(file_src):
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
def get_documents(file_src):
|
29 |
documents = []
|
@@ -33,9 +46,12 @@ def get_documents(file_src):
|
|
33 |
logging.info(f"loading file: {file.name}")
|
34 |
if os.path.splitext(file.name)[1] == ".pdf":
|
35 |
logging.debug("Loading PDF...")
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
elif os.path.splitext(file.name)[1] == ".docx":
|
40 |
logging.debug("Loading DOCX...")
|
41 |
DocxReader = download_loader("DocxReader")
|
@@ -51,7 +67,10 @@ def get_documents(file_src):
|
|
51 |
with open(file.name, "r", encoding="utf-8") as f:
|
52 |
text_raw = f.read()
|
53 |
text = add_space(text_raw)
|
|
|
|
|
54 |
documents += [Document(text)]
|
|
|
55 |
return documents
|
56 |
|
57 |
|
@@ -59,13 +78,11 @@ def construct_index(
|
|
59 |
api_key,
|
60 |
file_src,
|
61 |
max_input_size=4096,
|
62 |
-
num_outputs=
|
63 |
max_chunk_overlap=20,
|
64 |
chunk_size_limit=600,
|
65 |
embedding_limit=None,
|
66 |
-
separator=" "
|
67 |
-
num_children=10,
|
68 |
-
max_keywords_per_chunk=10,
|
69 |
):
|
70 |
os.environ["OPENAI_API_KEY"] = api_key
|
71 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
@@ -73,16 +90,9 @@ def construct_index(
|
|
73 |
separator = " " if separator == "" else separator
|
74 |
|
75 |
llm_predictor = LLMPredictor(
|
76 |
-
llm=
|
77 |
-
)
|
78 |
-
prompt_helper = PromptHelper(
|
79 |
-
max_input_size,
|
80 |
-
num_outputs,
|
81 |
-
max_chunk_overlap,
|
82 |
-
embedding_limit,
|
83 |
-
chunk_size_limit,
|
84 |
-
separator=separator,
|
85 |
)
|
|
|
86 |
index_name = get_index_name(file_src)
|
87 |
if os.path.exists(f"./index/{index_name}.json"):
|
88 |
logging.info("找到了缓存的索引文件,加载中……")
|
@@ -90,14 +100,19 @@ def construct_index(
|
|
90 |
else:
|
91 |
try:
|
92 |
documents = get_documents(file_src)
|
93 |
-
logging.
|
94 |
-
|
95 |
-
|
|
|
96 |
)
|
|
|
97 |
os.makedirs("./index", exist_ok=True)
|
98 |
index.save_to_disk(f"./index/{index_name}.json")
|
|
|
99 |
return index
|
|
|
100 |
except Exception as e:
|
|
|
101 |
print(e)
|
102 |
return None
|
103 |
|
@@ -144,7 +159,7 @@ def ask_ai(
|
|
144 |
question,
|
145 |
prompt_tmpl,
|
146 |
refine_tmpl,
|
147 |
-
sim_k=
|
148 |
temprature=0,
|
149 |
prefix_messages=[],
|
150 |
reply_language="中文",
|
@@ -154,7 +169,7 @@ def ask_ai(
|
|
154 |
logging.debug("Index file found")
|
155 |
logging.debug("Querying index...")
|
156 |
llm_predictor = LLMPredictor(
|
157 |
-
llm=
|
158 |
temperature=temprature,
|
159 |
model_name="gpt-3.5-turbo-0301",
|
160 |
prefix_messages=prefix_messages,
|
@@ -166,7 +181,6 @@ def ask_ai(
|
|
166 |
rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
|
167 |
response = index.query(
|
168 |
question,
|
169 |
-
llm_predictor=llm_predictor,
|
170 |
similarity_top_k=sim_k,
|
171 |
text_qa_template=qa_prompt,
|
172 |
refine_template=rf_prompt,
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
4 |
+
from llama_index import GPTSimpleVectorIndex, ServiceContext
|
5 |
from llama_index import download_loader
|
6 |
from llama_index import (
|
7 |
Document,
|
|
|
11 |
RefinePrompt,
|
12 |
)
|
13 |
from langchain.llms import OpenAI
|
14 |
+
from langchain.chat_models import ChatOpenAI
|
15 |
import colorama
|
16 |
+
import PyPDF2
|
17 |
+
from tqdm import tqdm
|
18 |
|
19 |
from modules.presets import *
|
20 |
from modules.utils import *
|
21 |
|
22 |
def get_index_name(file_src):
|
23 |
+
file_paths = [x.name for x in file_src]
|
24 |
+
file_paths.sort(key=lambda x: os.path.basename(x))
|
25 |
+
|
26 |
+
md5_hash = hashlib.md5()
|
27 |
+
for file_path in file_paths:
|
28 |
+
with open(file_path, "rb") as f:
|
29 |
+
while chunk := f.read(8192):
|
30 |
+
md5_hash.update(chunk)
|
31 |
+
|
32 |
+
return md5_hash.hexdigest()
|
33 |
+
|
34 |
+
def block_split(text):
|
35 |
+
blocks = []
|
36 |
+
while len(text) > 0:
|
37 |
+
blocks.append(Document(text[:1000]))
|
38 |
+
text = text[1000:]
|
39 |
+
return blocks
|
40 |
|
41 |
def get_documents(file_src):
|
42 |
documents = []
|
|
|
46 |
logging.info(f"loading file: {file.name}")
|
47 |
if os.path.splitext(file.name)[1] == ".pdf":
|
48 |
logging.debug("Loading PDF...")
|
49 |
+
pdftext = ""
|
50 |
+
with open(file.name, 'rb') as pdfFileObj:
|
51 |
+
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
52 |
+
for page in tqdm(pdfReader.pages):
|
53 |
+
pdftext += page.extract_text()
|
54 |
+
text_raw = pdftext
|
55 |
elif os.path.splitext(file.name)[1] == ".docx":
|
56 |
logging.debug("Loading DOCX...")
|
57 |
DocxReader = download_loader("DocxReader")
|
|
|
67 |
with open(file.name, "r", encoding="utf-8") as f:
|
68 |
text_raw = f.read()
|
69 |
text = add_space(text_raw)
|
70 |
+
# text = block_split(text)
|
71 |
+
# documents += text
|
72 |
documents += [Document(text)]
|
73 |
+
logging.debug("Documents loaded.")
|
74 |
return documents
|
75 |
|
76 |
|
|
|
78 |
api_key,
|
79 |
file_src,
|
80 |
max_input_size=4096,
|
81 |
+
num_outputs=5,
|
82 |
max_chunk_overlap=20,
|
83 |
chunk_size_limit=600,
|
84 |
embedding_limit=None,
|
85 |
+
separator=" "
|
|
|
|
|
86 |
):
|
87 |
os.environ["OPENAI_API_KEY"] = api_key
|
88 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
|
|
90 |
separator = " " if separator == "" else separator
|
91 |
|
92 |
llm_predictor = LLMPredictor(
|
93 |
+
llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
)
|
95 |
+
prompt_helper = PromptHelper(max_input_size = max_input_size, num_output = num_outputs, max_chunk_overlap = max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator)
|
96 |
index_name = get_index_name(file_src)
|
97 |
if os.path.exists(f"./index/{index_name}.json"):
|
98 |
logging.info("找到了缓存的索引文件,加载中……")
|
|
|
100 |
else:
|
101 |
try:
|
102 |
documents = get_documents(file_src)
|
103 |
+
logging.info("构建索引中……")
|
104 |
+
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
|
105 |
+
index = GPTSimpleVectorIndex.from_documents(
|
106 |
+
documents, service_context=service_context
|
107 |
)
|
108 |
+
logging.debug("索引构建完成!")
|
109 |
os.makedirs("./index", exist_ok=True)
|
110 |
index.save_to_disk(f"./index/{index_name}.json")
|
111 |
+
logging.debug("索引已保存至本地!")
|
112 |
return index
|
113 |
+
|
114 |
except Exception as e:
|
115 |
+
logging.error("索引构建失败!", e)
|
116 |
print(e)
|
117 |
return None
|
118 |
|
|
|
159 |
question,
|
160 |
prompt_tmpl,
|
161 |
refine_tmpl,
|
162 |
+
sim_k=5,
|
163 |
temprature=0,
|
164 |
prefix_messages=[],
|
165 |
reply_language="中文",
|
|
|
169 |
logging.debug("Index file found")
|
170 |
logging.debug("Querying index...")
|
171 |
llm_predictor = LLMPredictor(
|
172 |
+
llm=ChatOpenAI(
|
173 |
temperature=temprature,
|
174 |
model_name="gpt-3.5-turbo-0301",
|
175 |
prefix_messages=prefix_messages,
|
|
|
181 |
rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
|
182 |
response = index.query(
|
183 |
question,
|
|
|
184 |
similarity_top_k=sim_k,
|
185 |
text_qa_template=qa_prompt,
|
186 |
refine_template=rf_prompt,
|
modules/openai_func.py
CHANGED
@@ -1,70 +1,82 @@
|
|
1 |
import requests
|
2 |
import logging
|
3 |
-
from modules.presets import
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
|
|
|
|
|
|
7 |
|
8 |
-
def
|
9 |
headers = {
|
10 |
"Content-Type": "application/json",
|
11 |
-
"Authorization": f"Bearer {openai_api_key}"
|
12 |
}
|
13 |
-
|
14 |
timeout = timeout_all
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
if
|
24 |
-
|
25 |
-
|
26 |
-
if https_proxy:
|
27 |
-
logging.info(f"使用 HTTPS 代理: {https_proxy}")
|
28 |
-
proxies["https"] = https_proxy
|
29 |
-
|
30 |
-
# 如果有代理,使用代理发送请求,否则使用默认设置发送请求
|
31 |
-
"""
|
32 |
-
暂不支持修改
|
33 |
-
if shared.state.balance_api_url != BALANCE_API_URL:
|
34 |
-
logging.info(f"使用自定义BALANCE API URL: {shared.state.balance_api_url}")
|
35 |
-
"""
|
36 |
-
if proxies:
|
37 |
-
response = requests.get(
|
38 |
-
BALANCE_API_URL,
|
39 |
-
headers=headers,
|
40 |
-
timeout=timeout,
|
41 |
-
proxies=proxies,
|
42 |
-
)
|
43 |
else:
|
44 |
-
|
45 |
-
|
46 |
-
headers=headers,
|
47 |
-
timeout=timeout,
|
48 |
-
)
|
49 |
-
return response
|
50 |
|
51 |
def get_usage(openai_api_key):
|
52 |
try:
|
53 |
-
|
54 |
-
logging.debug(
|
55 |
try:
|
56 |
-
balance =
|
57 |
-
|
58 |
-
|
59 |
-
"total_used") else 0
|
60 |
except Exception as e:
|
61 |
logging.error(f"API使用情况解析失败:"+str(e))
|
62 |
balance = 0
|
63 |
total_used=0
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
except requests.exceptions.ConnectTimeout:
|
66 |
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
67 |
return status_text
|
68 |
except requests.exceptions.ReadTimeout:
|
69 |
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
|
70 |
return status_text
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
import logging
|
3 |
+
from modules.presets import (
|
4 |
+
timeout_all,
|
5 |
+
USAGE_API_URL,
|
6 |
+
BALANCE_API_URL,
|
7 |
+
standard_error_msg,
|
8 |
+
connection_timeout_prompt,
|
9 |
+
error_retrieve_prompt,
|
10 |
+
read_timeout_prompt
|
11 |
+
)
|
12 |
|
13 |
+
from modules import shared
|
14 |
+
from modules.utils import get_proxies
|
15 |
+
import os, datetime
|
16 |
|
17 |
+
def get_billing_data(openai_api_key, billing_url):
|
18 |
headers = {
|
19 |
"Content-Type": "application/json",
|
20 |
+
"Authorization": f"Bearer {openai_api_key}"
|
21 |
}
|
22 |
+
|
23 |
timeout = timeout_all
|
24 |
+
proxies = get_proxies()
|
25 |
+
response = requests.get(
|
26 |
+
billing_url,
|
27 |
+
headers=headers,
|
28 |
+
timeout=timeout,
|
29 |
+
proxies=proxies,
|
30 |
+
)
|
31 |
+
|
32 |
+
if response.status_code == 200:
|
33 |
+
data = response.json()
|
34 |
+
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
else:
|
36 |
+
raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
|
37 |
+
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def get_usage(openai_api_key):
|
40 |
try:
|
41 |
+
balance_data=get_billing_data(openai_api_key, BALANCE_API_URL)
|
42 |
+
logging.debug(balance_data)
|
43 |
try:
|
44 |
+
balance = balance_data["total_available"] if balance_data["total_available"] else 0
|
45 |
+
total_used = balance_data["total_used"] if balance_data["total_used"] else 0
|
46 |
+
usage_percent = round(total_used / (total_used+balance) * 100, 2)
|
|
|
47 |
except Exception as e:
|
48 |
logging.error(f"API使用情况解析失败:"+str(e))
|
49 |
balance = 0
|
50 |
total_used=0
|
51 |
+
return f"**API使用情况解析失败**"
|
52 |
+
if balance == 0:
|
53 |
+
last_day_of_month = datetime.datetime.now().strftime("%Y-%m-%d")
|
54 |
+
first_day_of_month = datetime.datetime.now().replace(day=1).strftime("%Y-%m-%d")
|
55 |
+
usage_url = f"{USAGE_API_URL}?start_date={first_day_of_month}&end_date={last_day_of_month}"
|
56 |
+
try:
|
57 |
+
usage_data = get_billing_data(openai_api_key, usage_url)
|
58 |
+
except Exception as e:
|
59 |
+
logging.error(f"获取API使用情况失败:"+str(e))
|
60 |
+
return f"**获取API使用情况失败**"
|
61 |
+
return f"**本月使用金额** \u3000 ${usage_data['total_usage'] / 100}"
|
62 |
+
|
63 |
+
# return f"**免费额度**(已用/余额)\u3000${total_used} / ${balance}"
|
64 |
+
return f"""\
|
65 |
+
<b>免费额度使用情况</b>
|
66 |
+
<div class="progress-bar">
|
67 |
+
<div class="progress" style="width: {usage_percent}%;">
|
68 |
+
<span class="progress-text">{usage_percent}%</span>
|
69 |
+
</div>
|
70 |
+
</div>
|
71 |
+
<div style="display: flex; justify-content: space-between;"><span>已用 ${total_used}</span><span>可用 ${balance}</span></div>
|
72 |
+
"""
|
73 |
+
|
74 |
except requests.exceptions.ConnectTimeout:
|
75 |
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
76 |
return status_text
|
77 |
except requests.exceptions.ReadTimeout:
|
78 |
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
|
79 |
return status_text
|
80 |
+
except Exception as e:
|
81 |
+
logging.error(f"获取API使用情况失败:"+str(e))
|
82 |
+
return standard_error_msg + error_retrieve_prompt
|
modules/presets.py
CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
|
|
5 |
initial_prompt = "You are a helpful assistant."
|
6 |
API_URL = "https://api.openai.com/v1/chat/completions"
|
7 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
|
|
8 |
HISTORY_DIR = "history"
|
9 |
TEMPLATES_DIR = "templates"
|
10 |
|
@@ -18,9 +19,7 @@ ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
|
|
18 |
no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
|
19 |
no_input_msg = "请输入对话内容。" # 未输入对话内容
|
20 |
|
21 |
-
max_token_streaming = 3500 # 流式对话时的最大 token 数
|
22 |
timeout_streaming = 10 # 流式对话时的超时时间
|
23 |
-
max_token_all = 3500 # 非流式对话时的最大 token 数
|
24 |
timeout_all = 200 # 非流式对话时的超时时间
|
25 |
enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
|
26 |
HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
|
@@ -41,6 +40,10 @@ description = """\
|
|
41 |
</div>
|
42 |
"""
|
43 |
|
|
|
|
|
|
|
|
|
44 |
summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
|
45 |
|
46 |
MODELS = [
|
@@ -52,8 +55,36 @@ MODELS = [
|
|
52 |
"gpt-4-32k-0314",
|
53 |
] # 可选的模型
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
REPLY_LANGUAGES = [
|
56 |
-
"
|
|
|
57 |
"English",
|
58 |
"日本語",
|
59 |
"Español",
|
|
|
5 |
initial_prompt = "You are a helpful assistant."
|
6 |
API_URL = "https://api.openai.com/v1/chat/completions"
|
7 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
8 |
+
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
|
9 |
HISTORY_DIR = "history"
|
10 |
TEMPLATES_DIR = "templates"
|
11 |
|
|
|
19 |
no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
|
20 |
no_input_msg = "请输入对话内容。" # 未输入对话内容
|
21 |
|
|
|
22 |
timeout_streaming = 10 # 流式对话时的超时时间
|
|
|
23 |
timeout_all = 200 # 非流式对话时的超时时间
|
24 |
enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
|
25 |
HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
|
|
|
40 |
</div>
|
41 |
"""
|
42 |
|
43 |
+
footer = """\
|
44 |
+
<div class="versions">{versions}</div>
|
45 |
+
"""
|
46 |
+
|
47 |
summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
|
48 |
|
49 |
MODELS = [
|
|
|
55 |
"gpt-4-32k-0314",
|
56 |
] # 可选的模型
|
57 |
|
58 |
+
MODEL_SOFT_TOKEN_LIMIT = {
|
59 |
+
"gpt-3.5-turbo": {
|
60 |
+
"streaming": 3500,
|
61 |
+
"all": 3500
|
62 |
+
},
|
63 |
+
"gpt-3.5-turbo-0301": {
|
64 |
+
"streaming": 3500,
|
65 |
+
"all": 3500
|
66 |
+
},
|
67 |
+
"gpt-4": {
|
68 |
+
"streaming": 7500,
|
69 |
+
"all": 7500
|
70 |
+
},
|
71 |
+
"gpt-4-0314": {
|
72 |
+
"streaming": 7500,
|
73 |
+
"all": 7500
|
74 |
+
},
|
75 |
+
"gpt-4-32k": {
|
76 |
+
"streaming": 31000,
|
77 |
+
"all": 31000
|
78 |
+
},
|
79 |
+
"gpt-4-32k-0314": {
|
80 |
+
"streaming": 31000,
|
81 |
+
"all": 31000
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
REPLY_LANGUAGES = [
|
86 |
+
"简体中文",
|
87 |
+
"繁體中文",
|
88 |
"English",
|
89 |
"日本語",
|
90 |
"Español",
|
modules/utils.py
CHANGED
@@ -10,6 +10,8 @@ import csv
|
|
10 |
import requests
|
11 |
import re
|
12 |
import html
|
|
|
|
|
13 |
|
14 |
import gradio as gr
|
15 |
from pypinyin import lazy_pinyin
|
@@ -115,7 +117,11 @@ def convert_mdtext(md_text):
|
|
115 |
|
116 |
|
117 |
def convert_asis(userinput):
|
118 |
-
return
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def detect_converted_mark(userinput):
|
121 |
if userinput.endswith(ALREADY_CONVERTED_MARK):
|
@@ -153,6 +159,7 @@ def construct_assistant(text):
|
|
153 |
def construct_token_message(token, stream=False):
|
154 |
return f"Token 计数: {token}"
|
155 |
|
|
|
156 |
def delete_first_conversation(history, previous_token_count):
|
157 |
if history:
|
158 |
del history[:2]
|
@@ -346,6 +353,8 @@ def change_proxy(proxy):
|
|
346 |
|
347 |
|
348 |
def hide_middle_chars(s):
|
|
|
|
|
349 |
if len(s) <= 8:
|
350 |
return s
|
351 |
else:
|
@@ -362,20 +371,14 @@ def submit_key(key):
|
|
362 |
return key, msg
|
363 |
|
364 |
|
365 |
-
def sha1sum(filename):
|
366 |
-
sha1 = hashlib.sha1()
|
367 |
-
sha1.update(filename.encode("utf-8"))
|
368 |
-
return sha1.hexdigest()
|
369 |
-
|
370 |
-
|
371 |
def replace_today(prompt):
|
372 |
today = datetime.datetime.today().strftime("%Y-%m-%d")
|
373 |
return prompt.replace("{current_date}", today)
|
374 |
|
375 |
|
376 |
def get_geoip():
|
377 |
-
response = requests.get("https://ipapi.co/json/", timeout=5)
|
378 |
try:
|
|
|
379 |
data = response.json()
|
380 |
except:
|
381 |
data = {"error": True, "reason": "连接ipapi失败"}
|
@@ -383,7 +386,7 @@ def get_geoip():
|
|
383 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
384 |
if data["reason"] == "RateLimited":
|
385 |
return (
|
386 |
-
f"获取IP地理位置失败,因为达到了检测IP
|
387 |
)
|
388 |
else:
|
389 |
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
@@ -427,8 +430,91 @@ def cancel_outputing():
|
|
427 |
logging.info("中止输出……")
|
428 |
shared.state.interrupt()
|
429 |
|
|
|
430 |
def transfer_input(inputs):
|
431 |
# 一次性返回,降低延迟
|
432 |
textbox = reset_textbox()
|
433 |
outputing = start_outputing()
|
434 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import requests
|
11 |
import re
|
12 |
import html
|
13 |
+
import sys
|
14 |
+
import subprocess
|
15 |
|
16 |
import gradio as gr
|
17 |
from pypinyin import lazy_pinyin
|
|
|
117 |
|
118 |
|
119 |
def convert_asis(userinput):
|
120 |
+
return (
|
121 |
+
f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
|
122 |
+
+ ALREADY_CONVERTED_MARK
|
123 |
+
)
|
124 |
+
|
125 |
|
126 |
def detect_converted_mark(userinput):
|
127 |
if userinput.endswith(ALREADY_CONVERTED_MARK):
|
|
|
159 |
def construct_token_message(token, stream=False):
|
160 |
return f"Token 计数: {token}"
|
161 |
|
162 |
+
|
163 |
def delete_first_conversation(history, previous_token_count):
|
164 |
if history:
|
165 |
del history[:2]
|
|
|
353 |
|
354 |
|
355 |
def hide_middle_chars(s):
|
356 |
+
if s is None:
|
357 |
+
return ""
|
358 |
if len(s) <= 8:
|
359 |
return s
|
360 |
else:
|
|
|
371 |
return key, msg
|
372 |
|
373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
def replace_today(prompt):
|
375 |
today = datetime.datetime.today().strftime("%Y-%m-%d")
|
376 |
return prompt.replace("{current_date}", today)
|
377 |
|
378 |
|
379 |
def get_geoip():
|
|
|
380 |
try:
|
381 |
+
response = requests.get("https://ipapi.co/json/", timeout=5)
|
382 |
data = response.json()
|
383 |
except:
|
384 |
data = {"error": True, "reason": "连接ipapi失败"}
|
|
|
386 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
387 |
if data["reason"] == "RateLimited":
|
388 |
return (
|
389 |
+
f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用。"
|
390 |
)
|
391 |
else:
|
392 |
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
|
|
430 |
logging.info("中止输出……")
|
431 |
shared.state.interrupt()
|
432 |
|
433 |
+
|
434 |
def transfer_input(inputs):
|
435 |
# 一次性返回,降低延迟
|
436 |
textbox = reset_textbox()
|
437 |
outputing = start_outputing()
|
438 |
+
return (
|
439 |
+
inputs,
|
440 |
+
gr.update(value=""),
|
441 |
+
gr.Button.update(visible=True),
|
442 |
+
gr.Button.update(visible=False),
|
443 |
+
)
|
444 |
+
|
445 |
+
|
446 |
+
def get_proxies():
|
447 |
+
# 获取环境变量中的代理设置
|
448 |
+
http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
|
449 |
+
https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
|
450 |
+
|
451 |
+
# 如果存在代理设置,使用它们
|
452 |
+
proxies = {}
|
453 |
+
if http_proxy:
|
454 |
+
logging.info(f"使用 HTTP 代理: {http_proxy}")
|
455 |
+
proxies["http"] = http_proxy
|
456 |
+
if https_proxy:
|
457 |
+
logging.info(f"使用 HTTPS 代理: {https_proxy}")
|
458 |
+
proxies["https"] = https_proxy
|
459 |
+
|
460 |
+
if proxies == {}:
|
461 |
+
proxies = None
|
462 |
+
|
463 |
+
return proxies
|
464 |
+
|
465 |
+
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
466 |
+
if desc is not None:
|
467 |
+
print(desc)
|
468 |
+
if live:
|
469 |
+
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
470 |
+
if result.returncode != 0:
|
471 |
+
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
472 |
+
Command: {command}
|
473 |
+
Error code: {result.returncode}""")
|
474 |
+
|
475 |
+
return ""
|
476 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
477 |
+
if result.returncode != 0:
|
478 |
+
message = f"""{errdesc or 'Error running command'}.
|
479 |
+
Command: {command}
|
480 |
+
Error code: {result.returncode}
|
481 |
+
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
482 |
+
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
483 |
+
"""
|
484 |
+
raise RuntimeError(message)
|
485 |
+
return result.stdout.decode(encoding="utf8", errors="ignore")
|
486 |
+
|
487 |
+
def versions_html():
|
488 |
+
git = os.environ.get('GIT', "git")
|
489 |
+
python_version = ".".join([str(x) for x in sys.version_info[0:3]])
|
490 |
+
try:
|
491 |
+
commit_hash = run(f"{git} rev-parse HEAD").strip()
|
492 |
+
except Exception:
|
493 |
+
commit_hash = "<none>"
|
494 |
+
if commit_hash != "<none>":
|
495 |
+
short_commit = commit_hash[0:7]
|
496 |
+
commit_info = f"<a style=\"text-decoration:none\" href=\"https://github.com/GaiZhenbiao/ChuanhuChatGPT/commit/{short_commit}\">{short_commit}</a>"
|
497 |
+
else:
|
498 |
+
commit_info = "unknown \U0001F615"
|
499 |
+
return f"""
|
500 |
+
Python: <span title="{sys.version}">{python_version}</span>
|
501 |
+
•
|
502 |
+
Gradio: {gr.__version__}
|
503 |
+
•
|
504 |
+
Commit: {commit_info}
|
505 |
+
"""
|
506 |
+
|
507 |
+
def add_source_numbers(lst, source_name = "Source", use_source = True):
|
508 |
+
if use_source:
|
509 |
+
return [f'[{idx+1}]\t "{item[0]}"\n{source_name}: {item[1]}' for idx, item in enumerate(lst)]
|
510 |
+
else:
|
511 |
+
return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)]
|
512 |
+
|
513 |
+
def add_details(lst):
|
514 |
+
nodes = []
|
515 |
+
for index, txt in enumerate(lst):
|
516 |
+
brief = txt[:25].replace("\n", "")
|
517 |
+
nodes.append(
|
518 |
+
f"<details><summary>{brief}...</summary><p>{txt}</p></details>"
|
519 |
+
)
|
520 |
+
return nodes
|