Upload 8 files
Browse files- travel/RAGGraph.py +202 -0
- travel/__pycache__/RAGGraph.cpython-39.pyc +0 -0
- travel/__pycache__/self_rag_tool.cpython-39.pyc +0 -0
- travel/database_generate.py +183 -0
- travel/rag_tool.py +115 -0
- travel/self_rag_tool.py +158 -0
- travel/travel.py +237 -0
- travel/travel_new.py +236 -0
travel/RAGGraph.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import List, TypedDict, Type, Any, Annotated
|
3 |
+
|
4 |
+
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, AnyMessage
|
5 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
6 |
+
from langchain_core.tools import BaseTool
|
7 |
+
from langchain_openai import ChatOpenAI
|
8 |
+
from langchain.agents import create_openai_tools_agent,AgentExecutor
|
9 |
+
from langgraph.constants import END
|
10 |
+
from langgraph.graph import StateGraph,MessagesState
|
11 |
+
from pydantic.v1 import BaseModel, Field
|
12 |
+
import concurrent.futures
|
13 |
+
from self_rag_tool import GradeAndGenerateTool
|
14 |
+
|
15 |
+
|
16 |
+
class CreateLangGraphState(TypedDict):
|
17 |
+
question: str
|
18 |
+
generation: str
|
19 |
+
documents: List[str] # 检索后的信息,或者通过筛选后的信息
|
20 |
+
|
21 |
+
class self_RAGTool(object):
|
22 |
+
def __init__(self) -> None:
|
23 |
+
self.tools = GradeAndGenerateTool()
|
24 |
+
self.workflow = StateGraph(CreateLangGraphState)
|
25 |
+
|
26 |
+
self.workflow.add_node("retrieve", self.retrieve)
|
27 |
+
self.workflow.add_node("grade_documents", self.grade_documents)
|
28 |
+
self.workflow.add_node("generate_llm", self.generate_llm)
|
29 |
+
self.workflow.add_node("rewrite_question", self.rewrite_question)
|
30 |
+
|
31 |
+
self.workflow.set_entry_point("retrieve")
|
32 |
+
self.workflow.add_edge("retrieve", "grade_documents")
|
33 |
+
self.workflow.add_edge("grade_documents", "generate_llm")
|
34 |
+
self.workflow.add_conditional_edges("generate_llm", self.hallucinations_generate,
|
35 |
+
{"generate_llm": "generate_llm", "rewrite_question": "rewrite_question", "useful": END})
|
36 |
+
|
37 |
+
self.workflow.add_edge("rewrite_question", "retrieve")
|
38 |
+
|
39 |
+
self.graph = self.workflow.compile()
|
40 |
+
def retrieve(self,state):
|
41 |
+
question = state["question"]
|
42 |
+
a=time.time()
|
43 |
+
documents = self.tools.search_vector(question)
|
44 |
+
print("retrieve:",time.time()-a)
|
45 |
+
# result_documents = []
|
46 |
+
# for info in documents[0]:
|
47 |
+
# result_documents.append(info["entity"]["text"])
|
48 |
+
return {
|
49 |
+
"documents": documents,
|
50 |
+
"question": question,
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
def grade_documents(self,state):
|
55 |
+
question = state["question"]
|
56 |
+
documents = state["documents"]
|
57 |
+
result_documents = []
|
58 |
+
a=time.time()
|
59 |
+
for info in documents:
|
60 |
+
# 传入问题,并通过大模型判断当前文档是否与问题相关
|
61 |
+
# 如果是yes,则加入result_documents,否则丢弃
|
62 |
+
result = self.tools.grade(question=question, text=info)
|
63 |
+
if result == "yes":
|
64 |
+
result_documents.append(info)
|
65 |
+
else:
|
66 |
+
continue
|
67 |
+
print("grade_documents:",time.time()-a)
|
68 |
+
return {"question": question, "documents": result_documents}
|
69 |
+
|
70 |
+
|
71 |
+
def generate_llm(self,state):
|
72 |
+
question = state["question"]
|
73 |
+
documents = state["documents"]
|
74 |
+
documents_str = "\n".join(documents).replace("{", "").replace("}", "")
|
75 |
+
a=time.time()
|
76 |
+
result = self.tools.generate(question=question, text=documents_str)
|
77 |
+
print("generate_llm:",time.time()-a)
|
78 |
+
return {"question": question, "generation": result, "documents": documents}
|
79 |
+
|
80 |
+
|
81 |
+
def hallucinations_generate(self,state):
|
82 |
+
print("调用幻觉判断的方法")
|
83 |
+
question = state["question"]
|
84 |
+
generation = state["generation"]
|
85 |
+
documents = state["documents"]
|
86 |
+
documents_str = "\n".join(documents)
|
87 |
+
a=time.time()
|
88 |
+
result = self.tools.hallucinations(documents=documents_str, answer=generation)
|
89 |
+
print("hallucinations_generate:",time.time()-a)
|
90 |
+
if result == "yes":
|
91 |
+
return "generate_llm"
|
92 |
+
else:
|
93 |
+
generation = self.tools.answer_question(question=question, answer=generation)
|
94 |
+
if generation == "yes":
|
95 |
+
return "useful"
|
96 |
+
else:
|
97 |
+
return "rewrite_question"
|
98 |
+
|
99 |
+
|
100 |
+
def rewrite_question(self,state):
|
101 |
+
question = state["question"]
|
102 |
+
a=time.time()
|
103 |
+
result = self.tools.rewrite_question(question=question)
|
104 |
+
print("rewrite_question:",time.time()-a)
|
105 |
+
return {"question": result}
|
106 |
+
|
107 |
+
def get_answer(self,question):
|
108 |
+
res = self.graph.invoke({"question":question})
|
109 |
+
return res['generation']
|
110 |
+
|
111 |
+
class RAGTool(object):
|
112 |
+
def __init__(self) -> None:
|
113 |
+
self.tools = GradeAndGenerateTool()
|
114 |
+
self.workflow = StateGraph(CreateLangGraphState)
|
115 |
+
|
116 |
+
self.workflow.add_node("retrieve", self.retrieve)
|
117 |
+
self.workflow.add_node("generate_llm", self.generate_llm)
|
118 |
+
|
119 |
+
self.workflow.set_entry_point("retrieve")
|
120 |
+
self.workflow.add_edge("retrieve", "generate_llm")
|
121 |
+
|
122 |
+
self.workflow.add_edge("generate_llm", END)
|
123 |
+
|
124 |
+
self.graph = self.workflow.compile()
|
125 |
+
def retrieve(self,state):
|
126 |
+
question = state["question"]
|
127 |
+
a=time.time()
|
128 |
+
documents = self.tools.search_vector(question)
|
129 |
+
print("retrieve:",time.time()-a)
|
130 |
+
# result_documents = []
|
131 |
+
# for info in documents[0]:
|
132 |
+
# result_documents.append(info["entity"]["text"])
|
133 |
+
return {
|
134 |
+
"documents": documents,
|
135 |
+
"question": question,
|
136 |
+
}
|
137 |
+
|
138 |
+
|
139 |
+
def grade_documents(self,state):
|
140 |
+
question = state["question"]
|
141 |
+
documents = state["documents"]
|
142 |
+
result_documents = []
|
143 |
+
a=time.time()
|
144 |
+
for info in documents:
|
145 |
+
# 传入问题,并通过大模型判断当前文档是否与问题相关
|
146 |
+
# 如果是yes,则加入result_documents,否则丢弃
|
147 |
+
result = self.tools.grade(question=question, text=info)
|
148 |
+
if result == "yes":
|
149 |
+
result_documents.append(info)
|
150 |
+
else:
|
151 |
+
continue
|
152 |
+
print("grade_documents:",time.time()-a)
|
153 |
+
return {"question": question, "documents": result_documents}
|
154 |
+
|
155 |
+
|
156 |
+
def generate_llm(self,state):
|
157 |
+
question = state["question"]
|
158 |
+
documents = state["documents"]
|
159 |
+
documents_str = "\n".join(documents)
|
160 |
+
a=time.time()
|
161 |
+
result = self.tools.generate(question=question, text=documents_str)
|
162 |
+
print("generate_llm:",time.time()-a)
|
163 |
+
return {"question": question, "generation": result, "documents": documents}
|
164 |
+
|
165 |
+
|
166 |
+
def hallucinations_generate(self,state):
|
167 |
+
print("调用幻觉判断的方法")
|
168 |
+
question = state["question"]
|
169 |
+
generation = state["generation"]
|
170 |
+
documents = state["documents"]
|
171 |
+
documents_str = "\n".join(documents)
|
172 |
+
a=time.time()
|
173 |
+
result = self.tools.hallucinations(documents=documents_str, answer=generation)
|
174 |
+
print("hallucinations_generate:",time.time()-a)
|
175 |
+
if result == "yes":
|
176 |
+
return "generate_llm"
|
177 |
+
else:
|
178 |
+
generation = self.tools.answer_question(question=question, answer=generation)
|
179 |
+
if generation == "yes":
|
180 |
+
return "useful"
|
181 |
+
else:
|
182 |
+
return "rewrite_question"
|
183 |
+
|
184 |
+
|
185 |
+
def rewrite_question(self,state):
|
186 |
+
question = state["question"]
|
187 |
+
a=time.time()
|
188 |
+
result = self.tools.rewrite_question(question=question)
|
189 |
+
print("rewrite_question:",time.time()-a)
|
190 |
+
return {"question": result}
|
191 |
+
|
192 |
+
def get_answer(self,question):
|
193 |
+
res = self.graph.invoke({"question":question})
|
194 |
+
return res['generation']
|
195 |
+
|
196 |
+
# messages = []
|
197 |
+
# while True:
|
198 |
+
# question = input("请输入问题:")
|
199 |
+
# messages.append(HumanMessage(content=question))
|
200 |
+
# res = graph.invoke({"question":messages.content})
|
201 |
+
# messages.append(AIMessage(content=res["output"]))
|
202 |
+
# print(res["output"])
|
travel/__pycache__/RAGGraph.cpython-39.pyc
ADDED
Binary file (5.53 kB). View file
|
|
travel/__pycache__/self_rag_tool.cpython-39.pyc
ADDED
Binary file (9.13 kB). View file
|
|
travel/database_generate.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, UnstructuredMarkdownLoader, CSVLoader
|
6 |
+
from langchain_core.documents import Document
|
7 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
8 |
+
from pydantic.v1 import Field, BaseModel
|
9 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
10 |
+
from langchain_text_splitters import CharacterTextSplitter, RecursiveJsonSplitter
|
11 |
+
from pymilvus import MilvusClient
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import AutoModel, AutoTokenizer
|
14 |
+
import torch
|
15 |
+
from pdfminer.high_level import extract_pages
|
16 |
+
from pdfminer.layout import LTTextContainer
|
17 |
+
from pypinyin import lazy_pinyin
|
18 |
+
import chromadb
|
19 |
+
from chromadb.config import Settings
|
20 |
+
from typing import Literal, TypedDict
|
21 |
+
# from project.config import base_url
|
22 |
+
|
23 |
+
# _ = load_dotenv("/Users/zhulang/work/llm/self_rag/.env")
|
24 |
+
class KnowledgeType(BaseModel):
|
25 |
+
"""
|
26 |
+
将用户查询路由到最相关的数据源
|
27 |
+
"""
|
28 |
+
route: Literal['澳门', '青海', '周庄', '上海', '天津', '黄果树', '黔东南', '九寨沟', '广西', '贵阳', '扬州', '济南', '香格里拉', '香港', '昆明', '宁波', '林芝', '台北', '三清山', '呼伦贝尔', '鼓浪屿', '婺源', '厦门', '张家界', '故宫', '北戴河', '西藏', '杭州', '大同', '泰山', '秦皇岛', '成都', '凤凰', '兰州', '华山', '浙江', '哈尔滨', '沈阳', '云台山', '福州', '甘南', '三亚', '长沙', '敦煌', '苏州', '青城山', '束河', '南宁', '乌镇', '镇江', '丽江', '西塘', '黄山', '平遥', '五台山', '连云港', '拉萨', '西双版纳', '峨眉山', '武夷山', '宏村', '衡山', '横店', '北海', '桂林', '山海关', '长岛', '太原', '大连', '高雄', '青海湖', '荔波', '野三坡', '蓬莱', '合肥', '绍兴', '云南', '同里', '南京', '青岛', '北疆', '千岛湖', '南昌', '武汉', '珠海', '镇远', '武当山', '重庆', '庐山', '大理', '海口', '康定', '长白山', '曲阜', '蜀南竹海', '常州', '新疆', '海螺沟', '都江堰', '北京', '无锡', '白洋淀', '纳木错', '西溪湿地', '普陀山', '川藏', '日照', '雁荡山', '威海', '深圳', '广州', '泸沽湖', '乌鲁木齐', '西安', '稻城亚丁', '惠州', '烟台', '洛阳', '四姑娘山', '舟山'] = Field(...,description="用户给定一个问题,选择最相关一个进行输出")
|
29 |
+
|
30 |
+
# 灌库
|
31 |
+
class ChatDoc(object):
|
32 |
+
|
33 |
+
def __init__(self):
|
34 |
+
self.loader = {
|
35 |
+
".pdf": PyPDFLoader,
|
36 |
+
".txt": Docx2txtLoader,
|
37 |
+
".docx": Docx2txtLoader,
|
38 |
+
".md": UnstructuredMarkdownLoader,
|
39 |
+
".csv": CSVLoader,
|
40 |
+
".json": self.handle_json,
|
41 |
+
}
|
42 |
+
|
43 |
+
self.txt_splitter = CharacterTextSplitter(chunk_size=240, chunk_overlap=30, length_function=len,
|
44 |
+
add_start_index=True)
|
45 |
+
self.json_splitter = RecursiveJsonSplitter(max_chunk_size=240)
|
46 |
+
self.embeding = embedding()
|
47 |
+
self.client = chromadb.PersistentClient(path="database/travel")
|
48 |
+
self.database_info = json.load(open("database/travel/info.json", "r", encoding="utf-8"))
|
49 |
+
self.llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")
|
50 |
+
|
51 |
+
def get_knowledge_type(self, query):
|
52 |
+
names = self.database_info["names"]
|
53 |
+
system_prompt = f"""
|
54 |
+
你是一名地区分类专家,主要分别判断以下类别的地区,有且仅有{','.join(names)}这{len(names)}类地区。识别准确后,返回给用户。
|
55 |
+
识别到澳门,返回'澳门',以此类推。
|
56 |
+
"""
|
57 |
+
grade_messages = [SystemMessage(content=system_prompt)]
|
58 |
+
grade_messages.append(HumanMessage(content=f"{query}"))
|
59 |
+
collection_name = self.llm.with_structured_output(KnowledgeType).invoke(grade_messages)
|
60 |
+
|
61 |
+
return ''.join(lazy_pinyin(collection_name.route))
|
62 |
+
def get_file(self, filename):
|
63 |
+
file_extension = os.path.splitext(filename)[-1]
|
64 |
+
loader = self.loader.get(file_extension, None)
|
65 |
+
if loader:
|
66 |
+
if file_extension == ".json":
|
67 |
+
return loader(filename)
|
68 |
+
else:
|
69 |
+
load_info = loader(filename).load()
|
70 |
+
return load_info
|
71 |
+
|
72 |
+
else:
|
73 |
+
return None
|
74 |
+
|
75 |
+
def handle_json(self, filename):
|
76 |
+
with open(filename, "r", encoding="utf-8") as f:
|
77 |
+
data = f.read()
|
78 |
+
return data
|
79 |
+
|
80 |
+
def is_json(self, data):
|
81 |
+
try:
|
82 |
+
json.loads(data)
|
83 |
+
return True
|
84 |
+
except:
|
85 |
+
return False
|
86 |
+
|
87 |
+
def split_text(self, filename):
|
88 |
+
load_info = self.get_file(filename)
|
89 |
+
if load_info:
|
90 |
+
if self.is_json(load_info):
|
91 |
+
self.end_splitter = self.json_splitter.split_text(json.loads(load_info), ensure_ascii=False)
|
92 |
+
else:
|
93 |
+
self.end_splitter = self.txt_splitter.split_documents(load_info)
|
94 |
+
|
95 |
+
return self.end_splitter
|
96 |
+
|
97 |
+
else:
|
98 |
+
return "文件格式不支持"
|
99 |
+
def emb_text(self, text):
|
100 |
+
return self.embeding.embed_query(text)
|
101 |
+
|
102 |
+
def vector_storage(self, filename):
|
103 |
+
data_name = self.pdf_to_pinyin(filename)
|
104 |
+
data = []
|
105 |
+
for idx, text in enumerate(tqdm(self.end_splitter, desc="向量化")):
|
106 |
+
if isinstance(text, Document):
|
107 |
+
text = text.page_content
|
108 |
+
data.append({"id": idx, "vector": self.emb_text(text), "text": text})
|
109 |
+
print(f"Collection name: {data_name}")
|
110 |
+
collection = self.client.get_or_create_collection(data_name)
|
111 |
+
collection.add(
|
112 |
+
ids=[str(item["id"]) for item in data],
|
113 |
+
embeddings=[item["vector"][0] for item in data],
|
114 |
+
documents=[item["text"] for item in data]
|
115 |
+
)
|
116 |
+
# self.milvus_client.create_collection(
|
117 |
+
# collection_name=data_name,
|
118 |
+
# dimension=768,
|
119 |
+
# metric_type="IP", # Inner product distance
|
120 |
+
# consistency_level="Strong", # Strong consistency level
|
121 |
+
# )
|
122 |
+
|
123 |
+
# self.milvus_client.insert(collection_name=data_name, data=data)
|
124 |
+
return "向量存储成功"
|
125 |
+
def pdf_to_pinyin(self,file):
|
126 |
+
name = os.path.basename(file).split('.')[0]
|
127 |
+
return ''.join(lazy_pinyin(name))
|
128 |
+
def combine(self,path):
|
129 |
+
data_name = 'travel'
|
130 |
+
data = []
|
131 |
+
for i in os.listdir(path):
|
132 |
+
self.split_text(os.path.join(path,i))
|
133 |
+
for idx, text in enumerate(tqdm(self.end_splitter, desc="向量化")):
|
134 |
+
if isinstance(text, Document):
|
135 |
+
text = text.page_content
|
136 |
+
data.append({"id": len(data), "vector": self.emb_text(text), "text": text})
|
137 |
+
print(f"Collection name: {data_name}")
|
138 |
+
collection = self.client.get_or_create_collection(data_name)
|
139 |
+
collection.add(
|
140 |
+
ids=[str(item["id"]) for item in data],
|
141 |
+
embeddings=[item["vector"][0] for item in data],
|
142 |
+
documents=[item["text"] for item in data]
|
143 |
+
)
|
144 |
+
def delete(self):
|
145 |
+
li = ['澳门', '青海', '周庄', '上海', '天津', '黄果树', '黔东南', '九寨沟', '广西', '贵阳', '扬州', '济南', '香格里拉', '香港', '昆明', '宁波', '林芝', '台北', '三清山', '呼伦贝尔', '鼓浪屿', '婺源', '厦门', '张家界', '故宫', '北戴河', '西藏', '杭州', '大同', '泰山', '秦皇岛', '成都', '凤凰', '兰州', '华山', '浙江', '哈尔滨', '沈阳', '云台山', '福州', '甘南', '三亚', '长沙', '敦煌', '苏州', '青城山', '束河', '南宁', '乌镇', '镇江', '丽江', '西塘', '黄山', '平遥', '五台山', '连云港', '拉萨', '西双版纳', '峨眉山', '武夷山', '宏村', '衡山', '横店', '北海', '桂林', '山海关', '长岛', '太原', '大连', '高雄', '青海湖', '荔波', '野三坡', '蓬莱', '合肥', '绍兴', '云南', '同里', '南京', '青岛', '北疆', '千岛湖', '南昌', '武汉', '珠海', '镇远', '武当山', '重庆', '庐山', '大理', '海口', '康定', '长白山', '曲阜', '蜀南竹海', '常州', '新疆', '海螺沟', '都江堰', '北京', '无锡', '白洋淀', '纳木错', '西溪湿地', '普陀山', '川藏', '日照', '雁荡山', '威海', '深圳', '广州', '泸沽湖', '乌鲁木齐', '西安', '稻城亚丁', '惠州', '烟台', '洛阳', '四姑娘山', '舟山']
|
146 |
+
for data_name in li:
|
147 |
+
data_name = ''.join(lazy_pinyin(data_name))
|
148 |
+
collection = self.client.delete_collection(data_name)
|
149 |
+
# collection.delete()
|
150 |
+
return "删除成功"
|
151 |
+
class embedding(object):
|
152 |
+
|
153 |
+
def __init__(self):
|
154 |
+
|
155 |
+
# init model and tokenizer
|
156 |
+
self.tokenizer = AutoTokenizer.from_pretrained('maidalun1020/bce-embedding-base_v1')
|
157 |
+
self.model = AutoModel.from_pretrained('maidalun1020/bce-embedding-base_v1')
|
158 |
+
|
159 |
+
self.device = 'cuda' # if no GPU, set "cpu"
|
160 |
+
self.model.to(self.device)
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
def embed_query(self, text):
|
165 |
+
# get inputs
|
166 |
+
inputs = self.tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt")
|
167 |
+
inputs_on_device = {k: v.to(self.device) for k, v in inputs.items()}
|
168 |
+
|
169 |
+
# get embeddings
|
170 |
+
outputs = self.model(**inputs_on_device, return_dict=True)
|
171 |
+
embeddings = outputs.last_hidden_state[:, 0] # cls pooler
|
172 |
+
embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) # normalize
|
173 |
+
return embeddings.tolist()
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
a=ChatDoc()
|
178 |
+
path = "database/travel/pdf"
|
179 |
+
# for i in os.listdir(path):
|
180 |
+
# if i.endswith(".pdf"):
|
181 |
+
# a.split_text(os.path.join(path,i))
|
182 |
+
# a.vector_storage(os.path.join(path,i))
|
183 |
+
print(a.client.list_collections())
|
travel/rag_tool.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
5 |
+
from langchain_openai import ChatOpenAI
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from langchain_community.document_loaders import Docx2txtLoader, CSVLoader
|
9 |
+
from langchain.text_splitter import CharacterTextSplitter
|
10 |
+
from pymilvus import MilvusClient
|
11 |
+
|
12 |
+
|
13 |
+
class RAGTool(object):
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
self.loader = {
|
17 |
+
".txt": Docx2txtLoader,
|
18 |
+
".docx": Docx2txtLoader,
|
19 |
+
".csv": CSVLoader,
|
20 |
+
}
|
21 |
+
self.milvus_client = MilvusClient(host="127.0.0.1", port="19530")
|
22 |
+
self.llm = ChatOpenAI(model="gpt-4o")
|
23 |
+
self.messages = [SystemMessage(
|
24 |
+
content="你是一个助手,请根据上下文回答问题,如果无法回答,请说“我不理解”,请尽量简要回答,与问题不相关的内容不用作为分析的内容。")]
|
25 |
+
|
26 |
+
def get_file(self, filename):
|
27 |
+
"""
|
28 |
+
获取文件
|
29 |
+
:param filename: 文件名
|
30 |
+
:return:
|
31 |
+
"""
|
32 |
+
file_type = os.path.splitext(filename)[-1]
|
33 |
+
if file_type in self.loader:
|
34 |
+
loader = self.loader[file_type]
|
35 |
+
loader = loader(filename)
|
36 |
+
return loader.load()
|
37 |
+
else:
|
38 |
+
return None
|
39 |
+
|
40 |
+
def split_sentences(self, filename):
|
41 |
+
"""
|
42 |
+
将文件分割成句子
|
43 |
+
:param filename: 文件名
|
44 |
+
:return:
|
45 |
+
"""
|
46 |
+
full_text = self.get_file(filename)
|
47 |
+
if full_text:
|
48 |
+
text_splitter = CharacterTextSplitter(chunk_size=240, chunk_overlap=30, add_start_index=True,
|
49 |
+
length_function=len)
|
50 |
+
text_split = text_splitter.split_documents(full_text)
|
51 |
+
return text_split
|
52 |
+
else:
|
53 |
+
return "文档格式不支持"
|
54 |
+
|
55 |
+
def emb_text(self, text):
|
56 |
+
"""
|
57 |
+
将文本向量化
|
58 |
+
:param text: 文本
|
59 |
+
:return:
|
60 |
+
"""
|
61 |
+
from langchain_openai import OpenAIEmbeddings
|
62 |
+
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
63 |
+
return embeddings.embed_query(text)
|
64 |
+
|
65 |
+
def vector_storage(self, filename):
|
66 |
+
text_split = self.split_sentences(filename)
|
67 |
+
data_vector = []
|
68 |
+
for idx, text in enumerate(tqdm(text_split, desc="Embedding")):
|
69 |
+
data_vector.append({
|
70 |
+
"id": idx,
|
71 |
+
"text": text.page_content,
|
72 |
+
"vector": self.emb_text(text.page_content)
|
73 |
+
})
|
74 |
+
|
75 |
+
self.milvus_client.create_collection(
|
76 |
+
collection_name="test_collection",
|
77 |
+
dimension=1536,
|
78 |
+
metric_type="IP",
|
79 |
+
consistency_level="Strong"
|
80 |
+
)
|
81 |
+
|
82 |
+
self.milvus_client.insert(collection_name="test_collection", data=data_vector)
|
83 |
+
return "success"
|
84 |
+
|
85 |
+
def query_data(self, query):
|
86 |
+
query_vector = self.emb_text(query)
|
87 |
+
result = self.milvus_client.search(
|
88 |
+
collection_name="test_collection",
|
89 |
+
data=[query_vector],
|
90 |
+
limit=3,
|
91 |
+
output_fields=["text"],
|
92 |
+
params={"metric_type": "IP"},
|
93 |
+
)
|
94 |
+
|
95 |
+
result_info = ""
|
96 |
+
for info in result[0]:
|
97 |
+
result_info += info["entity"]["text"]
|
98 |
+
|
99 |
+
return result_info
|
100 |
+
|
101 |
+
def get_answer(self, question):
|
102 |
+
"""
|
103 |
+
获取答案
|
104 |
+
:param question: 问题
|
105 |
+
:return:
|
106 |
+
"""
|
107 |
+
|
108 |
+
result = self.query_data(question)
|
109 |
+
|
110 |
+
self.messages.append(HumanMessage(content=f"问题:{question},检索内容:{result}"))
|
111 |
+
res = self.llm.invoke(self.messages)
|
112 |
+
self.messages.append(AIMessage(content=res.content))
|
113 |
+
return res.content
|
114 |
+
|
115 |
+
|
travel/self_rag_tool.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import chromadb
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from langchain.prompts import ChatPromptTemplate
|
6 |
+
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
7 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
8 |
+
from pydantic.v1 import BaseModel, Field
|
9 |
+
from pymilvus import MilvusClient
|
10 |
+
from typing import Literal
|
11 |
+
from transformers import AutoModel, AutoTokenizer
|
12 |
+
from pypinyin import lazy_pinyin
|
13 |
+
import json
|
14 |
+
|
15 |
+
class LocationType(BaseModel):
|
16 |
+
"""
|
17 |
+
将用户查询路由到最相关的数据源
|
18 |
+
"""
|
19 |
+
route: Literal['澳门', '青海', '周庄', '上海', '天津', '黄果树', '黔东南', '九寨沟', '广西', '贵阳', '扬州', '济南', '香格里拉', '香港', '昆明', '宁波', '林芝', '台北', '三清山', '呼伦贝尔', '鼓浪屿', '婺源', '厦门', '张家界', '故宫', '北戴河', '西藏', '杭州', '大同', '泰山', '秦皇岛', '成都', '凤凰', '兰州', '华山', '浙江', '哈尔滨', '沈阳', '云台山', '福州', '甘南', '三亚', '长沙', '敦煌', '苏州', '青城山', '束河', '南宁', '乌镇', '镇江', '丽江', '西塘', '黄山', '平遥', '五台山', '连云港', '拉萨', '西双版纳', '峨眉山', '武夷山', '宏村', '衡山', '横店', '北海', '桂林', '山海关', '长岛', '太原', '大连', '高雄', '青海湖', '荔波', '野三坡', '蓬莱', '合肥', '绍兴', '云南', '同里', '南京', '青岛', '北疆', '千岛湖', '南昌', '武汉', '珠海', '镇远', '武当山', '重庆', '庐山', '大理', '海口', '康定', '长白山', '曲阜', '蜀南竹海', '常州', '新疆', '海螺沟', '都江堰', '北京', '无锡', '白洋淀', '纳木错', '西溪湿地', '普陀山', '川藏', '日照', '雁荡山', '威海', '深圳', '广州', '泸沽湖', '乌鲁木齐', '西安', '稻城亚丁', '惠州', '烟台', '洛阳', '四姑娘山', '舟山'] = Field(...,description="用户给定一个问题,选择最相关一个进行输出")
|
20 |
+
|
21 |
+
|
22 |
+
class GradedRagTool(BaseModel):
|
23 |
+
"""
|
24 |
+
对检索到到文档进行相关性的检查,相关返回yes,不相关返回no
|
25 |
+
"""
|
26 |
+
|
27 |
+
binary_score: Literal['yes', 'no'] = Field(description="文档与问题的相关性,'yes' or 'no'")
|
28 |
+
|
29 |
+
|
30 |
+
class GradeHallucinations(BaseModel):
|
31 |
+
"""
|
32 |
+
对最终对回答进行一个判断,判断回答中是否存在幻觉,存在则输出yes,不存在这输出no
|
33 |
+
"""
|
34 |
+
|
35 |
+
binary_score: Literal['yes', 'no'] = Field(description="问题与回答的相关性,'yes' or 'no'")
|
36 |
+
|
37 |
+
|
38 |
+
class GradeAnswer(BaseModel):
|
39 |
+
"""对最终的回答于问题进行比对,判断回答和问题是相关的,是相关的则输出yes,不相关则输出no"""
|
40 |
+
|
41 |
+
binary_score: Literal['yes', 'no'] = Field(
|
42 |
+
description="问题与回答的相关性, 'yes' or 'no'"
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
class GradeAndGenerateTool(object):
|
47 |
+
|
48 |
+
def __init__(self, database_path="database/travel"):
|
49 |
+
self.llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")
|
50 |
+
# self.llm = ChatOpenAI(temperature=0, model="qwen2-14b",api_key="empty",base_url = "http://61.136.221.118:15001/v1")
|
51 |
+
self.struct_llm_grader = self.llm.with_structured_output(GradedRagTool)
|
52 |
+
self.struct_llm_hallucinations = self.llm.with_structured_output(GradeHallucinations)
|
53 |
+
self.struct_llm_answer = self.llm.with_structured_output(GradeAnswer)
|
54 |
+
self.embeding = embedding()
|
55 |
+
self.database_info = json.load(open(os.path.join(database_path,"info.json"), "r", encoding="utf-8"))
|
56 |
+
self.client = chromadb.PersistentClient(path=database_path)
|
57 |
+
|
58 |
+
# 评分
|
59 |
+
def grade(self, question, text):
|
60 |
+
system_prompt = """
|
61 |
+
你是一名评估检索到到文档与用户到问题相关性到评分员,不需要一个严格的测试,目标是过滤掉错误的检索。如果文档包含与用户问题相关的关键字或者语义,请评为相关,否则请评为不相关。你的回答只能是yes或者no
|
62 |
+
"""
|
63 |
+
grade_messages = [SystemMessage(content=system_prompt)]
|
64 |
+
grade_messages.append(HumanMessage(content=f"问题:{question}\n文档:{text}"))
|
65 |
+
result = self.struct_llm_grader.invoke(grade_messages)
|
66 |
+
return result.binary_score
|
67 |
+
|
68 |
+
# 生成答案
|
69 |
+
def generate(self, question, text):
|
70 |
+
grade_human_prompt = f"""您是问答任务的助理。使用以下检索到的上下文来回答问题。如果你不知道答案,就说你不知道。尽量将回答长度控制在三句话内,保持答案简洁。\n问题:{question}\n上下文:{text}\n答案:"""
|
71 |
+
human_prompt = ChatPromptTemplate.from_template(grade_human_prompt)
|
72 |
+
grade_human_prompt_end = human_prompt.format_messages(question=question, text=text)
|
73 |
+
result = self.llm.invoke(grade_human_prompt_end)
|
74 |
+
return result.content
|
75 |
+
|
76 |
+
# 判断是否有幻觉
|
77 |
+
def hallucinations(self, documents, answer):
|
78 |
+
hallucinations_prompt = "您是一名评估LLM生成是否基于一组检索到的事实的评分员。如果是基于���索到的事实回答则返回no,否则返回yes"
|
79 |
+
hallucinations_messages = [SystemMessage(content=hallucinations_prompt)]
|
80 |
+
hallucinations_messages.append(HumanMessage(content=f":回答:{answer}\n文档:{documents}"))
|
81 |
+
result = self.struct_llm_hallucinations.invoke(hallucinations_messages)
|
82 |
+
return result.binary_score
|
83 |
+
|
84 |
+
# 判断答案是否和问题相关
|
85 |
+
def answer_question(self, question, answer):
|
86 |
+
answer_question_prompt = """
|
87 |
+
你是一名评分员,评估答案是否解决了问题,如果解决了则返回yes,否则返回no
|
88 |
+
"""
|
89 |
+
answer_question_messages = [SystemMessage(content=answer_question_prompt)]
|
90 |
+
answer_question_messages.append(HumanMessage(content=f"问题:{question}\n回答:{answer}"))
|
91 |
+
result = self.struct_llm_answer.invoke(answer_question_messages)
|
92 |
+
return result.binary_score
|
93 |
+
|
94 |
+
# 复写问题
|
95 |
+
def rewrite_question(self, question):
|
96 |
+
rewrite_promtp = "您是一个将输入问题转换为优化的更好版本的问题重写器\n用于矢量库检索。查看输入并尝试推理潜在的语义意图/含义。"
|
97 |
+
rewrite_promtp_messages = [SystemMessage(content=rewrite_promtp)]
|
98 |
+
rewrite_promtp_messages.append(HumanMessage(content=f"问题:{question}"))
|
99 |
+
result = self.llm.invoke(rewrite_promtp_messages)
|
100 |
+
return result.content
|
101 |
+
|
102 |
+
def embed_dim(self, text):
|
103 |
+
return self.embeding.embed_query(text)
|
104 |
+
|
105 |
+
def get_knowledge_type(self, query):
|
106 |
+
names = self.database_info["names"]
|
107 |
+
system_prompt = f"""
|
108 |
+
你是一名地区分类专家,主要分别判断以下类别的地区,有且仅有{','.join(names)}这{len(names)}类地区。识别准确后,返回给用户。
|
109 |
+
识别到澳门,返回'澳门',以此类推。
|
110 |
+
"""
|
111 |
+
grade_messages = [SystemMessage(content=system_prompt)]
|
112 |
+
grade_messages.append(HumanMessage(content=f"{query}"))
|
113 |
+
collection_name = self.llm.with_structured_output(LocationType).invoke(grade_messages)
|
114 |
+
|
115 |
+
return ''.join(lazy_pinyin(collection_name.route))
|
116 |
+
# 检索
|
117 |
+
def search_vector(self, question):
|
118 |
+
a=time.time()
|
119 |
+
collection_name = 'travel'
|
120 |
+
|
121 |
+
result = self.client.get_collection(collection_name).query(
|
122 |
+
query_embeddings=[self.embed_dim(question)[0]],
|
123 |
+
n_results=3,
|
124 |
+
)
|
125 |
+
|
126 |
+
# result = self.milvus_client.search(collection_name="RAG_vector", data=[self.embed_dim(question)],
|
127 |
+
# output_fields=["text"])
|
128 |
+
return result['documents'][0]
|
129 |
+
|
130 |
+
class embedding(object):
|
131 |
+
|
132 |
+
def __init__(self):
|
133 |
+
|
134 |
+
# init model and tokenizer
|
135 |
+
self.tokenizer = AutoTokenizer.from_pretrained('maidalun1020/bce-embedding-base_v1')
|
136 |
+
self.model = AutoModel.from_pretrained('maidalun1020/bce-embedding-base_v1')
|
137 |
+
|
138 |
+
self.device = 'cuda' # if no GPU, set "cpu"
|
139 |
+
self.model.to(self.device)
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
def embed_query(self, text):
|
144 |
+
# get inputs
|
145 |
+
inputs = self.tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt")
|
146 |
+
inputs_on_device = {k: v.to(self.device) for k, v in inputs.items()}
|
147 |
+
|
148 |
+
# get embeddings
|
149 |
+
outputs = self.model(**inputs_on_device, return_dict=True)
|
150 |
+
embeddings = outputs.last_hidden_state[:, 0] # cls pooler
|
151 |
+
embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) # normalize
|
152 |
+
return embeddings.tolist()
|
153 |
+
|
154 |
+
|
155 |
+
# a=GradeAndGenerateTool()
|
156 |
+
# a.search_vector("澳门旅游指南")
|
157 |
+
|
158 |
+
|
travel/travel.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
举例:1、高德搜索附近的店 2、高德获取地点的经纬度。 3、RAG功能
|
3 |
+
|
4 |
+
"""
|
5 |
+
import functools
|
6 |
+
import operator
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
from typing import Type, TypedDict, Annotated, Sequence
|
10 |
+
from langchain_openai import ChatOpenAI
|
11 |
+
import aiohttp
|
12 |
+
import requests
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
from langchain_community.output_parsers.ernie_functions import JsonOutputFunctionsParser
|
15 |
+
from langchain_core.messages import HumanMessage, BaseMessage, SystemMessage, AIMessage
|
16 |
+
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
|
17 |
+
from langchain_core.tools import BaseTool
|
18 |
+
from langgraph.constants import END
|
19 |
+
from langgraph.graph import StateGraph
|
20 |
+
from pydantic.v1 import BaseModel, Field
|
21 |
+
import gradio as gr
|
22 |
+
from RAGGraph import RAGTool
|
23 |
+
from langchain.agents import create_openai_tools_agent, AgentExecutor
|
24 |
+
|
25 |
+
_ = load_dotenv()
|
26 |
+
|
27 |
+
rag_tool = RAGTool()
|
28 |
+
|
29 |
+
|
30 |
+
class searchAroundInput(BaseModel):
|
31 |
+
keyword: str = Field(..., description="搜索关键词")
|
32 |
+
location: str = Field(..., description="地点的经纬度")
|
33 |
+
|
34 |
+
|
35 |
+
class searchAround(BaseTool):
|
36 |
+
args_schema: Type[BaseModel] = searchAroundInput
|
37 |
+
description = "这是一个搜索周边信息的方法,需要用户提供关键词和地点的经纬度,才能进行周边信息的搜索。如果用户没有提供关键词或者地点的经纬度,则提示用户给出关键词和地点的经纬度并再进行周边信息的搜索。"
|
38 |
+
name = "searchAround"
|
39 |
+
|
40 |
+
def _run(self, keyword, location):
|
41 |
+
around_url = "https://restapi.amap.com/v5/place/around"
|
42 |
+
params = {
|
43 |
+
"key": "df8ff851968143fb413203f195fcd7d7",
|
44 |
+
"keywords": keyword,
|
45 |
+
"location": location
|
46 |
+
}
|
47 |
+
print("同步调用获取地点周边的方法")
|
48 |
+
res = requests.get(url=around_url, params=params)
|
49 |
+
# prompt = "请帮我整理以下内容中的名称,地址和距离,并按照地址与名称对应输出,且告诉距离多少米,内容:{}".format(
|
50 |
+
# res.json())
|
51 |
+
# result = llm.invoke(prompt)
|
52 |
+
return res.text
|
53 |
+
|
54 |
+
async def _arun(self, keyword, location):
|
55 |
+
async with aiohttp.ClientSession() as session:
|
56 |
+
around_url = "https://restapi.amap.com/v5/place/around"
|
57 |
+
params = {
|
58 |
+
"key": "df8ff851968143fb413203f195fcd7d7",
|
59 |
+
"keywords": keyword,
|
60 |
+
"location": location
|
61 |
+
}
|
62 |
+
print("异步调用获取地点周边的方法")
|
63 |
+
async with session.get(url=around_url, params=params) as response:
|
64 |
+
return await response.json()
|
65 |
+
|
66 |
+
|
67 |
+
class getLocationInput(BaseModel):
|
68 |
+
keyword: str = Field(..., description="搜索关键词")
|
69 |
+
|
70 |
+
|
71 |
+
class getLocation(BaseTool):
|
72 |
+
args_schema: Type[BaseModel] = getLocationInput
|
73 |
+
description = "这是一个获取地点的经纬度的方法,需要用户提供关键词,才能进行地点的经纬度的获取。如果用户没有提供关键词,则提示用户给出关键词并再进行地点的经纬度的获取。"
|
74 |
+
name = "getLocation"
|
75 |
+
|
76 |
+
def _run(self, keyword):
|
77 |
+
url = "https://restapi.amap.com/v5/place/text"
|
78 |
+
params = {
|
79 |
+
"key": "df8ff851968143fb413203f195fcd7d7",
|
80 |
+
"keywords": keyword,
|
81 |
+
}
|
82 |
+
res = requests.get(url=url, params=params)
|
83 |
+
print("同步调用获取地点的经纬度方法")
|
84 |
+
return '{}的经纬度是:'.format(keyword) + res.json()["pois"][0]["location"]
|
85 |
+
|
86 |
+
async def _arun(self, keyword):
|
87 |
+
async with aiohttp.ClientSession() as session:
|
88 |
+
url = "https://restapi.amap.com/v5/place/text"
|
89 |
+
params = {
|
90 |
+
"key": "df8ff851968143fb413203f195fcd7d7",
|
91 |
+
"keywords": keyword,
|
92 |
+
}
|
93 |
+
print("异步调用获取地点的经纬度方法")
|
94 |
+
async with session.get(url=url, params=params) as response:
|
95 |
+
res = await response.json()
|
96 |
+
return '{}的经纬度是:'.format(keyword) + res["pois"][0]["location"]
|
97 |
+
|
98 |
+
|
99 |
+
class ragToolInput(BaseModel):
|
100 |
+
question: str = Field(..., description="用户的问题")
|
101 |
+
|
102 |
+
class ragTool(BaseTool):
|
103 |
+
args_schema: Type[BaseModel] = ragToolInput
|
104 |
+
description = "这是一个RAG工具,可以提供中国国内旅游的相关指南攻略等信息。"
|
105 |
+
name = "ragTool"
|
106 |
+
|
107 |
+
def _run(self, question):
|
108 |
+
a=time.time()
|
109 |
+
return rag_tool.get_answer(question)
|
110 |
+
def create_agent(llm, tools, system_prompt):
|
111 |
+
prompt = ChatPromptTemplate.from_messages([
|
112 |
+
("system", system_prompt), MessagesPlaceholder(variable_name="messages"),
|
113 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
114 |
+
])
|
115 |
+
agent = create_openai_tools_agent(llm, tools, prompt)
|
116 |
+
executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
117 |
+
return executor
|
118 |
+
|
119 |
+
|
120 |
+
def agent_node(state, agent, name):
|
121 |
+
result = agent.invoke(state)
|
122 |
+
return {
|
123 |
+
"messages": [HumanMessage(content=result["output"], name=name)]
|
124 |
+
}
|
125 |
+
def chat(message, history=[]):
|
126 |
+
history_message.append(HumanMessage(content=message))
|
127 |
+
a=time.time()
|
128 |
+
res = graph.invoke({"messages":history_message})
|
129 |
+
res = res['messages'][-1].content
|
130 |
+
history_message.append(AIMessage(content=res))
|
131 |
+
print(time.time()-a)
|
132 |
+
print(res)
|
133 |
+
return res
|
134 |
+
|
135 |
+
|
136 |
+
# 1.封装agent,其中包含特有的工具,系统提示,agent的执行器
|
137 |
+
# 2.封装node,node是一个函数,在函数中会对agent进行调用
|
138 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
139 |
+
supervisor_llm = ChatOpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY4"))
|
140 |
+
get_location_agent = create_agent(llm=llm, tools=[getLocation()],
|
141 |
+
system_prompt="你是一个获取地点经纬度的助手,当用户需要获取经纬度时,你需要准确提供经纬度信息")
|
142 |
+
get_location_node = functools.partial(agent_node, agent=get_location_agent, name="get_location_agent")
|
143 |
+
search_around_agent = create_agent(llm=llm, tools=[searchAround()],
|
144 |
+
system_prompt="你是一个地图通,你能够根据提供的经纬度去搜索周边店面信息。并返回给用户")
|
145 |
+
search_around_node = functools.partial(agent_node, agent=search_around_agent, name="search_around_agent")
|
146 |
+
rag_agent = create_agent(llm=llm, tools=[ragTool()],
|
147 |
+
system_prompt="你是一个RAG工具,主要是对于旅游相关内容的指南攻略等信息。")
|
148 |
+
rag_node = functools.partial(agent_node, agent=rag_agent, name="rag_agent")
|
149 |
+
|
150 |
+
member = ["search_around_agent", "get_location_agent", 'rag_agent']
|
151 |
+
|
152 |
+
system_prompt = f"""
|
153 |
+
你是一名任务管理者,负责管理任务的调度,下面是你的工作者{member},给定以下请求,与工作者一起响应,并采取下一步行动。
|
154 |
+
每个工作者将执行一个任务并回复执行后的结果和状态,若已经完成后,用FINISH回应。
|
155 |
+
"""
|
156 |
+
|
157 |
+
options = member + ["FINISH"]
|
158 |
+
|
159 |
+
function_def = {
|
160 |
+
"name": "route",
|
161 |
+
"description": "选择下一个工作者",
|
162 |
+
"parameters": {
|
163 |
+
"title": "routeSchema",
|
164 |
+
"type": "object",
|
165 |
+
"properties": {
|
166 |
+
"next": {
|
167 |
+
"title": "Next",
|
168 |
+
"anyOf": [
|
169 |
+
{
|
170 |
+
"enum": options
|
171 |
+
}
|
172 |
+
],
|
173 |
+
}
|
174 |
+
},
|
175 |
+
"required": ["next"]
|
176 |
+
}
|
177 |
+
}
|
178 |
+
|
179 |
+
prompt = ChatPromptTemplate.from_messages(
|
180 |
+
[("system",system_prompt),MessagesPlaceholder(variable_name="messages"),
|
181 |
+
("system",f"基于上述的对话接下来应该是谁来采取行动?请在以下选项中进行选择{options},"
|
182 |
+
"如果你认为最后一句能够对以上对话形成较好回复比如完成了打招呼,请尽量选择'FINISH',"
|
183 |
+
"如果问题仍未解决且问题与旅游相关请选择'rag_agent',"
|
184 |
+
"如果问题仍未解决且问题与获取地点经纬度相关请选择'get_location_agent',"
|
185 |
+
"如果问题仍未解决且问题与地点周边信息相关请选择'search_around_agent',")]
|
186 |
+
# "如果问题仍未解决且问题是一般性问题请选择'normal_chat'。")]
|
187 |
+
).partial(options=str(options),member=",".join(member))
|
188 |
+
|
189 |
+
|
190 |
+
supervisor_chain = prompt | supervisor_llm.bind_functions(functions=[function_def],function_call="route") | JsonOutputFunctionsParser()
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
class AgentState(TypedDict):
|
195 |
+
messages : Annotated[Sequence[BaseMessage],operator.add]
|
196 |
+
next : str
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
work_flow = StateGraph(AgentState)
|
201 |
+
|
202 |
+
work_flow.add_node("get_location_agent",get_location_node)
|
203 |
+
work_flow.add_node("search_around_agent",search_around_node)
|
204 |
+
work_flow.add_node("rag_agent",rag_node)
|
205 |
+
work_flow.add_node("supervisor",supervisor_chain)
|
206 |
+
|
207 |
+
for name in member:
|
208 |
+
work_flow.add_edge(name,"supervisor")
|
209 |
+
|
210 |
+
conditional_map = {
|
211 |
+
"get_location_agent":"get_location_agent",
|
212 |
+
"search_around_agent":"search_around_agent",
|
213 |
+
"rag_agent":"rag_agent",
|
214 |
+
"FINISH":END,
|
215 |
+
}
|
216 |
+
|
217 |
+
work_flow.add_conditional_edges("supervisor",lambda x : x["next"],conditional_map)
|
218 |
+
|
219 |
+
work_flow.set_entry_point("supervisor")
|
220 |
+
|
221 |
+
graph = work_flow.compile()
|
222 |
+
history_message = []
|
223 |
+
chat("北京的旅游攻略")
|
224 |
+
|
225 |
+
# iface_chat_file = gr.ChatInterface(
|
226 |
+
# fn=chat,
|
227 |
+
# examples=['北京的经纬度是多少', '天安门附近的餐馆有哪些','北京的旅游攻略'],
|
228 |
+
# title="Chat File Interface",
|
229 |
+
# )
|
230 |
+
# iface_chat_file.launch(share=True, server_name='0.0.0.0', server_port=5001)
|
231 |
+
# a=time.time()
|
232 |
+
# res = graph.invoke({"messages":[HumanMessage(content="天安门附近的餐馆有哪些")]})
|
233 |
+
# print(time.time()-a)
|
234 |
+
# print(res)
|
235 |
+
|
236 |
+
|
237 |
+
|
travel/travel_new.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
举例:1、高德搜索附近的店 2、高德获取地点的经纬度。 3、RAG功能
|
3 |
+
|
4 |
+
"""
|
5 |
+
import functools
|
6 |
+
import operator
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
from typing import Type, TypedDict, Annotated, Sequence
|
10 |
+
|
11 |
+
import aiohttp
|
12 |
+
import requests
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
from langchain_community.output_parsers.ernie_functions import JsonOutputFunctionsParser
|
15 |
+
from langchain_core.messages import HumanMessage, BaseMessage, SystemMessage, AIMessage
|
16 |
+
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
|
17 |
+
from langchain_core.tools import BaseTool
|
18 |
+
from langgraph.constants import END
|
19 |
+
from langgraph.graph import StateGraph
|
20 |
+
from pydantic.v1 import BaseModel, Field
|
21 |
+
from typing import Literal
|
22 |
+
from RAGGraph import RAGTool
|
23 |
+
from langchain.agents import create_openai_tools_agent, AgentExecutor
|
24 |
+
|
25 |
+
import gradio as gr
|
26 |
+
from langchain_openai import ChatOpenAI
|
27 |
+
# 1.封装agent,其中包含特有的工具,系统提示,agent的执行器
|
28 |
+
# 2.封装node,node是一个函数,在函数中会对agent进行调用
|
29 |
+
load_dotenv()
|
30 |
+
rag_tool = RAGTool()
|
31 |
+
|
32 |
+
class AgentType(BaseModel):
|
33 |
+
"""
|
34 |
+
将用户查询路由到最相关的数据源
|
35 |
+
"""
|
36 |
+
route: Literal["search_around_agent", "get_location_agent", 'rag_agent', 'FINISH', 'normal_chat'] = Field(...,description="用户给定一个问题,选择最相关一个进行输出")
|
37 |
+
arguments: dict | None = Field(None, description="给选定的代理的参数")
|
38 |
+
|
39 |
+
class AgentState(TypedDict):
|
40 |
+
messages : Annotated[Sequence[BaseMessage],operator.add]
|
41 |
+
next : str
|
42 |
+
|
43 |
+
|
44 |
+
class searchAroundInput(BaseModel):
|
45 |
+
keyword: str = Field(..., description="搜索关键词")
|
46 |
+
location: str = Field(..., description="地点的经纬度")
|
47 |
+
|
48 |
+
|
49 |
+
class searchAround(BaseTool):
|
50 |
+
args_schema: Type[BaseModel] = searchAroundInput
|
51 |
+
description = "这是一个搜索周边信息的方法,需要用户提供关键词和地点的经纬度,才能进行周边信息的搜索。如果用户没有提供关键词或者地点的经纬度,则提示用户给出关键词和地点的经纬度并再进行周边信息的搜索。"
|
52 |
+
name = "searchAround"
|
53 |
+
|
54 |
+
def _run(self, keyword, location):
|
55 |
+
around_url = "https://restapi.amap.com/v5/place/around"
|
56 |
+
params = {
|
57 |
+
"key": "df8ff851968143fb413203f195fcd7d7",
|
58 |
+
"keywords": keyword,
|
59 |
+
"location": location
|
60 |
+
}
|
61 |
+
print("同步调用获取地点周边的方法")
|
62 |
+
res = requests.get(url=around_url, params=params)
|
63 |
+
# prompt = "请帮我整理以下内容中的名称,地址和距离,并按照地址与名称对应输出,且告诉距离多少米,内容:{}".format(
|
64 |
+
# res.json())
|
65 |
+
# result = llm.invoke(prompt)
|
66 |
+
return res.text
|
67 |
+
|
68 |
+
async def _arun(self, keyword, location):
|
69 |
+
async with aiohttp.ClientSession() as session:
|
70 |
+
around_url = "https://restapi.amap.com/v5/place/around"
|
71 |
+
params = {
|
72 |
+
"key": "df8ff851968143fb413203f195fcd7d7",
|
73 |
+
"keywords": keyword,
|
74 |
+
"location": location
|
75 |
+
}
|
76 |
+
print("异步调用获取地点周边的方法")
|
77 |
+
async with session.get(url=around_url, params=params) as response:
|
78 |
+
return await response.json()
|
79 |
+
|
80 |
+
|
81 |
+
class getLocationInput(BaseModel):
|
82 |
+
keyword: str = Field(..., description="搜索关键词")
|
83 |
+
|
84 |
+
|
85 |
+
class getLocation(BaseTool):
|
86 |
+
args_schema: Type[BaseModel] = getLocationInput
|
87 |
+
description = "这是一个获取地点的经纬度的方法,需要用户提供关键词,才能进行地点的经纬度的获取。如果用户没有提供关键词,则提示用户给出关键词并再进行地点的经纬度的获取。"
|
88 |
+
name = "getLocation"
|
89 |
+
|
90 |
+
def _run(self, keyword):
|
91 |
+
url = "https://restapi.amap.com/v5/place/text"
|
92 |
+
params = {
|
93 |
+
"key": "df8ff851968143fb413203f195fcd7d7",
|
94 |
+
"keywords": keyword,
|
95 |
+
}
|
96 |
+
res = requests.get(url=url, params=params)
|
97 |
+
print("同步调用获取地点的经纬度方法")
|
98 |
+
return '{}的经纬度是:'.format(keyword) + res.json()["pois"][0]["location"]
|
99 |
+
|
100 |
+
async def _arun(self, keyword):
|
101 |
+
async with aiohttp.ClientSession() as session:
|
102 |
+
url = "https://restapi.amap.com/v5/place/text"
|
103 |
+
params = {
|
104 |
+
"key": "df8ff851968143fb413203f195fcd7d7",
|
105 |
+
"keywords": keyword,
|
106 |
+
}
|
107 |
+
print("异步调用获取地点的经纬度方法")
|
108 |
+
async with session.get(url=url, params=params) as response:
|
109 |
+
res = await response.json()
|
110 |
+
return '{}的经纬度是:'.format(keyword) + res["pois"][0]["location"]
|
111 |
+
|
112 |
+
class ragToolInput(BaseModel):
|
113 |
+
question: str = Field(..., description="用户的问题")
|
114 |
+
|
115 |
+
class ragTool(BaseTool):
|
116 |
+
args_schema: Type[BaseModel] = ragToolInput
|
117 |
+
description = "这是一个RAG工具,可以提供中国国内旅游的相关指南攻略等信息。"
|
118 |
+
name = "ragTool"
|
119 |
+
|
120 |
+
def _run(self, question):
|
121 |
+
a=time.time()
|
122 |
+
res = rag_tool.get_answer(question)
|
123 |
+
print('ragTool',time.time()-a)
|
124 |
+
return res
|
125 |
+
def create_agent(llm, tools, system_prompt):
|
126 |
+
prompt = ChatPromptTemplate.from_messages([
|
127 |
+
("system", system_prompt), MessagesPlaceholder(variable_name="messages"),
|
128 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
129 |
+
])
|
130 |
+
agent = create_openai_tools_agent(llm, tools, prompt)
|
131 |
+
executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
132 |
+
return executor
|
133 |
+
|
134 |
+
|
135 |
+
def agent_node(state, agent, name):
|
136 |
+
result = agent.invoke(state)
|
137 |
+
return {
|
138 |
+
"messages": [HumanMessage(content=result["output"], name=name)]
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
def supervisor(state):
|
143 |
+
message = state['messages']
|
144 |
+
prompt = [SystemMessage(content=(
|
145 |
+
f"基于上述的对话接下来应该是谁来采取行动?请在以下选项中进行选择{options},"
|
146 |
+
"如果你认为最后一句能够对以上对话形成较好回复比如完成了打招呼,请尽量选择'FINISH',"
|
147 |
+
"如果问题仍未解决且问题与旅游相关请选择'rag_agent',"
|
148 |
+
"如果问题仍未解决且问题与获取地点经纬度相关请选择'get_location_agent',"
|
149 |
+
"如果问题仍未解决且问题与地点周边信息相关请选择'search_around_agent',"
|
150 |
+
"如果问题仍未解决且问题是一般性问题请选择'normal_chat'。"
|
151 |
+
))]
|
152 |
+
a=time.time()
|
153 |
+
res = supervisor_llm.with_structured_output(AgentType).invoke(message+prompt)
|
154 |
+
print('supervisor',time.time()-a)
|
155 |
+
return {'next': res.route}
|
156 |
+
# if res.route == "FINISH":
|
157 |
+
# state['messages'].append(SystemMessage(content="任务结束"))
|
158 |
+
# state['next'] = "FINISH"
|
159 |
+
# return state
|
160 |
+
# else:
|
161 |
+
# state['next'] = res.route
|
162 |
+
# return state
|
163 |
+
|
164 |
+
def noraml_chat(state,name):
|
165 |
+
message = state['messages']
|
166 |
+
result = llm.invoke(message)
|
167 |
+
return {
|
168 |
+
"messages": [HumanMessage(content=result.content, name=name)]
|
169 |
+
}
|
170 |
+
def chat(message, history=[]):
|
171 |
+
|
172 |
+
history_message.append(HumanMessage(content=message))
|
173 |
+
res = graph.invoke({"messages":history_message})
|
174 |
+
res = res['messages'][-1].content
|
175 |
+
history_message.append(AIMessage(content=res))
|
176 |
+
print(res)
|
177 |
+
return res
|
178 |
+
|
179 |
+
if __name__ == '__main__':
|
180 |
+
llm = ChatOpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY4"))
|
181 |
+
supervisor_llm = ChatOpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY4"))
|
182 |
+
get_location_agent = create_agent(llm=llm, tools=[getLocation()],
|
183 |
+
system_prompt="你是一个获取地点经纬度的助手,当用户需要获取经纬度时,你需要准确提供经纬度信息")
|
184 |
+
get_location_node = functools.partial(agent_node, agent=get_location_agent, name="get_location_agent")
|
185 |
+
search_around_agent = create_agent(llm=llm, tools=[searchAround()],
|
186 |
+
system_prompt="你是一个地图通,你能够根据提供的经纬度去搜索周边店面信息。并返回给用户")
|
187 |
+
search_around_node = functools.partial(agent_node, agent=search_around_agent, name="search_around_agent")
|
188 |
+
rag_agent = create_agent(llm=llm, tools=[ragTool()],
|
189 |
+
system_prompt="你是一个RAG工具,主要是对于旅游相关内容的指南攻略等信息。")
|
190 |
+
rag_node = functools.partial(agent_node, agent=rag_agent, name="rag_agent")
|
191 |
+
# normal_agent = create_agent(llm=llm, tools=[],
|
192 |
+
# system_prompt="你能回答一些不需要联网搜索的一般性的问题")
|
193 |
+
normal_node = functools.partial(noraml_chat, name="normal_chat")
|
194 |
+
member = ["search_around_agent", "get_location_agent", 'rag_agent', 'normal_chat']
|
195 |
+
|
196 |
+
options = member + ["FINISH"]
|
197 |
+
|
198 |
+
work_flow = StateGraph(AgentState)
|
199 |
+
|
200 |
+
work_flow.add_node("get_location_agent",get_location_node)
|
201 |
+
work_flow.add_node("search_around_agent",search_around_node)
|
202 |
+
work_flow.add_node("rag_agent",rag_node)
|
203 |
+
work_flow.add_node("normal_chat",normal_node)
|
204 |
+
work_flow.add_node("supervisor",supervisor)
|
205 |
+
|
206 |
+
for name in member:
|
207 |
+
work_flow.add_edge(name,"supervisor")
|
208 |
+
|
209 |
+
conditional_map = {
|
210 |
+
"get_location_agent":"get_location_agent",
|
211 |
+
"search_around_agent":"search_around_agent",
|
212 |
+
"rag_agent":"rag_agent",
|
213 |
+
"normal_chat":"normal_chat",
|
214 |
+
"FINISH":END,
|
215 |
+
}
|
216 |
+
|
217 |
+
work_flow.add_conditional_edges("supervisor",lambda x : x["next"],conditional_map)
|
218 |
+
|
219 |
+
work_flow.set_entry_point("supervisor")
|
220 |
+
|
221 |
+
graph = work_flow.compile()
|
222 |
+
|
223 |
+
system_prompt = f"""
|
224 |
+
你是一名任务管理者,负责管理任务的调度,下面是你的工作者{member},给定以下请求,与工作者一起响应,并采取下一步行动。
|
225 |
+
每个工作者将执行一个任务并回复执行后的结果和状态,若对话中最后一句能够形成较好回复比如完成了打招呼,请选择'FINISH'。
|
226 |
+
"""
|
227 |
+
history_message = [SystemMessage(content=system_prompt)]
|
228 |
+
# chat('hello')
|
229 |
+
|
230 |
+
|
231 |
+
iface_chat_file = gr.ChatInterface(
|
232 |
+
fn=chat,
|
233 |
+
examples=['北京的经纬度是多少', '天安门附近的餐馆有哪些','北京的旅游攻略'],
|
234 |
+
title="Chat File Interface",
|
235 |
+
)
|
236 |
+
iface_chat_file.launch(share=True, server_name='0.0.0.0', server_port=5001)
|