artistypl commited on
Commit
db7289c
1 Parent(s): a03bd7a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -0
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import nltk
5
+ import sentence_transformers
6
+ import torch
7
+ from itertools import islice
8
+ from duckduckgo_search import ddg
9
+ from duckduckgo_search import DDGS
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.document_loaders import UnstructuredFileLoader
12
+ from langchain.embeddings import JinaEmbeddings
13
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
14
+ from langchain.prompts import PromptTemplate
15
+ from langchain.prompts.prompt import PromptTemplate
16
+ from langchain.vectorstores import FAISS
17
+
18
+
19
+
20
+
21
+ from chatllm import ChatLLM
22
+ from chinese_text_splitter import ChineseTextSplitter
23
+
24
+ nltk.data.path.append('./nltk_data')
25
+
26
+ embedding_model_dict = {
27
+ "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
28
+ "ernie-base": "nghuyong/ernie-3.0-base-zh",
29
+ "text2vec-base": "GanymedeNil/text2vec-base-chinese",
30
+ #"ViT-B-32": 'ViT-B-32::laion2b-s34b-b79k'
31
+ }
32
+
33
+ llm_model_dict = {
34
+ "ChatGLM-6B-int8": "THUDM/chatglm-6b-int8",
35
+ "ChatGLM-6B-int4": "THUDM/chatglm-6b-int4",
36
+ "ChatGLM-6b-int4-qe": "THUDM/chatglm-6b-int4-qe"
37
+ }
38
+
39
+ DEVICE = "cuda" if torch.cuda.is_available(
40
+ ) else "mps" if torch.backends.mps.is_available() else "cpu"
41
+
42
+
43
+ def search_web(query):
44
+ web_content = ''
45
+ with DDGS() as ddgs:
46
+ results = ddgs.text(query, region='wt-wt', safesearch='Off');
47
+ if results:
48
+ for result in islice(results, 3):
49
+ web_content += result['body']
50
+ return web_content
51
+
52
+
53
+ def load_file(filepath):
54
+ if filepath.lower().endswith(".pdf"):
55
+ loader = UnstructuredFileLoader(filepath)
56
+ textsplitter = ChineseTextSplitter(pdf=True)
57
+ docs = loader.load_and_split(textsplitter)
58
+ else:
59
+ loader = UnstructuredFileLoader(filepath, mode="elements")
60
+ textsplitter = ChineseTextSplitter(pdf=False)
61
+ docs = loader.load_and_split(text_splitter=textsplitter)
62
+ return docs
63
+
64
+
65
+ def init_knowledge_vector_store(embedding_model, filepath):
66
+ if embedding_model == "ViT-B-32":
67
+ jina_auth_token = os.getenv('jina_auth_token')
68
+ embeddings = JinaEmbeddings(
69
+ jina_auth_token=jina_auth_token,
70
+ model_name=embedding_model_dict[embedding_model])
71
+ else:
72
+ embeddings = HuggingFaceEmbeddings(
73
+ model_name=embedding_model_dict[embedding_model], )
74
+ embeddings.client = sentence_transformers.SentenceTransformer(
75
+ embeddings.model_name, device=DEVICE)
76
+
77
+ docs = load_file(filepath)
78
+
79
+ vector_store = FAISS.from_documents(docs, embeddings)
80
+ return vector_store
81
+
82
+
83
+ def get_knowledge_based_answer(query,
84
+ large_language_model,
85
+ vector_store,
86
+ VECTOR_SEARCH_TOP_K,
87
+ web_content,
88
+ history_len,
89
+ temperature,
90
+ top_p,
91
+ chat_history=[]):
92
+ if web_content:
93
+ prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
94
+ 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
95
+ 已知网络检索内容:{web_content}""" + """
96
+ 已知内容:
97
+ {context}
98
+ 问题:
99
+ {question}"""
100
+ else:
101
+ prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。
102
+ 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。
103
+
104
+ 已知内容:
105
+ {context}
106
+
107
+ 问题:
108
+ {question}"""
109
+ prompt = PromptTemplate(template=prompt_template,
110
+ input_variables=["context", "question"])
111
+ chatLLM = ChatLLM()
112
+ chatLLM.history = chat_history[-history_len:] if history_len > 0 else []
113
+ if large_language_model == "ChatGPT":
114
+ chatLLM.model = OpenAI()
115
+ else:
116
+ chatLLM.load_model(
117
+ model_name_or_path=llm_model_dict[large_language_model])
118
+ chatLLM.temperature = temperature
119
+ chatLLM.top_p = top_p
120
+
121
+ knowledge_chain = RetrievalQA.from_llm(
122
+ llm=chatLLM,
123
+ retriever=vector_store.as_retriever(
124
+ search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
125
+ prompt=prompt)
126
+ knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
127
+ input_variables=["page_content"], template="{page_content}")
128
+
129
+ knowledge_chain.return_source_documents = True
130
+
131
+ result = knowledge_chain({"query": query})
132
+ return result
133
+
134
+
135
+ def clear_session():
136
+ return '', None
137
+
138
+
139
+ def predict(input,
140
+ large_language_model,
141
+ embedding_model,
142
+ file_obj,
143
+ VECTOR_SEARCH_TOP_K,
144
+ temperature,
145
+ top_p,
146
+ use_web,
147
+ history=None):
148
+ if history == None:
149
+ history = []
150
+ print(file_obj.name)
151
+ vector_store = init_knowledge_vector_store(embedding_model, file_obj.name)
152
+ if use_web == 'True':
153
+ web_content = search_web(query=input)
154
+ else:
155
+ web_content = ''
156
+ resp = get_knowledge_based_answer(
157
+ query=input,
158
+ large_language_model=large_language_model,
159
+ vector_store=vector_store,
160
+ VECTOR_SEARCH_TOP_K=VECTOR_SEARCH_TOP_K,
161
+ web_content=web_content,
162
+ chat_history=history,
163
+ history_len=history_len,
164
+ temperature=temperature,
165
+ top_p=top_p,
166
+ )
167
+ print(resp)
168
+ history.append((input, resp['result']))
169
+ return '', history, history
170
+
171
+
172
+ if __name__ == "__main__":
173
+ block = gr.Blocks()
174
+ with block as demo:
175
+ gr.Markdown("""<h1><center>LangChain-ChatLLM-Webui</center></h1>
176
+ <center><font size=3>
177
+ 本项目基于LangChain和大型语言模型系列模型, 提供基于本地知识的自动问答应用. <br>
178
+ </center></font>
179
+ """)
180
+ with gr.Row():
181
+ with gr.Column(scale=1):
182
+ model_choose = gr.Accordion("模型选择")
183
+ with model_choose:
184
+ large_language_model = gr.Dropdown(
185
+ list(llm_model_dict.keys()),
186
+ label="large language model",
187
+ value="ChatGLM-6B-int4")
188
+
189
+ embedding_model = gr.Dropdown(list(
190
+ embedding_model_dict.keys()),
191
+ label="Embedding model",
192
+ value="text2vec-base")
193
+
194
+ file = gr.File(label='请上传知识库文件, 目前支持txt、docx、md格式',
195
+ file_types=['.txt', '.md', '.docx'])
196
+
197
+ use_web = gr.Radio(["True", "False"],
198
+ label="Web Search",
199
+ value="False")
200
+ model_argument = gr.Accordion("模型参数配置")
201
+
202
+ with model_argument:
203
+
204
+ VECTOR_SEARCH_TOP_K = gr.Slider(
205
+ 1,
206
+ 10,
207
+ value=6,
208
+ step=1,
209
+ label="vector search top k",
210
+ interactive=True)
211
+
212
+ # HISTORY_LEN = gr.Slider(0,
213
+ # 3,
214
+ # value=0,
215
+ # step=1,
216
+ # label="history len",
217
+ # interactive=True)
218
+
219
+ temperature = gr.Slider(0,
220
+ 1,
221
+ value=0.01,
222
+ step=0.01,
223
+ label="temperature",
224
+ interactive=True)
225
+ top_p = gr.Slider(0,
226
+ 1,
227
+ value=0.9,
228
+ step=0.1,
229
+ label="top_p",
230
+ interactive=True)
231
+
232
+ with gr.Column(scale=4):
233
+ chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
234
+ message = gr.Textbox(label='请输入问题')
235
+ state = gr.State()
236
+
237
+ with gr.Row():
238
+ clear_history = gr.Button("🧹 清除历史对话")
239
+ send = gr.Button("🚀 发送")
240
+
241
+ send.click(predict,
242
+ inputs=[
243
+ message, large_language_model,
244
+ embedding_model, file, VECTOR_SEARCH_TOP_K,
245
+ HISTORY_LEN, temperature, top_p, use_web,
246
+ state
247
+ ],
248
+ outputs=[message, chatbot, state])
249
+ clear_history.click(fn=clear_session,
250
+ inputs=[],
251
+ outputs=[chatbot, state],
252
+ queue=False)
253
+
254
+ message.submit(predict,
255
+ inputs=[
256
+ message, large_language_model,
257
+ embedding_model, file,
258
+ VECTOR_SEARCH_TOP_K, HISTORY_LEN,
259
+ temperature, top_p, use_web, state
260
+ ],
261
+ outputs=[message, chatbot, state])
262
+ gr.Markdown("""提醒:<br>
263
+ 1. 使用时请先上传自己的知识文件,并且文件中不含某些特殊字符,否则将返回error. <br>
264
+ 2. 有任何使用问题,请通过[问题交流区](https://huggingface.co/spaces/thomas-yanxin/LangChain-ChatLLM/discussions)或[Github Issue区](https://github.com/thomas-yanxin/LangChain-ChatGLM-Webui/issues)进行反馈. <br>
265
+ """)
266
+ demo.queue().launch(server_name='0.0.0.0', share=False)