import os import gradio as gr import nltk import sentence_transformers import torch from duckduckgo_search import ddg from duckduckgo_search.utils import SESSION from langchain.chains import RetrievalQA from langchain.document_loaders import UnstructuredFileLoader from langchain.embeddings import JinaEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.prompts import PromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.vectorstores import FAISS from chatllm import ChatLLM from chinese_text_splitter import ChineseTextSplitter nltk.data.path.append('./nltk_data') embedding_model_dict = { "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-base": "nghuyong/ernie-3.0-base-zh", "text2vec-base": "GanymedeNil/text2vec-base-chinese", #"ViT-B-32": 'ViT-B-32::laion2b-s34b-b79k' } llm_model_dict = { "ChatGLM-6B-int8": "THUDM/chatglm-6b-int8", "ChatGLM-6B-int4": "THUDM/chatglm-6b-int4", "ChatGLM-6b-int4-qe": "THUDM/chatglm-6b-int4-qe", #"Minimax": "Minimax" } DEVICE = "cuda" if torch.cuda.is_available( ) else "mps" if torch.backends.mps.is_available() else "cpu" def search_web(query): SESSION.proxies = { "http": f"socks5h://localhost:7890", "https": f"socks5h://localhost:7890" } results = ddg(query) web_content = '' if results: for result in results: web_content += result['body'] return web_content def load_file(filepath): if filepath.lower().endswith(".pdf"): loader = UnstructuredFileLoader(filepath) textsplitter = ChineseTextSplitter(pdf=True) docs = loader.load_and_split(textsplitter) else: loader = UnstructuredFileLoader(filepath, mode="elements") textsplitter = ChineseTextSplitter(pdf=False) docs = loader.load_and_split(text_splitter=textsplitter) return docs def init_knowledge_vector_store(embedding_model, filepath): if embedding_model == "ViT-B-32": jina_auth_token = os.getenv('jina_auth_token') embeddings = JinaEmbeddings( jina_auth_token=jina_auth_token, model_name=embedding_model_dict[embedding_model]) else: embeddings = HuggingFaceEmbeddings( model_name=embedding_model_dict[embedding_model], ) embeddings.client = sentence_transformers.SentenceTransformer( embeddings.model_name, device=DEVICE) docs = load_file(filepath) vector_store = FAISS.from_documents(docs, embeddings) return vector_store def get_knowledge_based_answer(query, large_language_model, vector_store, VECTOR_SEARCH_TOP_K, web_content, history_len, temperature, top_p, chat_history=[]): if web_content: prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 已知网络检索内容:{web_content}""" + """ 已知内容: {context} 问题: {question}""" else: prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。 已知内容: {context} 问题: {question}""" prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) chatLLM = ChatLLM() chatLLM.history = chat_history[-history_len:] if history_len > 0 else [] if large_language_model == "Minimax": chatLLM.model = 'Minimax' else: chatLLM.load_model( model_name_or_path=llm_model_dict[large_language_model]) chatLLM.temperature = temperature chatLLM.top_p = top_p knowledge_chain = RetrievalQA.from_llm( llm=chatLLM, retriever=vector_store.as_retriever( search_kwargs={"k": VECTOR_SEARCH_TOP_K}), prompt=prompt) knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( input_variables=["page_content"], template="{page_content}") knowledge_chain.return_source_documents = True result = knowledge_chain({"query": query}) return result def clear_session(): return '', None def predict(input, large_language_model, embedding_model, file_obj, VECTOR_SEARCH_TOP_K, history_len, temperature, top_p, use_web, history=None): if history == None: history = [] print(file_obj.name) vector_store = init_knowledge_vector_store(embedding_model, file_obj.name) if use_web == 'True': web_content = search_web(query=input) else: web_content = '' resp = get_knowledge_based_answer( query=input, large_language_model=large_language_model, vector_store=vector_store, VECTOR_SEARCH_TOP_K=VECTOR_SEARCH_TOP_K, web_content=web_content, chat_history=history, history_len=history_len, temperature=temperature, top_p=top_p, ) print(resp) history.append((input, resp['result'])) return '', history, history if __name__ == "__main__": block = gr.Blocks() with block as demo: gr.Markdown("
\
\

Meiya GPT 美亚柏科大模型

\ ") with gr.Row(): with gr.Column(scale=1): model_choose = gr.Accordion("模型选择") with model_choose: large_language_model = gr.Dropdown( list(llm_model_dict.keys()), label="large language model", value="ChatGLM-6B-int4") embedding_model = gr.Dropdown(list( embedding_model_dict.keys()), label="Embedding model", value="text2vec-base") file = gr.File(label='请上传知识库文件, 目前支持txt、docx、md格式', file_types=['.txt', '.md', '.docx']) use_web = gr.Radio(["True", "False"], label="Web Search", value="False") model_argument = gr.Accordion("模型参数配置") with model_argument: VECTOR_SEARCH_TOP_K = gr.Slider( 1, 10, value=6, step=1, label="vector search top k", interactive=True) HISTORY_LEN = gr.Slider(0, 3, value=0, step=1, label="history len", interactive=True) temperature = gr.Slider(0, 1, value=0.01, step=0.01, label="temperature", interactive=True) top_p = gr.Slider(0, 1, value=0.9, step=0.1, label="top_p", interactive=True) with gr.Column(scale=4): chatbot = gr.Chatbot(label='ChatLLM').style(height=600) message = gr.Textbox(label='请输入问题') state = gr.State() with gr.Row(): clear_history = gr.Button("🧹 清除历史对话") send = gr.Button("🚀 发送") send.click(predict, inputs=[ message, large_language_model, embedding_model, file, VECTOR_SEARCH_TOP_K, HISTORY_LEN, temperature, top_p, use_web, state ], outputs=[message, chatbot, state]) clear_history.click(fn=clear_session, inputs=[], outputs=[chatbot, state], queue=False) message.submit(predict, inputs=[ message, large_language_model, embedding_model, file, VECTOR_SEARCH_TOP_K, HISTORY_LEN, temperature, top_p, use_web, state ], outputs=[message, chatbot, state]) def auth(username, password): if username == "admin" and password =="meiya": return True else: return False demo.queue().launch(server_name='0.0.0.0', share=False) demo.launch(share=True, auth=auth, debug=True, server_name='0.0.0.0', server_port=8006,auth_message="请输入用户名与密码")