Spaces:
Running
Running
Upload 2 files
Browse files- app.py +96 -41
- requirements.txt +3 -1
app.py
CHANGED
@@ -22,6 +22,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings
|
|
22 |
from langchain.chat_models import ChatOpenAI
|
23 |
|
24 |
# LangChain
|
|
|
25 |
from langchain.llms import HuggingFacePipeline
|
26 |
from transformers import pipeline
|
27 |
|
@@ -45,8 +46,8 @@ import gradio as gr
|
|
45 |
from pypdf import PdfReader
|
46 |
import requests # DeepL API request
|
47 |
|
48 |
-
#
|
49 |
-
import
|
50 |
|
51 |
# --------------------------------------
|
52 |
# ユーザ別セッションの変数値を記録するクラス
|
@@ -69,6 +70,7 @@ class SessionState:
|
|
69 |
self.conversation_chain = None # ConversationChain
|
70 |
self.query_generator = None # Query Refiner with Chat history
|
71 |
self.qa_chain = None # load_qa_chain
|
|
|
72 |
self.embedded_urls = []
|
73 |
self.similarity_search_k = None # No. of similarity search documents to find.
|
74 |
self.summarization_mode = None # Stuff / Map Reduce / Refine
|
@@ -132,6 +134,33 @@ text_splitter = JPTextSplitter(
|
|
132 |
chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
|
133 |
)
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
# --------------------------------------
|
136 |
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
|
137 |
# --------------------------------------
|
@@ -175,11 +204,22 @@ def deepl_memory(ss: SessionState) -> (SessionState):
|
|
175 |
# DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
|
176 |
# DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
|
177 |
|
178 |
-
def web_search(
|
179 |
-
search = DuckDuckGoSearchRun()
|
180 |
web_result = search(query)
|
181 |
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
text = [query, web_result]
|
184 |
params = {
|
185 |
"auth_key": DEEPL_API_KEY,
|
@@ -193,19 +233,28 @@ def web_search(query, current_model) -> str:
|
|
193 |
response = request.json()
|
194 |
|
195 |
query = response["translations"][0]["text"]
|
196 |
-
web_result =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
-
web_query = query + "\nUse the following information as a reference to answer the question above in the Japanese.\n===\nReference: " + web_result + "\n==="
|
199 |
|
200 |
-
return web_query
|
201 |
|
202 |
# --------------------------------------
|
203 |
# LangChain カスタムプロンプト各種
|
204 |
# llama tokenizer
|
205 |
-
#
|
206 |
-
|
207 |
# OpenAI tokenizer
|
208 |
-
#
|
209 |
# --------------------------------------
|
210 |
|
211 |
# --------------------------------------
|
@@ -214,19 +263,18 @@ def web_search(query, current_model) -> str:
|
|
214 |
|
215 |
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
|
216 |
sys_chat_message = """
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます".
|
221 |
""".replace("\n", "")
|
222 |
|
223 |
chat_common_format = """
|
224 |
===
|
225 |
Question: {query}
|
226 |
-
|
227 |
-
Conversation History:
|
228 |
{chat_history}
|
229 |
-
|
230 |
日本語の回答: """
|
231 |
|
232 |
chat_template_std = f"{sys_chat_message}{chat_common_format}"
|
@@ -238,21 +286,23 @@ chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common
|
|
238 |
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
|
239 |
sys_qa_message = """
|
240 |
You are an AI concierge who carefully answers questions from customers based on references.
|
241 |
-
You understand what the customer wants to know from
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
""".replace("\n", "")
|
246 |
|
247 |
qa_common_format = """
|
248 |
===
|
249 |
Question: {query}
|
250 |
References: {context}
|
251 |
-
|
|
|
252 |
{chat_history}
|
253 |
-
|
254 |
日本語の回答: """
|
255 |
|
|
|
256 |
qa_template_std = f"{sys_qa_message}{qa_common_format}"
|
257 |
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
|
258 |
|
@@ -262,8 +312,8 @@ qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_forma
|
|
262 |
# 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
|
263 |
query_generator_message = """
|
264 |
Referring to the "Conversation History", reformat the user's "Additional Question"
|
265 |
-
to a specific question
|
266 |
-
|
267 |
""".replace("\n", "")
|
268 |
|
269 |
query_generator_common_format = """
|
@@ -272,7 +322,7 @@ query_generator_common_format = """
|
|
272 |
{chat_history}
|
273 |
|
274 |
[Additional Question] {query}
|
275 |
-
|
276 |
|
277 |
query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
|
278 |
query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
|
@@ -287,8 +337,8 @@ and complement.
|
|
287 |
|
288 |
question_prompt_common_format = """
|
289 |
===
|
290 |
-
[references] {context}
|
291 |
[Question] {query}
|
|
|
292 |
[Summary] """
|
293 |
|
294 |
question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
|
@@ -305,17 +355,14 @@ If you do not know the answer, do not make up an answer and reply,
|
|
305 |
|
306 |
combine_prompt_common_format = """
|
307 |
===
|
308 |
-
Question:
|
309 |
-
{query}
|
310 |
-
===
|
311 |
Reference: {summaries}
|
312 |
-
===
|
313 |
日本語の回答: """
|
314 |
|
|
|
315 |
combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}"
|
316 |
combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]"
|
317 |
|
318 |
-
|
319 |
# --------------------------------------
|
320 |
# ConversationSummaryBufferMemoryの要約プロンプト
|
321 |
# ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
|
@@ -508,6 +555,10 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
|
|
508 |
# --------------------------------------
|
509 |
# Conversation/QAチェーンの設定
|
510 |
# --------------------------------------
|
|
|
|
|
|
|
|
|
511 |
if ss.conversation_chain is None:
|
512 |
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
|
513 |
ss.conversation_chain = ConversationChain(
|
@@ -525,13 +576,14 @@ def set_chains(ss: SessionState, summarization_mode) -> SessionState:
|
|
525 |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
|
526 |
|
527 |
elif summarization_mode == "map_reduce":
|
528 |
-
query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"])
|
529 |
-
ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt)
|
530 |
-
|
531 |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
532 |
combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
|
533 |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
|
534 |
|
|
|
|
|
|
|
|
|
535 |
return ss
|
536 |
|
537 |
def initialize_db(ss: SessionState) -> SessionState:
|
@@ -761,16 +813,16 @@ def bot(ss: SessionState, query, qa_flag, web_flag, summarization_mode) -> (Sess
|
|
761 |
# QA Model
|
762 |
if qa_flag is True and ss.embeddings is not None and ss.db is not None:
|
763 |
if web_flag:
|
764 |
-
web_query = web_search(
|
765 |
ss = qa_predict(ss, web_query)
|
766 |
ss.memory.chat_memory.messages[-2].content = query
|
767 |
else:
|
768 |
-
ss = qa_predict(ss, query)
|
769 |
|
770 |
# Chat Model
|
771 |
else:
|
772 |
if web_flag:
|
773 |
-
web_query = web_search(
|
774 |
ss = chat_predict(ss, web_query)
|
775 |
ss.memory.chat_memory.messages[-2].content = query
|
776 |
else:
|
@@ -788,6 +840,8 @@ def chat_predict(ss: SessionState, query) -> SessionState:
|
|
788 |
|
789 |
def qa_predict(ss: SessionState, query) -> SessionState:
|
790 |
|
|
|
|
|
791 |
# Rinnaモデル向けの設定(クエリの改行コード修正)
|
792 |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
793 |
query = query.strip().replace("\n", "<NL>")
|
@@ -829,7 +883,7 @@ def qa_predict(ss: SessionState, query) -> SessionState:
|
|
829 |
response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。"
|
830 |
|
831 |
# ユーザーメッセージと AI メッセージの追加
|
832 |
-
ss.memory.chat_memory.add_user_message(
|
833 |
ss.memory.chat_memory.add_ai_message(response)
|
834 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
|
835 |
return ss
|
@@ -1028,4 +1082,5 @@ with gr.Blocks() as demo:
|
|
1028 |
|
1029 |
if __name__ == "__main__":
|
1030 |
demo.queue(concurrency_count=5)
|
1031 |
-
demo.launch(debug=True)
|
|
|
|
22 |
from langchain.chat_models import ChatOpenAI
|
23 |
|
24 |
# LangChain
|
25 |
+
import langchain
|
26 |
from langchain.llms import HuggingFacePipeline
|
27 |
from transformers import pipeline
|
28 |
|
|
|
46 |
from pypdf import PdfReader
|
47 |
import requests # DeepL API request
|
48 |
|
49 |
+
# Mecab
|
50 |
+
import MeCab
|
51 |
|
52 |
# --------------------------------------
|
53 |
# ユーザ別セッションの変数値を記録するクラス
|
|
|
70 |
self.conversation_chain = None # ConversationChain
|
71 |
self.query_generator = None # Query Refiner with Chat history
|
72 |
self.qa_chain = None # load_qa_chain
|
73 |
+
self.web_summary_chain = None # Summarize web search result
|
74 |
self.embedded_urls = []
|
75 |
self.similarity_search_k = None # No. of similarity search documents to find.
|
76 |
self.summarization_mode = None # Stuff / Map Reduce / Refine
|
|
|
134 |
chunk_overlap = chunk_overlap, # オーバーラップの最大文字数
|
135 |
)
|
136 |
|
137 |
+
# --------------------------------------
|
138 |
+
# 文中から人名を抽出
|
139 |
+
# --------------------------------------
|
140 |
+
def name_detector(text: str) -> list:
|
141 |
+
mecab = MeCab.Tagger()
|
142 |
+
mecab.parse('') # ←バグ対応
|
143 |
+
node = mecab.parseToNode(text).next
|
144 |
+
names = []
|
145 |
+
|
146 |
+
while node:
|
147 |
+
if node.feature.split(',')[3] == "姓":
|
148 |
+
if node.next and node.next.feature.split(',')[3] == "名":
|
149 |
+
names.append(str(node.surface) + str(node.next.surface))
|
150 |
+
else:
|
151 |
+
names.append(node.surface)
|
152 |
+
if node.feature.split(',')[3] == "名":
|
153 |
+
if node.prev and node.prev.feature.split(',')[3] == "姓":
|
154 |
+
pass
|
155 |
+
else:
|
156 |
+
names.append(str(node.surface))
|
157 |
+
|
158 |
+
node = node.next
|
159 |
+
|
160 |
+
names = list(set(names))
|
161 |
+
|
162 |
+
return names
|
163 |
+
|
164 |
# --------------------------------------
|
165 |
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時)
|
166 |
# --------------------------------------
|
|
|
204 |
# DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate"
|
205 |
# DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
|
206 |
|
207 |
+
def web_search(ss: SessionState, query) -> (SessionState, str):
|
208 |
+
search = DuckDuckGoSearchRun(verbose=True)
|
209 |
web_result = search(query)
|
210 |
|
211 |
+
# 人名の抽出
|
212 |
+
names = []
|
213 |
+
names.extend(name_detector(query))
|
214 |
+
names.extend(name_detector(web_result))
|
215 |
+
if len(names)==0:
|
216 |
+
names = ""
|
217 |
+
elif len(names)==1:
|
218 |
+
names = names[0]
|
219 |
+
else:
|
220 |
+
names = ", ".join(names)
|
221 |
+
|
222 |
+
if ss.current_model == "gpt-3.5-turbo":
|
223 |
text = [query, web_result]
|
224 |
params = {
|
225 |
"auth_key": DEEPL_API_KEY,
|
|
|
233 |
response = request.json()
|
234 |
|
235 |
query = response["translations"][0]["text"]
|
236 |
+
web_result = response["translations"][1]["text"]
|
237 |
+
web_result = ss.web_summary_chain({'query': query, 'context': web_result})['text']
|
238 |
+
|
239 |
+
if names != "":
|
240 |
+
web_query = f"""
|
241 |
+
{query}
|
242 |
+
Use the following information as a reference to answer the question above in Japanese. When translating names of Japanese people, refer to Japanese Names as a translation guide.
|
243 |
+
Reference: {web_result}
|
244 |
+
Japanese Names: {names}
|
245 |
+
""".strip()
|
246 |
+
else:
|
247 |
+
web_query = query + "\nUse the following information as a reference to answer the question above in the Japanese.\n===\nReference: " + web_result + "\n==="
|
248 |
|
|
|
249 |
|
250 |
+
return ss, web_query
|
251 |
|
252 |
# --------------------------------------
|
253 |
# LangChain カスタムプロンプト各種
|
254 |
# llama tokenizer
|
255 |
+
# https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/
|
|
|
256 |
# OpenAI tokenizer
|
257 |
+
# https://platform.openai.com/tokenizer
|
258 |
# --------------------------------------
|
259 |
|
260 |
# --------------------------------------
|
|
|
263 |
|
264 |
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162
|
265 |
sys_chat_message = """
|
266 |
+
You are an outstanding AI concierge. You understand your customers' needs from their questions and answer
|
267 |
+
them with many specific and detailed information in Japanese. If you do not know the answer to a question,
|
268 |
+
do make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます". Ignore Conversation History.
|
|
|
269 |
""".replace("\n", "")
|
270 |
|
271 |
chat_common_format = """
|
272 |
===
|
273 |
Question: {query}
|
274 |
+
===
|
275 |
+
Conversation History(Ignore):
|
276 |
{chat_history}
|
277 |
+
===
|
278 |
日本語の回答: """
|
279 |
|
280 |
chat_template_std = f"{sys_chat_message}{chat_common_format}"
|
|
|
286 |
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225
|
287 |
sys_qa_message = """
|
288 |
You are an AI concierge who carefully answers questions from customers based on references.
|
289 |
+
You understand what the customer wants to know from Question, and give a specific answer in
|
290 |
+
Japanese using sentences extracted from the following references. If you do not know the answer,
|
291 |
+
do not make up an answer and reply, "誠に申し訳ございませんが、その点についてはわかりかねます".
|
292 |
+
Ignore Conversation History.
|
293 |
""".replace("\n", "")
|
294 |
|
295 |
qa_common_format = """
|
296 |
===
|
297 |
Question: {query}
|
298 |
References: {context}
|
299 |
+
===
|
300 |
+
Conversation History(Ignore):
|
301 |
{chat_history}
|
302 |
+
===
|
303 |
日本語の回答: """
|
304 |
|
305 |
+
|
306 |
qa_template_std = f"{sys_qa_message}{qa_common_format}"
|
307 |
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]"
|
308 |
|
|
|
312 |
# 1. 会話履歴と最新の質問から、質問文を生成するchain のプロンプト
|
313 |
query_generator_message = """
|
314 |
Referring to the "Conversation History", reformat the user's "Additional Question"
|
315 |
+
to a specific question by filling in the missing subject, verb, objects, complements,
|
316 |
+
and other necessary information to get a better search result. Answer in Japanese.
|
317 |
""".replace("\n", "")
|
318 |
|
319 |
query_generator_common_format = """
|
|
|
322 |
{chat_history}
|
323 |
|
324 |
[Additional Question] {query}
|
325 |
+
明確な日本語の質問文: """
|
326 |
|
327 |
query_generator_template_std = f"{query_generator_message}{query_generator_common_format}"
|
328 |
query_generator_template_llama2 = f"<s>[INST] <<SYS>>{query_generator_message}<</SYS>>{query_generator_common_format}[/INST]"
|
|
|
337 |
|
338 |
question_prompt_common_format = """
|
339 |
===
|
|
|
340 |
[Question] {query}
|
341 |
+
[references] {context}
|
342 |
[Summary] """
|
343 |
|
344 |
question_prompt_template_std = f"{question_prompt_message}{question_prompt_common_format}"
|
|
|
355 |
|
356 |
combine_prompt_common_format = """
|
357 |
===
|
358 |
+
Question: {query}
|
|
|
|
|
359 |
Reference: {summaries}
|
|
|
360 |
日本語の回答: """
|
361 |
|
362 |
+
|
363 |
combine_prompt_template_std = f"{combine_prompt_message}{combine_prompt_common_format}"
|
364 |
combine_prompt_template_llama2 = f"<s>[INST] <<SYS>>{combine_prompt_message}<</SYS>>{combine_prompt_common_format}[/INST]"
|
365 |
|
|
|
366 |
# --------------------------------------
|
367 |
# ConversationSummaryBufferMemoryの要約プロンプト
|
368 |
# ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49
|
|
|
555 |
# --------------------------------------
|
556 |
# Conversation/QAチェーンの設定
|
557 |
# --------------------------------------
|
558 |
+
if ss.query_generator is None:
|
559 |
+
query_generator_prompt = PromptTemplate(template=query_generator_template, input_variables = ["chat_history", "query"])
|
560 |
+
ss.query_generator = LLMChain(llm=ss.llm, prompt=query_generator_prompt, verbose=True)
|
561 |
+
|
562 |
if ss.conversation_chain is None:
|
563 |
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template)
|
564 |
ss.conversation_chain = ConversationChain(
|
|
|
576 |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt)
|
577 |
|
578 |
elif summarization_mode == "map_reduce":
|
|
|
|
|
|
|
579 |
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
580 |
combine_prompt = PromptTemplate(template=combine_template, input_variables=["summaries", "query"])
|
581 |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="map_reduce", return_map_steps=True, memory=ss.memory, question_prompt=question_prompt, combine_prompt=combine_prompt)
|
582 |
|
583 |
+
if ss.web_summary_chain is None:
|
584 |
+
question_prompt = PromptTemplate(template=question_template, input_variables=["context", "query"])
|
585 |
+
ss.web_summary_chain = LLMChain(llm=ss.llm, prompt=question_prompt, verbose=True)
|
586 |
+
|
587 |
return ss
|
588 |
|
589 |
def initialize_db(ss: SessionState) -> SessionState:
|
|
|
813 |
# QA Model
|
814 |
if qa_flag is True and ss.embeddings is not None and ss.db is not None:
|
815 |
if web_flag:
|
816 |
+
ss, web_query = web_search(ss, query)
|
817 |
ss = qa_predict(ss, web_query)
|
818 |
ss.memory.chat_memory.messages[-2].content = query
|
819 |
else:
|
820 |
+
ss = qa_predict(ss, query)
|
821 |
|
822 |
# Chat Model
|
823 |
else:
|
824 |
if web_flag:
|
825 |
+
ss, web_query = web_search(ss, query)
|
826 |
ss = chat_predict(ss, web_query)
|
827 |
ss.memory.chat_memory.messages[-2].content = query
|
828 |
else:
|
|
|
840 |
|
841 |
def qa_predict(ss: SessionState, query) -> SessionState:
|
842 |
|
843 |
+
original_query = query
|
844 |
+
|
845 |
# Rinnaモデル向けの設定(クエリの改行コード修正)
|
846 |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft":
|
847 |
query = query.strip().replace("\n", "<NL>")
|
|
|
883 |
response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。"
|
884 |
|
885 |
# ユーザーメッセージと AI メッセージの追加
|
886 |
+
ss.memory.chat_memory.add_user_message(original_query.replace("<NL>", "\n"))
|
887 |
ss.memory.chat_memory.add_ai_message(response)
|
888 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴
|
889 |
return ss
|
|
|
1082 |
|
1083 |
if __name__ == "__main__":
|
1084 |
demo.queue(concurrency_count=5)
|
1085 |
+
demo.launch(debug=True,)
|
1086 |
+
|
requirements.txt
CHANGED
@@ -21,4 +21,6 @@ numpy==1.23.5
|
|
21 |
pandas==1.5.3
|
22 |
chromedriver-autoinstaller
|
23 |
chromedriver-binary
|
24 |
-
duckduckgo-search==3.8.5
|
|
|
|
|
|
21 |
pandas==1.5.3
|
22 |
chromedriver-autoinstaller
|
23 |
chromedriver-binary
|
24 |
+
duckduckgo-search==3.8.5
|
25 |
+
mecab-python3==1.0.6
|
26 |
+
unidic-lite==1.0.8
|