import os import gradio as gr import nltk import sentence_transformers import torch from duckduckgo_search import ddg from duckduckgo_search.utils import SESSION from langchain.chains import RetrievalQA from langchain.document_loaders import UnstructuredFileLoader from langchain.embeddings import JinaEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.prompts import PromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.vectorstores import FAISS from chatllm import ChatLLM from chinese_text_splitter import ChineseTextSplitter nltk.data.path.append('./nltk_data') embedding_model_dict = { "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-base": "nghuyong/ernie-3.0-base-zh", "text2vec-base": "GanymedeNil/text2vec-base-chinese", #"ViT-B-32": 'ViT-B-32::laion2b-s34b-b79k' } llm_model_dict = { "ChatGLM-6B-int8": "THUDM/chatglm-6b-int8", "ChatGLM-6B-int4": "THUDM/chatglm-6b-int4", "ChatGLM-6b-int4-qe": "THUDM/chatglm-6b-int4-qe", #"Minimax": "Minimax" } DEVICE = "cuda" if torch.cuda.is_available( ) else "mps" if torch.backends.mps.is_available() else "cpu" def search_web(query): SESSION.proxies = { "http": f"socks5h://localhost:7890", "https": f"socks5h://localhost:7890" } results = ddg(query) web_content = '' if results: for result in results: web_content += result['body'] return web_content def load_file(filepath): if filepath.lower().endswith(".pdf"): loader = UnstructuredFileLoader(filepath) textsplitter = ChineseTextSplitter(pdf=True) docs = loader.load_and_split(textsplitter) else: loader = UnstructuredFileLoader(filepath, mode="elements") textsplitter = ChineseTextSplitter(pdf=False) docs = loader.load_and_split(text_splitter=textsplitter) return docs def init_knowledge_vector_store(embedding_model, filepath): if embedding_model == "ViT-B-32": jina_auth_token = os.getenv('jina_auth_token') embeddings = JinaEmbeddings( jina_auth_token=jina_auth_token, model_name=embedding_model_dict[embedding_model]) else: embeddings = HuggingFaceEmbeddings( model_name=embedding_model_dict[embedding_model], ) embeddings.client = sentence_transformers.SentenceTransformer( embeddings.model_name, device=DEVICE) docs = load_file(filepath) vector_store = FAISS.from_documents(docs, embeddings) return vector_store def get_knowledge_based_answer(query, large_language_model, vector_store, VECTOR_SEARCH_TOP_K, web_content, history_len, temperature, top_p, chat_history=[]): if web_content: prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 已知网络检索内容:{web_content}""" + """ 已知内容: {context} 问题: {question}""" else: prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。 已知内容: {context} 问题: {question}""" prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) chatLLM = ChatLLM() chatLLM.history = chat_history[-history_len:] if history_len > 0 else [] if large_language_model == "Minimax": chatLLM.model = 'Minimax' else: chatLLM.load_model( model_name_or_path=llm_model_dict[large_language_model]) chatLLM.temperature = temperature chatLLM.top_p = top_p knowledge_chain = RetrievalQA.from_llm( llm=chatLLM, retriever=vector_store.as_retriever( search_kwargs={"k": VECTOR_SEARCH_TOP_K}), prompt=prompt) knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( input_variables=["page_content"], template="{page_content}") knowledge_chain.return_source_documents = True result = knowledge_chain({"query": query}) return result def clear_session(): return '', None def predict(input, large_language_model, embedding_model, file_obj, VECTOR_SEARCH_TOP_K, history_len, temperature, top_p, use_web, history=None): if history == None: history = [] print(file_obj.name) vector_store = init_knowledge_vector_store(embedding_model, file_obj.name) if use_web == 'True': web_content = search_web(query=input) else: web_content = '' resp = get_knowledge_based_answer( query=input, large_language_model=large_language_model, vector_store=vector_store, VECTOR_SEARCH_TOP_K=VECTOR_SEARCH_TOP_K, web_content=web_content, chat_history=history, history_len=history_len, temperature=temperature, top_p=top_p, ) print(resp) history.append((input, resp['result'])) return '', history, history if __name__ == "__main__": block = gr.Blocks() with block as demo: gr.Markdown("