Spaces:
Sleeping
Sleeping
""" | |
1. 完成了用Qwen通义千问作为知识库查询。 | |
1. 总共有三个区块:知识库回答,应用来源,相关问题。 | |
1. 在Huggingface的API上部署了一个在线BGE的模型,用于回答问题。OpenAI的Emebedding或者Langchain的Embedding都不可以用(会报错: self.d)。 | |
注意事项: | |
1. langchain_KB.py中的代码是用来构建本地知识库的,里面的embeddings需要与rag_response_002.py中的embeddings一致。否则会出错! | |
""" | |
##TODO: | |
# -*- coding: utf-8 -*- | |
import streamlit as st | |
import openai | |
import os | |
import numpy as np | |
import pandas as pd | |
import csv | |
import tempfile | |
from tempfile import NamedTemporaryFile | |
import pathlib | |
from pathlib import Path | |
import re | |
from re import sub | |
from itertools import product | |
import time | |
from time import sleep | |
import streamlit_authenticator as stauth | |
from langchain_community.vectorstores import FAISS | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain.llms.base import LLM | |
from langchain.llms.utils import enforce_stop_tokens | |
from typing import Dict, List, Optional, Tuple, Union | |
import requests | |
import streamlit as st | |
import qwen_response | |
import rag_reponse_002 | |
import dashscope | |
from dotenv import load_dotenv | |
from datetime import datetime | |
import pytz | |
from pytz import timezone | |
# def get_current_time(): | |
# beijing_tz = timezone('Asia/Shanghai') | |
# beijing_time = datetime.now(beijing_tz) | |
# current_time = beijing_time.strftime('%H:%M:%S') | |
# return current_time | |
load_dotenv() | |
### 设置openai的API key | |
os.environ["OPENAI_API_KEY"] = os.environ['user_token'] | |
openai.api_key = os.environ['user_token'] | |
bing_search_api_key = os.environ['bing_api_key'] | |
dashscope.api_key = os.environ['dashscope_api_key'] | |
### Streamlit页面设定。 | |
st.set_page_config(layout="wide") | |
st.title("本地化国产大模型知识库查询演示") | |
# st.title("大语言模型智能知识库查询中心") | |
# st.title("大语言模型本地知识库问答系统") | |
# st.subheader("Large Language Model-based Knowledge Base QA System") | |
# st.warning("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_") | |
st.caption("_声明:内容由人工智能生成,仅供参考。您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_") | |
# st.caption("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_") | |
# st.info("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_") | |
# st.divider() | |
### upload file | |
# username = 'test' | |
# path = f'./{username}/faiss_index/index.faiss' | |
# if os.path.exists(path): | |
# print(f'{path} local KB exists') | |
# database_info = pd.read_csv(f'./{username}/database_name.csv') | |
# current_database_name = database_info.iloc[-1][0] | |
# current_database_date = database_info.iloc[-1][1] | |
# database_claim = f"当前知识库为:{current_database_name},创建于{current_database_date}。可以开始提问!" | |
# st.markdown(database_claim) | |
# uploaded_file = st.file_uploader( | |
# "选择上传一个新知识库", type=(["pdf"])) | |
# # 默认状态下没有上传文件,None,会报错。需要判断。 | |
# if uploaded_file is not None: | |
# # uploaded_file_path = upload_file(uploaded_file) | |
# upload_file(uploaded_file) | |
# # ## 创建向量数据库 | |
# from langchain.embeddings.openai import OpenAIEmbeddings | |
# embeddings = OpenAIEmbeddings(disallowed_special=()) ## 这里是联网情况下,部署在Huggingface上后使用。 | |
# print('embeddings:', embeddings) | |
# embedding_model_name = 'GanymedeNil/text2vec-large-chinese' | |
# # embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) ## 这里是联网情况下连接huggingface后使用。 | |
# embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/RAG/bge-large-zh') ## 切换成BGE的embedding。 | |
# embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/RAG/bge-large-zh/') ## 切换成BGE的embedding。 | |
# embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/chatGLM/My_LocalKB_Project/GanymedeNil_text2vec-large-chinese/') ## 这里会有个“No sentence-transformers model found with name“的warning,但不是error,不影响使用。 | |
### authentication with a local yaml file. | |
import yaml | |
from yaml.loader import SafeLoader | |
with open('./config.yaml') as file: | |
config = yaml.load(file, Loader=SafeLoader) | |
authenticator = stauth.Authenticate( | |
config['credentials'], | |
config['cookie']['name'], | |
config['cookie']['key'], | |
config['cookie']['expiry_days'], | |
config['preauthorized'] | |
) | |
user, authentication_status, username = authenticator.login('用户登录', 'main') | |
if authentication_status: | |
with st.sidebar: | |
st.markdown( | |
""" | |
<style> | |
[data-testid="stSidebar"][aria-expanded="true"]{ | |
min-width: 450px; | |
max-width: 450px; | |
} | |
""", | |
unsafe_allow_html=True, | |
) | |
### siderbar的题目。 | |
### siderbar的题目。 | |
# st.header(f'**大语言模型专家系统工作设定区**') | |
st.header(f'**欢迎 **{username}** 使用本系统** ') | |
st.write(f'_Large Language Model Expert System Working Environment_') | |
# st.write(f'_Welcome and Hope U Enjoy Staying Here_') | |
authenticator.logout('登出', 'sidebar') | |
### upload模块 | |
def upload_file(uploaded_file): | |
if uploaded_file is not None: | |
# filename = uploaded_file.name | |
# st.write(filename) # print out the whole file name to validate. not to show in the final version. | |
try: | |
# if '.pdf' in filename: ### original code here. | |
if '.pdf' in uploaded_file.name: | |
pdf_filename = uploaded_file.name ### original code here. | |
filename = uploaded_file.name | |
# print('PDF file:', pdf_filename) | |
# with st.status('正在为您解析新知识库...', expanded=False, state='running') as status: | |
spinner = st.spinner('正在为您解析新知识库...请耐心等待') | |
with spinner: | |
### 以下是langchain方案。 | |
import langchain_KB | |
import save_database_info | |
uploaded_file_name = "File_provided" | |
temp_dir = tempfile.TemporaryDirectory() | |
# ! working. | |
uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name | |
with open(pdf_filename, 'wb') as output_temporary_file: | |
# with open(f'./{username}_upload.pdf', 'wb') as output_temporary_file: ### original code here. 可能会造成在引用信息来源时文件名不对的问题。 | |
# ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。 | |
# output_temporary_file.write(uploaded_file.getvalue()) | |
output_temporary_file.write(uploaded_file.getvalue()) | |
langchain_KB.langchain_localKB_construct(output_temporary_file, username) | |
## 在屏幕上展示当前知识库的信息,包括名字和加载日期。 | |
save_database_info.save_database_info(f'./{username}/database_name.csv', pdf_filename, str(datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%d %H:%M"))) | |
st.markdown('新知识库解析成功,请务必刷新页面,然后开启对话 🔃') | |
return pdf_filename | |
else: | |
# if '.csv' in filename: ### original code here. | |
if '.csv' in uploaded_file.name: | |
print('start the csv file processing...') | |
csv_filename = uploaded_file.name | |
filename = uploaded_file.name | |
csv_file = pd.read_csv(uploaded_file) | |
csv_file.to_csv(f'./{username}/{username}_upload.csv', encoding='utf-8', index=False) | |
st.write(csv_file[:3]) # 这里只是显示文件,后面需要定位文件所在的绝对路径。 | |
else: | |
xls_file = pd.read_excel(uploaded_file) | |
xls_file.to_csv(f'./{username}_upload.csv', index=False) | |
st.write(xls_file[:3]) | |
print('end the csv file processing...') | |
# uploaded_file_name = "File_provided" | |
# temp_dir = tempfile.TemporaryDirectory() | |
# ! working. | |
# uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name | |
# with open('./upload.csv', 'wb') as output_temporary_file: | |
# with open(f'./{username}_upload.csv', 'wb') as output_temporary_file: | |
# print(f'./{name}_upload.csv') | |
# ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。 | |
# output_temporary_file.write(uploaded_file.getvalue()) | |
# st.write(uploaded_file_path) #* 可以查看文件是否真实存在,然后是否可以 | |
except Exception as e: | |
st.write(e) | |
## 以下代码是为了解决上传文件后,文件路径和文件名不对的问题。 | |
# uploaded_file_name = "File_provided" | |
# temp_dir = tempfile.TemporaryDirectory() | |
# # ! working. | |
# uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name | |
# # with open('./upload.csv', 'wb') as output_temporary_file: | |
# with open(f'./{name}_upload.csv', 'wb') as output_temporary_file: | |
# # print(f'./{name}_upload.csv') | |
# # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。 | |
# # output_temporary_file.write(uploaded_file.getvalue()) | |
# output_temporary_file.write(uploaded_file.getvalue()) | |
# # st.write(uploaded_file_path) # * 可以查看文件是否真实存在,然后是否可以 | |
# # st.write('Now file saved successfully.') | |
# return pdf_filename, csv_filename | |
return filename | |
path = f'./{username}/faiss_index/index.faiss' | |
if os.path.exists(path): | |
print(f'{path} local KB exists') | |
database_info = pd.read_csv(f'./{username}/database_name.csv', encoding='utf-8', header=None) ## 不加encoding的话,中文名字的PDF会报错。 | |
print(database_info) | |
current_database_name = database_info.iloc[-1][0] | |
current_database_date = database_info.iloc[-1][1] | |
database_claim = f"当前知识库为:{current_database_name},创建于{current_database_date}。可以开始提问!" | |
st.warning(database_claim) | |
# st.markdown(database_claim) | |
try: | |
uploaded_file = st.file_uploader( | |
"选择上传一个新知识库", type=(["pdf"])) | |
# 默认状态下没有上传文件,None,会报错。需要判断。 | |
if uploaded_file is not None: | |
# uploaded_file_path = upload_file(uploaded_file) | |
upload_file(uploaded_file) | |
except Exception as e: | |
print(e) | |
pass | |
## 在sidebar上的三个分页显示,用st.tabs实现。 | |
tab_1, tab_2, tab_3, tab_4 = st.tabs(['使用须知', '模型参数', '提示词模板', '系统角色设定']) | |
# with st.expander(label='**使用须知**', expanded=False): | |
with tab_1: | |
# st.markdown("#### 快速上手指南") | |
# with st.text(body="说明"): | |
# st.markdown("* 重启一轮新对话时,只需要刷新页面(按Ctrl/Command + R)即可。") | |
with st.text(body="说明"): | |
st.markdown("* 为了保护数据与隐私,所有对话均不会被保存,刷新页面立即删除。敬请放心。") | |
# with st.text(body="说明"): | |
# st.markdown("* “GPT-4”回答质量极佳,但速度缓慢,建议适当使用。") | |
with st.text(body="说明"): | |
st.markdown("* 查询知识库模式与所有的搜索引擎或者数据库检索方式一样,仅限一轮对话,将不会保持之前的会话记录。") | |
with st.text(body="说明"): | |
st.markdown("""* 系统的工作流程如下: | |
1. 用户输入问题。 | |
1. 系统将问题转换为机器可理解的格式。 | |
1. 系统使用大语言模型来生成与问题相关的候选答案。 | |
1. 系统使用本地知识库来评估候选答案的准确性。 | |
1. 系统返回最准确的答案。""") | |
## 大模型参数 | |
# with st.expander(label='**大语言模型参数**', expanded=True): | |
with tab_2: | |
max_tokens = st.slider(label='Max_Token(生成结果时最大字数)', min_value=100, max_value=8096, value=4096,step=100) | |
temperature = st.slider(label='Temperature (温度)', min_value=0.0, max_value=1.0, value=0.8, step=0.1) | |
top_p = st.slider(label='Top_P (核采样)', min_value=0.0, max_value=1.0, value=0.6, step=0.1) | |
frequency_penalty = st.slider(label='Frequency Penalty (重复度惩罚因子)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1) | |
presence_penalty = st.slider(label='Presence Penalty (控制主题的重复度)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1) | |
with tab_3: | |
# st.markdown("#### Prompt提示词参考资料") | |
# with st.expander(label="**大语言模型基础提示词Prompt示例**", expanded=False): | |
st.code( | |
body="我是一个企业主,我需要关注哪些“存货”相关的数据资源规则?", language='plaintext') | |
st.code( | |
body="作为零售商,了解哪些关键的库存管理指标对我至关重要?", language='plaintext') | |
st.code(body="企业主在监控库存时,应如何确保遵守行业法规和最佳实践?", | |
language='plaintext') | |
st.code(body="在数字化时代,我应该关注哪些技术工具或平台来优化我的库存数据流程?", language='plaintext') | |
st.code(body="我应该如何定期审查和分析这些库存数据以提高运营效率?", language='plaintext') | |
st.code(body="如何设置预警系统来避免过度库存或缺货情况?", language='plaintext') | |
with tab_4: | |
st.text_area(label='系统角色设定', value='你是一个人工智能,你需要回答我提出的问题,或者完成我交代的任务。你需要使用我提问的语言(如中文、英文)来回答。', height=200, label_visibility='hidden') | |
elif authentication_status == False: | |
st.error('⛔ 用户名或密码错误!') | |
elif authentication_status == None: | |
st.warning('⬆ 请先登录!') | |
### 上传文件的模块 | |
#### start: 主程序 | |
## 清楚所有对话记录。 | |
def clear_all(): | |
st.session_state.conversation = None | |
st.session_state.chat_history = None | |
st.session_state.messages = [] | |
message_placeholder = st.empty() | |
## 只用这一个就可以了。 | |
st.rerun() | |
return None | |
if "copied" not in st.session_state: | |
st.session_state.copied = [] | |
if "llm_response" not in st.session_state: | |
st.session_state.llm_response = "" | |
## copy to clipboard function with a button. | |
def copy_to_clipboard(text): | |
st.session_state.copied.append(text) | |
clipboard.copy(text) | |
def main(): | |
# llm = ChatGLM() ## 启动一个实例。 | |
col1, col2 = st.columns([2, 1]) | |
# st.markdown('### 数据库查询区') | |
# with st.expander(label='**查询企业内部知识库**', expanded=True): | |
with col1: | |
KB_mode = True | |
user_input = st.text_input(label='**📶 大模型数据库对话区**', placeholder='请输入您的问题', label_visibility='visible') | |
if user_input: | |
## 非stream输出,原始状态,不需要改变api.py中的内容。 | |
# with st.status('检索中...', expanded=True, state='running') as status: | |
spinner = st.spinner('思考中...请耐心等待') | |
with spinner: | |
if KB_mode == True: | |
# import rag_reponse_001 | |
# clear_all() | |
# response = rag_reponse_001.rag_response(user_input=user_input, k=5) ## working. | |
# print('user_input:', user_input) | |
response, source = rag_reponse_002.rag_response(username=username, user_input=user_input, k=3) | |
print('llm response:', response) | |
sim_prompt = f"""你需要根据以下的问题来提出5个可能的后续问题{user_input} | |
""" | |
# sim_questions = chatgpt.chatgpt(user_prompt=sim_prompt) ## chatgpt to get similar questions. | |
sim_questions = qwen_response.call_with_messages(sim_prompt) | |
if len(user_input) != 0: | |
sim_prompt = f"""你需要根据以下的初始问题来提出3个相似的问题和3个后续问题。 | |
初始问题是:{user_input} | |
-------------------- | |
你回答的时候,需要使用如下格式: | |
**相似问题:** | |
**后续问题:** | |
""" | |
# sim_prompt = f"""你需要根据以下的问题来提出5个可能的后续问题{user_input}""" | |
### 这里用chatgpt来生成相似问题。 | |
# sim_questions = chatgpt.chatgpt(user_prompt=sim_prompt) | |
### 这里用Qwen来生成相似问题。 | |
sim_questions = qwen_response.call_with_messages(sim_prompt) | |
st.markdown(response) | |
# st_copy_to_clipboard(text=str(response), show_text=True, before_copy_label="📋", after_copy_label="✅") | |
## 如果这样使用,每次按button都会重新提交问题。 | |
# st.button(label="📃", on_click=copy_to_clipboard, args=(response,)) | |
st.divider() | |
st.caption(source) | |
st.divider() | |
## 初始状态下response未被定义。 | |
try: | |
if response: | |
with col2: | |
with st.expander(label='## **您可能还会关注以下内容**', expanded=True): | |
st.info(sim_questions) | |
except: | |
pass | |
# st.stop() | |
return None | |
#### End: 主程序 | |
if __name__ == '__main__': | |
main() | |