MyRecipe / modules /chat_func.py
Wendyy's picture
Update modules/chat_func.py
64e3877
# -*- coding:utf-8 -*-
from __future__ import annotations
from typing import TYPE_CHECKING, List
import logging
import json
import os
import requests
import urllib3
from tqdm import tqdm
import colorama
from duckduckgo_search import ddg
import asyncio
import aiohttp
from modules.presets import *
from modules.llama_func import *
from modules.utils import *
from . import shared
from modules.config import retrieve_proxy
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
if TYPE_CHECKING:
from typing import TypedDict
class DataframeData(TypedDict):
headers: List[str]
data: List[List[str | int | bool]]
initial_prompt = "You are a helpful assistant."
HISTORY_DIR = "history"
TEMPLATES_DIR = "templates"
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
def get_response(
openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openai_api_key}",
}
history = [construct_system(system_prompt), *history]
payload = {
"model": selected_model,
"messages": history, # [{"role": "user", "content": f"{inputs}"}],
"temperature": temperature, # 1.0,
"top_p": top_p, # 1.0,
"n": 1,
"stream": stream,
"presence_penalty": 0,
"frequency_penalty": 0,
}
if stream:
timeout = timeout_streaming
else:
timeout = timeout_all
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
if shared.state.completion_url != COMPLETION_URL:
logging.info(f"使用自定义API URL: {shared.state.completion_url}")
with retrieve_proxy():
response = requests.post(
shared.state.completion_url,
headers=headers,
json=payload,
stream=True,
timeout=timeout,
)
return response
def stream_predict(
openai_api_key,
system_prompt,
history,
inputs,
chatbot,
all_token_counts,
top_p,
temperature,
selected_model,
fake_input=None,
display_append=""
):
def get_return_value():
return chatbot, history, status_text, all_token_counts
logging.info("实时回答模式")
partial_words = ""
counter = 0
status_text = "开始实时传输回答……"
history.append(construct_user(inputs))
history.append(construct_assistant(""))
if fake_input:
chatbot.append((fake_input, ""))
else:
chatbot.append((inputs, ""))
user_token_count = 0
if fake_input is not None:
input_token_count = count_token(construct_user(fake_input))
else:
input_token_count = count_token(construct_user(inputs))
if len(all_token_counts) == 0:
system_prompt_token_count = count_token(construct_system(system_prompt))
user_token_count = (
input_token_count + system_prompt_token_count
)
else:
user_token_count = input_token_count
all_token_counts.append(user_token_count)
logging.info(f"输入token计数: {user_token_count}")
yield get_return_value()
try:
response = get_response(
openai_api_key,
system_prompt,
history,
temperature,
top_p,
True,
selected_model,
)
except requests.exceptions.ConnectTimeout:
status_text = (
standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
)
yield get_return_value()
return
except requests.exceptions.ReadTimeout:
status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
yield get_return_value()
return
yield get_return_value()
error_json_str = ""
if fake_input is not None:
history[-2] = construct_user(fake_input)
for chunk in tqdm(response.iter_lines()):
if counter == 0:
counter += 1
continue
counter += 1
# check whether each line is non-empty
if chunk:
chunk = chunk.decode()
chunklength = len(chunk)
try:
chunk = json.loads(chunk[6:])
except json.JSONDecodeError:
logging.info(chunk)
error_json_str += chunk
status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
yield get_return_value()
continue
# decode each line as response data is in bytes
if chunklength > 6 and "delta" in chunk["choices"][0]:
finish_reason = chunk["choices"][0]["finish_reason"]
status_text = construct_token_message(all_token_counts)
if finish_reason == "stop":
yield get_return_value()
break
try:
partial_words = (
partial_words + chunk["choices"][0]["delta"]["content"]
)
except KeyError:
status_text = (
standard_error_msg
+ "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
+ str(sum(all_token_counts))
)
yield get_return_value()
break
history[-1] = construct_assistant(partial_words)
chatbot[-1] = (chatbot[-1][0], partial_words + display_append)
all_token_counts[-1] += 1
yield get_return_value()
def predict_all(
openai_api_key,
system_prompt,
history,
inputs,
chatbot,
all_token_counts,
top_p,
temperature,
selected_model,
fake_input=None,
display_append=""
):
logging.info("一次性回答模式")
history.append(construct_user(inputs))
history.append(construct_assistant(""))
if fake_input:
chatbot.append((fake_input, ""))
else:
chatbot.append((inputs, ""))
if fake_input is not None:
all_token_counts.append(count_token(construct_user(fake_input)))
else:
all_token_counts.append(count_token(construct_user(inputs)))
try:
# logging.info(f"system_prompt:{system_prompt}")
# logging.info(f"history:{history}")
# # 如果能传入index,则此处里获得初筛后的店铺和菜名
response = get_response(
openai_api_key,
system_prompt,
history,
temperature,
top_p,
False,
selected_model,
)
# response = json.loads(response.text)
# logging.info(f"初次响应推荐店铺:{response}")
# response = response["choices"][0]["message"]["content"]
#
# logging.info(f"初次响应推荐店铺:{response}")
# # 将response中的店铺和菜名提取出来
# import re
#
# # text = """
# # 好的,针对您想吃韩式烤肉的需求,我向您推荐以下店铺和菜品:
# #
# # 店铺名称:“青年烤肉店” 推荐菜品:烤牛肉、烤猪肉、烤羊肉
# #
# # 店铺名称:“西西里烤肉店” 推荐菜品:烤牛肉串、烤排骨、烤鸡肉
# #
# # 店铺名称:“韩式烤肉店” 推荐菜品:石锅拌饭、铁板烧、烤牛舌"""
#
# pattern = r'店铺名称:(.+?) 推荐菜品:(.+)。'
#
# results = re.findall(pattern, response)
#
# dicts = {}
# import string
# for result in results:
# dicts[result[0]] = result[1].split('、')
#
# logging.info(f"初筛后的店铺和菜品:{dicts}")
# dishes = []
# for restaurant, dish in dicts.items():
# dishes.extend(dish)
#
# dishes = '、'.join(dishes)
# dishes = "半熟芝士拿铁、白桃半熟芝士拿铁、招牌烤全羊、羊肉串"
# # 将初筛后的店铺和菜品送入构建好的CoT
# prompt_with_ingredient = f"""
# 我需要你推测一些菜可能的原料以及其营养成分,输出格式如下:
# -----------------------
# 菜品名称:[]
# 菜品原料:[原料1,原料2...]
# 营养成分:[成分(含量)]
# -----------------------
# 注意,其中营养成分包括蛋白质、脂肪、碳水化合物、纤维素、维生素等,你可以根据你的知识添加其他成分。营养成分的含量分为无、低、中、高四个等级,需要填在成分后的括号内。
# 以下是需要你推测的菜品名称,不同菜品用顿号隔开:
# -----------------------
# {dishes}
# -----------------------
# 每个菜品按照以上格式分开回复,但是注意,除了以上格式的内容,不要回复其他内容。
# """
# logging.info(f"分析食物中营养成分的prompt构建完成:{prompt_with_ingredient}")
# history_ingredient=[]
# history_ingredient.append(construct_user(prompt_with_ingredient))
# history_ingredient.append(construct_assistant(""))
# logging.info(f"history_ingredient:{history_ingredient}")
# response_ingredient = get_response(
# openai_api_key,
# "你是一个营养分析专家",
# history_ingredient,
# temperature,
# top_p,
# True,
# selected_model,
# )
# response_ingredient = json.loads(response_ingredient.text)
# response_ingredient = response_ingredient["choices"][0]["message"]["content"]
# logging.info(f"得到食物中的营养成分:{response_ingredient}")
# prompt_rec = f"""
# 以下是一些菜品名称和所属的店铺,我需要你根据我的需求从其中推荐一家店铺的一种或多种菜品,并给出推荐的理由。我的需求为:我有糖尿病,而且今天不想吃太油腻的食物。
# -----------------------
# {response_ingredient}
# -----------------------
# """
# history_rec = []
# history_rec.append(construct_user(prompt_rec))
# history_rec.append(construct_assistant(""))
# logging.info(f"history_rec:{history_rec}")
# response = get_response(
# openai_api_key,
# "你是一个美食推荐家",
# history_rec,
# temperature,
# top_p,
# True,
# selected_model,
# )
except requests.exceptions.ConnectTimeout:
status_text = (
standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
)
return chatbot, history, status_text, all_token_counts
except requests.exceptions.ProxyError:
status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
return chatbot, history, status_text, all_token_counts
except requests.exceptions.SSLError:
status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
return chatbot, history, status_text, all_token_counts
response = json.loads(response.text)
if fake_input is not None:
history[-2] = construct_user(fake_input)
try:
content = response["choices"][0]["message"]["content"]
history[-1] = construct_assistant(content)
chatbot[-1] = (chatbot[-1][0], content + display_append)
total_token_count = response["usage"]["total_tokens"]
if fake_input is not None:
all_token_counts[-1] += count_token(construct_assistant(content))
else:
all_token_counts[-1] = total_token_count - sum(all_token_counts)
status_text = construct_token_message(total_token_count)
return chatbot, history, status_text, all_token_counts
except KeyError:
status_text = standard_error_msg + str(response)
return chatbot, history, status_text, all_token_counts
def predict(
openai_api_key,
system_prompt,
history,
inputs,
chatbot,
all_token_counts,
top_p,
temperature,
stream=False,
selected_model=MODELS[0],
use_websearch=False,
files=None,
reply_language="中文",
should_check_token_count=True,
): # repetition_penalty, top_k
# CHANGE
# files = [{'name': 'database/cuc-pure.txt'}]
# CHANGE
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
from llama_index.indices.query.schema import QueryBundle
from langchain.llms import OpenAIChat
logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
if should_check_token_count:
yield chatbot + [(inputs, "")], history, "开始生成回答……", all_token_counts
if reply_language == "跟随问题语言(不稳定)":
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
old_inputs = None
display_reference = []
limited_context = False
if files:
limited_context = True
old_inputs = inputs
msg = "加载索引中……(这可能需要几分钟)"
logging.info(msg)
yield chatbot + [(inputs, "")], history, msg, all_token_counts
index = construct_index(openai_api_key, file_src=files)
msg = "索引构建完成,获取回答中……"
logging.info(msg)
yield chatbot + [(inputs, "")], history, msg, all_token_counts
with retrieve_proxy():
llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
prompt_helper = PromptHelper(max_input_size=4096, num_output=5, max_chunk_overlap=20, chunk_size_limit=600)
from llama_index import ServiceContext
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context,
similarity_top_k=5, vector_store=index._vector_store,
docstore=index._docstore)
query_bundle = QueryBundle(inputs)
nodes = query_object.retrieve(query_bundle)
reference_results = [n.node.text for n in nodes]
reference_results = add_source_numbers(reference_results, use_source=False)
display_reference = add_details(reference_results)
display_reference = "\n\n" + "".join(display_reference)
inputs = (
replace_today(PROMPT_TEMPLATE)
.replace("{query_str}", inputs)
.replace("{context_str}", "\n\n".join(reference_results))
.replace("{reply_language}", reply_language)
)
elif use_websearch:
limited_context = True
search_results = ddg(inputs, max_results=5)
old_inputs = inputs
reference_results = []
for idx, result in enumerate(search_results):
logging.info(f"搜索结果{idx + 1}{result}")
domain_name = urllib3.util.parse_url(result["href"]).host
reference_results.append([result["body"], result["href"]])
display_reference.append(f"{idx + 1}. [{domain_name}]({result['href']})\n")
reference_results = add_source_numbers(reference_results)
display_reference = "\n\n" + "".join(display_reference)
inputs = (
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
.replace("{query}", inputs)
.replace("{web_results}", "\n\n".join(reference_results))
.replace("{reply_language}", reply_language)
)
else:
display_reference = ""
if len(openai_api_key) == 0 and not shared.state.multi_api_key:
status_text = standard_error_msg + no_apikey_msg
logging.info(status_text)
chatbot.append((inputs, ""))
if len(history) == 0:
history.append(construct_user(inputs))
history.append("")
all_token_counts.append(0)
else:
history[-2] = construct_user(inputs)
yield chatbot + [(inputs, "")], history, status_text, all_token_counts
return
elif len(inputs.strip()) == 0:
status_text = standard_error_msg + no_input_msg
logging.info(status_text)
yield chatbot + [(inputs, "")], history, status_text, all_token_counts
return
if stream:
logging.info("使用流式传输")
iter = stream_predict(
openai_api_key,
system_prompt,
history,
inputs,
chatbot,
all_token_counts,
top_p,
temperature,
selected_model,
fake_input=old_inputs,
display_append=display_reference
)
for chatbot, history, status_text, all_token_counts in iter:
if shared.state.interrupted:
shared.state.recover()
return
yield chatbot, history, status_text, all_token_counts
else:
logging.info("不使用流式传输")
chatbot, history, status_text, all_token_counts = predict_all(
openai_api_key,
system_prompt,
history,
inputs,
chatbot,
all_token_counts,
top_p,
temperature,
selected_model,
fake_input=old_inputs,
display_append=display_reference
)
yield chatbot, history, status_text, all_token_counts
logging.info(f"传输完毕。当前token计数为{all_token_counts}")
if len(history) > 1 and history[-1]["content"] != inputs:
logging.info(
"回答为:"
+ colorama.Fore.BLUE
+ f"{history[-1]['content']}"
+ colorama.Style.RESET_ALL
)
if limited_context:
history = history[-4:]
all_token_counts = all_token_counts[-2:]
yield chatbot, history, status_text, all_token_counts
if stream:
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
else:
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
if sum(all_token_counts) > max_token and should_check_token_count:
print(all_token_counts)
count = 0
while sum(all_token_counts) > max_token - 500 and sum(all_token_counts) > 0:
count += 1
del all_token_counts[0]
del history[:2]
logging.info(status_text)
status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
yield chatbot, history, status_text, all_token_counts
def retry(
openai_api_key,
system_prompt,
history,
chatbot,
token_count,
top_p,
temperature,
stream=False,
selected_model=MODELS[0],
reply_language="中文",
):
logging.info("重试中……")
if len(history) == 0:
yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
return
history.pop()
inputs = history.pop()["content"]
token_count.pop()
iter = predict(
openai_api_key,
system_prompt,
history,
inputs,
chatbot,
token_count,
top_p,
temperature,
stream=stream,
selected_model=selected_model,
reply_language=reply_language,
)
logging.info("重试中……")
for x in iter:
yield x
logging.info("重试完毕")
def reduce_token_size(
openai_api_key,
system_prompt,
history,
chatbot,
token_count,
top_p,
temperature,
max_token_count,
selected_model=MODELS[0],
reply_language="中文",
):
logging.info("开始减少token数量……")
iter = predict(
openai_api_key,
system_prompt,
history,
summarize_prompt,
chatbot,
token_count,
top_p,
temperature,
selected_model=selected_model,
should_check_token_count=False,
reply_language=reply_language,
)
logging.info(f"chatbot: {chatbot}")
flag = False
for chatbot, history, status_text, previous_token_count in iter:
num_chat = find_n(previous_token_count, max_token_count)
logging.info(f"previous_token_count: {previous_token_count}, keeping {num_chat} chats")
if flag:
chatbot = chatbot[:-1]
flag = True
history = history[-2 * num_chat:] if num_chat > 0 else []
token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
msg = f"保留了最近{num_chat}轮对话"
yield chatbot, history, msg + "," + construct_token_message(
token_count if len(token_count) > 0 else [0],
), token_count
logging.info(msg)
logging.info("减少token数量完毕")