fakeym commited on
Commit
fd22a9b
1 Parent(s): d661683

Upload 8 files

Browse files
travel/RAGGraph.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List, TypedDict, Type, Any, Annotated
3
+
4
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, AnyMessage
5
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
+ from langchain_core.tools import BaseTool
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain.agents import create_openai_tools_agent,AgentExecutor
9
+ from langgraph.constants import END
10
+ from langgraph.graph import StateGraph,MessagesState
11
+ from pydantic.v1 import BaseModel, Field
12
+ import concurrent.futures
13
+ from self_rag_tool import GradeAndGenerateTool
14
+
15
+
16
+ class CreateLangGraphState(TypedDict):
17
+ question: str
18
+ generation: str
19
+ documents: List[str] # 检索后的信息,或者通过筛选后的信息
20
+
21
+ class self_RAGTool(object):
22
+ def __init__(self) -> None:
23
+ self.tools = GradeAndGenerateTool()
24
+ self.workflow = StateGraph(CreateLangGraphState)
25
+
26
+ self.workflow.add_node("retrieve", self.retrieve)
27
+ self.workflow.add_node("grade_documents", self.grade_documents)
28
+ self.workflow.add_node("generate_llm", self.generate_llm)
29
+ self.workflow.add_node("rewrite_question", self.rewrite_question)
30
+
31
+ self.workflow.set_entry_point("retrieve")
32
+ self.workflow.add_edge("retrieve", "grade_documents")
33
+ self.workflow.add_edge("grade_documents", "generate_llm")
34
+ self.workflow.add_conditional_edges("generate_llm", self.hallucinations_generate,
35
+ {"generate_llm": "generate_llm", "rewrite_question": "rewrite_question", "useful": END})
36
+
37
+ self.workflow.add_edge("rewrite_question", "retrieve")
38
+
39
+ self.graph = self.workflow.compile()
40
+ def retrieve(self,state):
41
+ question = state["question"]
42
+ a=time.time()
43
+ documents = self.tools.search_vector(question)
44
+ print("retrieve:",time.time()-a)
45
+ # result_documents = []
46
+ # for info in documents[0]:
47
+ # result_documents.append(info["entity"]["text"])
48
+ return {
49
+ "documents": documents,
50
+ "question": question,
51
+ }
52
+
53
+
54
+ def grade_documents(self,state):
55
+ question = state["question"]
56
+ documents = state["documents"]
57
+ result_documents = []
58
+ a=time.time()
59
+ for info in documents:
60
+ # 传入问题,并通过大模型判断当前文档是否与问题相关
61
+ # 如果是yes,则加入result_documents,否则丢弃
62
+ result = self.tools.grade(question=question, text=info)
63
+ if result == "yes":
64
+ result_documents.append(info)
65
+ else:
66
+ continue
67
+ print("grade_documents:",time.time()-a)
68
+ return {"question": question, "documents": result_documents}
69
+
70
+
71
+ def generate_llm(self,state):
72
+ question = state["question"]
73
+ documents = state["documents"]
74
+ documents_str = "\n".join(documents).replace("{", "").replace("}", "")
75
+ a=time.time()
76
+ result = self.tools.generate(question=question, text=documents_str)
77
+ print("generate_llm:",time.time()-a)
78
+ return {"question": question, "generation": result, "documents": documents}
79
+
80
+
81
+ def hallucinations_generate(self,state):
82
+ print("调用幻觉判断的方法")
83
+ question = state["question"]
84
+ generation = state["generation"]
85
+ documents = state["documents"]
86
+ documents_str = "\n".join(documents)
87
+ a=time.time()
88
+ result = self.tools.hallucinations(documents=documents_str, answer=generation)
89
+ print("hallucinations_generate:",time.time()-a)
90
+ if result == "yes":
91
+ return "generate_llm"
92
+ else:
93
+ generation = self.tools.answer_question(question=question, answer=generation)
94
+ if generation == "yes":
95
+ return "useful"
96
+ else:
97
+ return "rewrite_question"
98
+
99
+
100
+ def rewrite_question(self,state):
101
+ question = state["question"]
102
+ a=time.time()
103
+ result = self.tools.rewrite_question(question=question)
104
+ print("rewrite_question:",time.time()-a)
105
+ return {"question": result}
106
+
107
+ def get_answer(self,question):
108
+ res = self.graph.invoke({"question":question})
109
+ return res['generation']
110
+
111
+ class RAGTool(object):
112
+ def __init__(self) -> None:
113
+ self.tools = GradeAndGenerateTool()
114
+ self.workflow = StateGraph(CreateLangGraphState)
115
+
116
+ self.workflow.add_node("retrieve", self.retrieve)
117
+ self.workflow.add_node("generate_llm", self.generate_llm)
118
+
119
+ self.workflow.set_entry_point("retrieve")
120
+ self.workflow.add_edge("retrieve", "generate_llm")
121
+
122
+ self.workflow.add_edge("generate_llm", END)
123
+
124
+ self.graph = self.workflow.compile()
125
+ def retrieve(self,state):
126
+ question = state["question"]
127
+ a=time.time()
128
+ documents = self.tools.search_vector(question)
129
+ print("retrieve:",time.time()-a)
130
+ # result_documents = []
131
+ # for info in documents[0]:
132
+ # result_documents.append(info["entity"]["text"])
133
+ return {
134
+ "documents": documents,
135
+ "question": question,
136
+ }
137
+
138
+
139
+ def grade_documents(self,state):
140
+ question = state["question"]
141
+ documents = state["documents"]
142
+ result_documents = []
143
+ a=time.time()
144
+ for info in documents:
145
+ # 传入问题,并通过大模型判断当前文档是否与问题相关
146
+ # 如果是yes,则加入result_documents,否则丢弃
147
+ result = self.tools.grade(question=question, text=info)
148
+ if result == "yes":
149
+ result_documents.append(info)
150
+ else:
151
+ continue
152
+ print("grade_documents:",time.time()-a)
153
+ return {"question": question, "documents": result_documents}
154
+
155
+
156
+ def generate_llm(self,state):
157
+ question = state["question"]
158
+ documents = state["documents"]
159
+ documents_str = "\n".join(documents)
160
+ a=time.time()
161
+ result = self.tools.generate(question=question, text=documents_str)
162
+ print("generate_llm:",time.time()-a)
163
+ return {"question": question, "generation": result, "documents": documents}
164
+
165
+
166
+ def hallucinations_generate(self,state):
167
+ print("调用幻觉判断的方法")
168
+ question = state["question"]
169
+ generation = state["generation"]
170
+ documents = state["documents"]
171
+ documents_str = "\n".join(documents)
172
+ a=time.time()
173
+ result = self.tools.hallucinations(documents=documents_str, answer=generation)
174
+ print("hallucinations_generate:",time.time()-a)
175
+ if result == "yes":
176
+ return "generate_llm"
177
+ else:
178
+ generation = self.tools.answer_question(question=question, answer=generation)
179
+ if generation == "yes":
180
+ return "useful"
181
+ else:
182
+ return "rewrite_question"
183
+
184
+
185
+ def rewrite_question(self,state):
186
+ question = state["question"]
187
+ a=time.time()
188
+ result = self.tools.rewrite_question(question=question)
189
+ print("rewrite_question:",time.time()-a)
190
+ return {"question": result}
191
+
192
+ def get_answer(self,question):
193
+ res = self.graph.invoke({"question":question})
194
+ return res['generation']
195
+
196
+ # messages = []
197
+ # while True:
198
+ # question = input("请输入问题:")
199
+ # messages.append(HumanMessage(content=question))
200
+ # res = graph.invoke({"question":messages.content})
201
+ # messages.append(AIMessage(content=res["output"]))
202
+ # print(res["output"])
travel/__pycache__/RAGGraph.cpython-39.pyc ADDED
Binary file (5.53 kB). View file
 
travel/__pycache__/self_rag_tool.cpython-39.pyc ADDED
Binary file (9.13 kB). View file
 
travel/database_generate.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ from dotenv import load_dotenv
5
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, UnstructuredMarkdownLoader, CSVLoader
6
+ from langchain_core.documents import Document
7
+ from langchain_core.messages import SystemMessage, HumanMessage
8
+ from pydantic.v1 import Field, BaseModel
9
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
10
+ from langchain_text_splitters import CharacterTextSplitter, RecursiveJsonSplitter
11
+ from pymilvus import MilvusClient
12
+ from tqdm import tqdm
13
+ from transformers import AutoModel, AutoTokenizer
14
+ import torch
15
+ from pdfminer.high_level import extract_pages
16
+ from pdfminer.layout import LTTextContainer
17
+ from pypinyin import lazy_pinyin
18
+ import chromadb
19
+ from chromadb.config import Settings
20
+ from typing import Literal, TypedDict
21
+ # from project.config import base_url
22
+
23
+ # _ = load_dotenv("/Users/zhulang/work/llm/self_rag/.env")
24
+ class KnowledgeType(BaseModel):
25
+ """
26
+ 将用户查询路由到最相关的数据源
27
+ """
28
+ route: Literal['澳门', '青海', '周庄', '上海', '天津', '黄果树', '黔东南', '九寨沟', '广西', '贵阳', '扬州', '济南', '香格里拉', '香港', '昆明', '宁波', '林芝', '台北', '三清山', '呼伦贝尔', '鼓浪屿', '婺源', '厦门', '张家界', '故宫', '北戴河', '西藏', '杭州', '大同', '泰山', '秦皇岛', '成都', '凤凰', '兰州', '华山', '浙江', '哈尔滨', '沈阳', '云台山', '福州', '甘南', '三亚', '长沙', '敦煌', '苏州', '青城山', '束河', '南宁', '乌镇', '镇江', '丽江', '西塘', '黄山', '平遥', '五台山', '连云港', '拉萨', '西双版纳', '峨眉山', '武夷山', '宏村', '衡山', '横店', '北海', '桂林', '山海关', '长岛', '太原', '大连', '高雄', '青海湖', '荔波', '野三坡', '蓬莱', '合肥', '绍兴', '云南', '同里', '南京', '青岛', '北疆', '千岛湖', '南昌', '武汉', '珠海', '镇远', '武当山', '重庆', '庐山', '大理', '海口', '康定', '长白山', '曲阜', '蜀南竹海', '常州', '新疆', '海螺沟', '都江堰', '北京', '无锡', '白洋淀', '纳木错', '西溪湿地', '普陀山', '川藏', '日照', '雁荡山', '威海', '深圳', '广州', '泸沽湖', '乌鲁木齐', '西安', '稻城亚丁', '惠州', '烟台', '洛阳', '四姑娘山', '舟山'] = Field(...,description="用户给定一个问题,选择最相关一个进行输出")
29
+
30
+ # 灌库
31
+ class ChatDoc(object):
32
+
33
+ def __init__(self):
34
+ self.loader = {
35
+ ".pdf": PyPDFLoader,
36
+ ".txt": Docx2txtLoader,
37
+ ".docx": Docx2txtLoader,
38
+ ".md": UnstructuredMarkdownLoader,
39
+ ".csv": CSVLoader,
40
+ ".json": self.handle_json,
41
+ }
42
+
43
+ self.txt_splitter = CharacterTextSplitter(chunk_size=240, chunk_overlap=30, length_function=len,
44
+ add_start_index=True)
45
+ self.json_splitter = RecursiveJsonSplitter(max_chunk_size=240)
46
+ self.embeding = embedding()
47
+ self.client = chromadb.PersistentClient(path="database/travel")
48
+ self.database_info = json.load(open("database/travel/info.json", "r", encoding="utf-8"))
49
+ self.llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")
50
+
51
+ def get_knowledge_type(self, query):
52
+ names = self.database_info["names"]
53
+ system_prompt = f"""
54
+ 你是一名地区分类专家,主要分别判断以下类别的地区,有且仅有{','.join(names)}这{len(names)}类地区。识别准确后,返回给用户。
55
+ 识别到澳门,返回'澳门',以此类推。
56
+ """
57
+ grade_messages = [SystemMessage(content=system_prompt)]
58
+ grade_messages.append(HumanMessage(content=f"{query}"))
59
+ collection_name = self.llm.with_structured_output(KnowledgeType).invoke(grade_messages)
60
+
61
+ return ''.join(lazy_pinyin(collection_name.route))
62
+ def get_file(self, filename):
63
+ file_extension = os.path.splitext(filename)[-1]
64
+ loader = self.loader.get(file_extension, None)
65
+ if loader:
66
+ if file_extension == ".json":
67
+ return loader(filename)
68
+ else:
69
+ load_info = loader(filename).load()
70
+ return load_info
71
+
72
+ else:
73
+ return None
74
+
75
+ def handle_json(self, filename):
76
+ with open(filename, "r", encoding="utf-8") as f:
77
+ data = f.read()
78
+ return data
79
+
80
+ def is_json(self, data):
81
+ try:
82
+ json.loads(data)
83
+ return True
84
+ except:
85
+ return False
86
+
87
+ def split_text(self, filename):
88
+ load_info = self.get_file(filename)
89
+ if load_info:
90
+ if self.is_json(load_info):
91
+ self.end_splitter = self.json_splitter.split_text(json.loads(load_info), ensure_ascii=False)
92
+ else:
93
+ self.end_splitter = self.txt_splitter.split_documents(load_info)
94
+
95
+ return self.end_splitter
96
+
97
+ else:
98
+ return "文件格式不支持"
99
+ def emb_text(self, text):
100
+ return self.embeding.embed_query(text)
101
+
102
+ def vector_storage(self, filename):
103
+ data_name = self.pdf_to_pinyin(filename)
104
+ data = []
105
+ for idx, text in enumerate(tqdm(self.end_splitter, desc="向量化")):
106
+ if isinstance(text, Document):
107
+ text = text.page_content
108
+ data.append({"id": idx, "vector": self.emb_text(text), "text": text})
109
+ print(f"Collection name: {data_name}")
110
+ collection = self.client.get_or_create_collection(data_name)
111
+ collection.add(
112
+ ids=[str(item["id"]) for item in data],
113
+ embeddings=[item["vector"][0] for item in data],
114
+ documents=[item["text"] for item in data]
115
+ )
116
+ # self.milvus_client.create_collection(
117
+ # collection_name=data_name,
118
+ # dimension=768,
119
+ # metric_type="IP", # Inner product distance
120
+ # consistency_level="Strong", # Strong consistency level
121
+ # )
122
+
123
+ # self.milvus_client.insert(collection_name=data_name, data=data)
124
+ return "向量存储成功"
125
+ def pdf_to_pinyin(self,file):
126
+ name = os.path.basename(file).split('.')[0]
127
+ return ''.join(lazy_pinyin(name))
128
+ def combine(self,path):
129
+ data_name = 'travel'
130
+ data = []
131
+ for i in os.listdir(path):
132
+ self.split_text(os.path.join(path,i))
133
+ for idx, text in enumerate(tqdm(self.end_splitter, desc="向量化")):
134
+ if isinstance(text, Document):
135
+ text = text.page_content
136
+ data.append({"id": len(data), "vector": self.emb_text(text), "text": text})
137
+ print(f"Collection name: {data_name}")
138
+ collection = self.client.get_or_create_collection(data_name)
139
+ collection.add(
140
+ ids=[str(item["id"]) for item in data],
141
+ embeddings=[item["vector"][0] for item in data],
142
+ documents=[item["text"] for item in data]
143
+ )
144
+ def delete(self):
145
+ li = ['澳门', '青海', '周庄', '上海', '天津', '黄果树', '黔东南', '九寨沟', '广西', '贵阳', '扬州', '济南', '香格里拉', '香港', '昆明', '宁波', '林芝', '台北', '三清山', '呼伦贝尔', '鼓浪屿', '婺源', '厦门', '张家界', '故宫', '北戴河', '西藏', '杭州', '大同', '泰山', '秦皇岛', '成都', '凤凰', '兰州', '华山', '浙江', '哈尔滨', '沈阳', '云台山', '福州', '甘南', '三亚', '长沙', '敦煌', '苏州', '青城山', '束河', '南宁', '乌镇', '镇江', '丽江', '西塘', '黄山', '平遥', '五台山', '连云港', '拉萨', '西双版纳', '峨眉山', '武夷山', '宏村', '衡山', '横店', '北海', '桂林', '山海关', '长岛', '太原', '大连', '高雄', '青海湖', '荔波', '野三坡', '蓬莱', '合肥', '绍兴', '云南', '同里', '南京', '青岛', '北疆', '千岛湖', '南昌', '武汉', '珠海', '镇远', '武当山', '重庆', '庐山', '大理', '海口', '康定', '长白山', '曲阜', '蜀南竹海', '常州', '新疆', '海螺沟', '都江堰', '北京', '无锡', '白洋淀', '纳木错', '西溪湿地', '普陀山', '川藏', '日照', '雁荡山', '威海', '深圳', '广州', '泸沽湖', '乌鲁木齐', '西安', '稻城亚丁', '惠州', '烟台', '洛阳', '四姑娘山', '舟山']
146
+ for data_name in li:
147
+ data_name = ''.join(lazy_pinyin(data_name))
148
+ collection = self.client.delete_collection(data_name)
149
+ # collection.delete()
150
+ return "删除成功"
151
+ class embedding(object):
152
+
153
+ def __init__(self):
154
+
155
+ # init model and tokenizer
156
+ self.tokenizer = AutoTokenizer.from_pretrained('maidalun1020/bce-embedding-base_v1')
157
+ self.model = AutoModel.from_pretrained('maidalun1020/bce-embedding-base_v1')
158
+
159
+ self.device = 'cuda' # if no GPU, set "cpu"
160
+ self.model.to(self.device)
161
+
162
+
163
+
164
+ def embed_query(self, text):
165
+ # get inputs
166
+ inputs = self.tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt")
167
+ inputs_on_device = {k: v.to(self.device) for k, v in inputs.items()}
168
+
169
+ # get embeddings
170
+ outputs = self.model(**inputs_on_device, return_dict=True)
171
+ embeddings = outputs.last_hidden_state[:, 0] # cls pooler
172
+ embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) # normalize
173
+ return embeddings.tolist()
174
+
175
+
176
+
177
+ a=ChatDoc()
178
+ path = "database/travel/pdf"
179
+ # for i in os.listdir(path):
180
+ # if i.endswith(".pdf"):
181
+ # a.split_text(os.path.join(path,i))
182
+ # a.vector_storage(os.path.join(path,i))
183
+ print(a.client.list_collections())
travel/rag_tool.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ from dotenv import load_dotenv
4
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
5
+ from langchain_openai import ChatOpenAI
6
+ from tqdm import tqdm
7
+
8
+ from langchain_community.document_loaders import Docx2txtLoader, CSVLoader
9
+ from langchain.text_splitter import CharacterTextSplitter
10
+ from pymilvus import MilvusClient
11
+
12
+
13
+ class RAGTool(object):
14
+
15
+ def __init__(self):
16
+ self.loader = {
17
+ ".txt": Docx2txtLoader,
18
+ ".docx": Docx2txtLoader,
19
+ ".csv": CSVLoader,
20
+ }
21
+ self.milvus_client = MilvusClient(host="127.0.0.1", port="19530")
22
+ self.llm = ChatOpenAI(model="gpt-4o")
23
+ self.messages = [SystemMessage(
24
+ content="你是一个助手,请根据上下文回答问题,如果无法回答,请说“我不理解”,请尽量简要回答,与问题不相关的内容不用作为分析的内容。")]
25
+
26
+ def get_file(self, filename):
27
+ """
28
+ 获取文件
29
+ :param filename: 文件名
30
+ :return:
31
+ """
32
+ file_type = os.path.splitext(filename)[-1]
33
+ if file_type in self.loader:
34
+ loader = self.loader[file_type]
35
+ loader = loader(filename)
36
+ return loader.load()
37
+ else:
38
+ return None
39
+
40
+ def split_sentences(self, filename):
41
+ """
42
+ 将文件分割成句子
43
+ :param filename: 文件名
44
+ :return:
45
+ """
46
+ full_text = self.get_file(filename)
47
+ if full_text:
48
+ text_splitter = CharacterTextSplitter(chunk_size=240, chunk_overlap=30, add_start_index=True,
49
+ length_function=len)
50
+ text_split = text_splitter.split_documents(full_text)
51
+ return text_split
52
+ else:
53
+ return "文档格式不支持"
54
+
55
+ def emb_text(self, text):
56
+ """
57
+ 将文本向量化
58
+ :param text: 文本
59
+ :return:
60
+ """
61
+ from langchain_openai import OpenAIEmbeddings
62
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
63
+ return embeddings.embed_query(text)
64
+
65
+ def vector_storage(self, filename):
66
+ text_split = self.split_sentences(filename)
67
+ data_vector = []
68
+ for idx, text in enumerate(tqdm(text_split, desc="Embedding")):
69
+ data_vector.append({
70
+ "id": idx,
71
+ "text": text.page_content,
72
+ "vector": self.emb_text(text.page_content)
73
+ })
74
+
75
+ self.milvus_client.create_collection(
76
+ collection_name="test_collection",
77
+ dimension=1536,
78
+ metric_type="IP",
79
+ consistency_level="Strong"
80
+ )
81
+
82
+ self.milvus_client.insert(collection_name="test_collection", data=data_vector)
83
+ return "success"
84
+
85
+ def query_data(self, query):
86
+ query_vector = self.emb_text(query)
87
+ result = self.milvus_client.search(
88
+ collection_name="test_collection",
89
+ data=[query_vector],
90
+ limit=3,
91
+ output_fields=["text"],
92
+ params={"metric_type": "IP"},
93
+ )
94
+
95
+ result_info = ""
96
+ for info in result[0]:
97
+ result_info += info["entity"]["text"]
98
+
99
+ return result_info
100
+
101
+ def get_answer(self, question):
102
+ """
103
+ 获取答案
104
+ :param question: 问题
105
+ :return:
106
+ """
107
+
108
+ result = self.query_data(question)
109
+
110
+ self.messages.append(HumanMessage(content=f"问题:{question},检索内容:{result}"))
111
+ res = self.llm.invoke(self.messages)
112
+ self.messages.append(AIMessage(content=res.content))
113
+ return res.content
114
+
115
+
travel/self_rag_tool.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import chromadb
4
+ from dotenv import load_dotenv
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
7
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
8
+ from pydantic.v1 import BaseModel, Field
9
+ from pymilvus import MilvusClient
10
+ from typing import Literal
11
+ from transformers import AutoModel, AutoTokenizer
12
+ from pypinyin import lazy_pinyin
13
+ import json
14
+
15
+ class LocationType(BaseModel):
16
+ """
17
+ 将用户查询路由到最相关的数据源
18
+ """
19
+ route: Literal['澳门', '青海', '周庄', '上海', '天津', '黄果树', '黔东南', '九寨沟', '广西', '贵阳', '扬州', '济南', '香格里拉', '香港', '昆明', '宁波', '林芝', '台北', '三清山', '呼伦贝尔', '鼓浪屿', '婺源', '厦门', '张家界', '故宫', '北戴河', '西藏', '杭州', '大同', '泰山', '秦皇岛', '成都', '凤凰', '兰州', '华山', '浙江', '哈尔滨', '沈阳', '云台山', '福州', '甘南', '三亚', '长沙', '敦煌', '苏州', '青城山', '束河', '南宁', '乌镇', '镇江', '丽江', '西塘', '黄山', '平遥', '五台山', '连云港', '拉萨', '西双版纳', '峨眉山', '武夷山', '宏村', '衡山', '横店', '北海', '桂林', '山海关', '长岛', '太原', '大连', '高雄', '青海湖', '荔波', '野三坡', '蓬莱', '合肥', '绍兴', '云南', '同里', '南京', '青岛', '北疆', '千岛湖', '南昌', '武汉', '珠海', '镇远', '武当山', '重庆', '庐山', '大理', '海口', '康定', '长白山', '曲阜', '蜀南竹海', '常州', '新疆', '海螺沟', '都江堰', '北京', '无锡', '白洋淀', '纳木错', '西溪湿地', '普陀山', '川藏', '日照', '雁荡山', '威海', '深圳', '广州', '泸沽湖', '乌鲁木齐', '西安', '稻城亚丁', '惠州', '烟台', '洛阳', '四姑娘山', '舟山'] = Field(...,description="用户给定一个问题,选择最相关一个进行输出")
20
+
21
+
22
+ class GradedRagTool(BaseModel):
23
+ """
24
+ 对检索到到文档进行相关性的检查,相关返回yes,不相关返回no
25
+ """
26
+
27
+ binary_score: Literal['yes', 'no'] = Field(description="文档与问题的相关性,'yes' or 'no'")
28
+
29
+
30
+ class GradeHallucinations(BaseModel):
31
+ """
32
+ 对最终对回答进行一个判断,判断回答中是否存在幻觉,存在则输出yes,不存在这输出no
33
+ """
34
+
35
+ binary_score: Literal['yes', 'no'] = Field(description="问题与回答的相关性,'yes' or 'no'")
36
+
37
+
38
+ class GradeAnswer(BaseModel):
39
+ """对最终的回答于问题进行比对,判断回答和问题是相关的,是相关的则输出yes,不相关则输出no"""
40
+
41
+ binary_score: Literal['yes', 'no'] = Field(
42
+ description="问题与回答的相关性, 'yes' or 'no'"
43
+ )
44
+
45
+
46
+ class GradeAndGenerateTool(object):
47
+
48
+ def __init__(self, database_path="database/travel"):
49
+ self.llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")
50
+ # self.llm = ChatOpenAI(temperature=0, model="qwen2-14b",api_key="empty",base_url = "http://61.136.221.118:15001/v1")
51
+ self.struct_llm_grader = self.llm.with_structured_output(GradedRagTool)
52
+ self.struct_llm_hallucinations = self.llm.with_structured_output(GradeHallucinations)
53
+ self.struct_llm_answer = self.llm.with_structured_output(GradeAnswer)
54
+ self.embeding = embedding()
55
+ self.database_info = json.load(open(os.path.join(database_path,"info.json"), "r", encoding="utf-8"))
56
+ self.client = chromadb.PersistentClient(path=database_path)
57
+
58
+ # 评分
59
+ def grade(self, question, text):
60
+ system_prompt = """
61
+ 你是一名评估检索到到文档与用户到问题相关性到评分员,不需要一个严格的测试,目标是过滤掉错误的检索。如果文档包含与用户问题相关的关键字或者语义,请评为相关,否则请评为不相关。你的回答只能是yes或者no
62
+ """
63
+ grade_messages = [SystemMessage(content=system_prompt)]
64
+ grade_messages.append(HumanMessage(content=f"问题:{question}\n文档:{text}"))
65
+ result = self.struct_llm_grader.invoke(grade_messages)
66
+ return result.binary_score
67
+
68
+ # 生成答案
69
+ def generate(self, question, text):
70
+ grade_human_prompt = f"""您是问答任务的助理。使用以下检索到的上下文来回答问题。如果你不知道答案,就说你不知道。尽量将回答长度控制在三句话内,保持答案简洁。\n问题:{question}\n上下文:{text}\n答案:"""
71
+ human_prompt = ChatPromptTemplate.from_template(grade_human_prompt)
72
+ grade_human_prompt_end = human_prompt.format_messages(question=question, text=text)
73
+ result = self.llm.invoke(grade_human_prompt_end)
74
+ return result.content
75
+
76
+ # 判断是否有幻觉
77
+ def hallucinations(self, documents, answer):
78
+ hallucinations_prompt = "您是一名评估LLM生成是否基于一组检索到的事实的评分员。如果是基于���索到的事实回答则返回no,否则返回yes"
79
+ hallucinations_messages = [SystemMessage(content=hallucinations_prompt)]
80
+ hallucinations_messages.append(HumanMessage(content=f":回答:{answer}\n文档:{documents}"))
81
+ result = self.struct_llm_hallucinations.invoke(hallucinations_messages)
82
+ return result.binary_score
83
+
84
+ # 判断答案是否和问题相关
85
+ def answer_question(self, question, answer):
86
+ answer_question_prompt = """
87
+ 你是一名评分员,评估答案是否解决了问题,如果解决了则返回yes,否则返回no
88
+ """
89
+ answer_question_messages = [SystemMessage(content=answer_question_prompt)]
90
+ answer_question_messages.append(HumanMessage(content=f"问题:{question}\n回答:{answer}"))
91
+ result = self.struct_llm_answer.invoke(answer_question_messages)
92
+ return result.binary_score
93
+
94
+ # 复写问题
95
+ def rewrite_question(self, question):
96
+ rewrite_promtp = "您是一个将输入问题转换为优化的更好版本的问题重写器\n用于矢量库检索。查看输入并尝试推理潜在的语义意图/含义。"
97
+ rewrite_promtp_messages = [SystemMessage(content=rewrite_promtp)]
98
+ rewrite_promtp_messages.append(HumanMessage(content=f"问题:{question}"))
99
+ result = self.llm.invoke(rewrite_promtp_messages)
100
+ return result.content
101
+
102
+ def embed_dim(self, text):
103
+ return self.embeding.embed_query(text)
104
+
105
+ def get_knowledge_type(self, query):
106
+ names = self.database_info["names"]
107
+ system_prompt = f"""
108
+ 你是一名地区分类专家,主要分别判断以下类别的地区,有且仅有{','.join(names)}这{len(names)}类地区。识别准确后,返回给用户。
109
+ 识别到澳门,返回'澳门',以此类推。
110
+ """
111
+ grade_messages = [SystemMessage(content=system_prompt)]
112
+ grade_messages.append(HumanMessage(content=f"{query}"))
113
+ collection_name = self.llm.with_structured_output(LocationType).invoke(grade_messages)
114
+
115
+ return ''.join(lazy_pinyin(collection_name.route))
116
+ # 检索
117
+ def search_vector(self, question):
118
+ a=time.time()
119
+ collection_name = 'travel'
120
+
121
+ result = self.client.get_collection(collection_name).query(
122
+ query_embeddings=[self.embed_dim(question)[0]],
123
+ n_results=3,
124
+ )
125
+
126
+ # result = self.milvus_client.search(collection_name="RAG_vector", data=[self.embed_dim(question)],
127
+ # output_fields=["text"])
128
+ return result['documents'][0]
129
+
130
+ class embedding(object):
131
+
132
+ def __init__(self):
133
+
134
+ # init model and tokenizer
135
+ self.tokenizer = AutoTokenizer.from_pretrained('maidalun1020/bce-embedding-base_v1')
136
+ self.model = AutoModel.from_pretrained('maidalun1020/bce-embedding-base_v1')
137
+
138
+ self.device = 'cuda' # if no GPU, set "cpu"
139
+ self.model.to(self.device)
140
+
141
+
142
+
143
+ def embed_query(self, text):
144
+ # get inputs
145
+ inputs = self.tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt")
146
+ inputs_on_device = {k: v.to(self.device) for k, v in inputs.items()}
147
+
148
+ # get embeddings
149
+ outputs = self.model(**inputs_on_device, return_dict=True)
150
+ embeddings = outputs.last_hidden_state[:, 0] # cls pooler
151
+ embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) # normalize
152
+ return embeddings.tolist()
153
+
154
+
155
+ # a=GradeAndGenerateTool()
156
+ # a.search_vector("澳门旅游指南")
157
+
158
+
travel/travel.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 举例:1、高德搜索附近的店 2、高德获取地点的经纬度。 3、RAG功能
3
+
4
+ """
5
+ import functools
6
+ import operator
7
+ import os
8
+ import time
9
+ from typing import Type, TypedDict, Annotated, Sequence
10
+ from langchain_openai import ChatOpenAI
11
+ import aiohttp
12
+ import requests
13
+ from dotenv import load_dotenv
14
+ from langchain_community.output_parsers.ernie_functions import JsonOutputFunctionsParser
15
+ from langchain_core.messages import HumanMessage, BaseMessage, SystemMessage, AIMessage
16
+ from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
17
+ from langchain_core.tools import BaseTool
18
+ from langgraph.constants import END
19
+ from langgraph.graph import StateGraph
20
+ from pydantic.v1 import BaseModel, Field
21
+ import gradio as gr
22
+ from RAGGraph import RAGTool
23
+ from langchain.agents import create_openai_tools_agent, AgentExecutor
24
+
25
+ _ = load_dotenv()
26
+
27
+ rag_tool = RAGTool()
28
+
29
+
30
+ class searchAroundInput(BaseModel):
31
+ keyword: str = Field(..., description="搜索关键词")
32
+ location: str = Field(..., description="地点的经纬度")
33
+
34
+
35
+ class searchAround(BaseTool):
36
+ args_schema: Type[BaseModel] = searchAroundInput
37
+ description = "这是一个搜索周边信息的方法,需要用户提供关键词和地点的经纬度,才能进行周边信息的搜索。如果用户没有提供关键词或者地点的经纬度,则提示用户给出关键词和地点的经纬度并再进行周边信息的搜索。"
38
+ name = "searchAround"
39
+
40
+ def _run(self, keyword, location):
41
+ around_url = "https://restapi.amap.com/v5/place/around"
42
+ params = {
43
+ "key": "df8ff851968143fb413203f195fcd7d7",
44
+ "keywords": keyword,
45
+ "location": location
46
+ }
47
+ print("同步调用获取地点周边的方法")
48
+ res = requests.get(url=around_url, params=params)
49
+ # prompt = "请帮我整理以下内容中的名称,地址和距离,并按照地址与名称对应输出,且告诉距离多少米,内容:{}".format(
50
+ # res.json())
51
+ # result = llm.invoke(prompt)
52
+ return res.text
53
+
54
+ async def _arun(self, keyword, location):
55
+ async with aiohttp.ClientSession() as session:
56
+ around_url = "https://restapi.amap.com/v5/place/around"
57
+ params = {
58
+ "key": "df8ff851968143fb413203f195fcd7d7",
59
+ "keywords": keyword,
60
+ "location": location
61
+ }
62
+ print("异步调用获取地点周边的方法")
63
+ async with session.get(url=around_url, params=params) as response:
64
+ return await response.json()
65
+
66
+
67
+ class getLocationInput(BaseModel):
68
+ keyword: str = Field(..., description="搜索关键词")
69
+
70
+
71
+ class getLocation(BaseTool):
72
+ args_schema: Type[BaseModel] = getLocationInput
73
+ description = "这是一个获取地点的经纬度的方法,需要用户提供关键词,才能进行地点的经纬度的获取。如果用户没有提供关键词,则提示用户给出关键词并再进行地点的经纬度的获取。"
74
+ name = "getLocation"
75
+
76
+ def _run(self, keyword):
77
+ url = "https://restapi.amap.com/v5/place/text"
78
+ params = {
79
+ "key": "df8ff851968143fb413203f195fcd7d7",
80
+ "keywords": keyword,
81
+ }
82
+ res = requests.get(url=url, params=params)
83
+ print("同步调用获取地点的经纬度方法")
84
+ return '{}的经纬度是:'.format(keyword) + res.json()["pois"][0]["location"]
85
+
86
+ async def _arun(self, keyword):
87
+ async with aiohttp.ClientSession() as session:
88
+ url = "https://restapi.amap.com/v5/place/text"
89
+ params = {
90
+ "key": "df8ff851968143fb413203f195fcd7d7",
91
+ "keywords": keyword,
92
+ }
93
+ print("异步调用获取地点的经纬度方法")
94
+ async with session.get(url=url, params=params) as response:
95
+ res = await response.json()
96
+ return '{}的经纬度是:'.format(keyword) + res["pois"][0]["location"]
97
+
98
+
99
+ class ragToolInput(BaseModel):
100
+ question: str = Field(..., description="用户的问题")
101
+
102
+ class ragTool(BaseTool):
103
+ args_schema: Type[BaseModel] = ragToolInput
104
+ description = "这是一个RAG工具,可以提供中国国内旅游的相关指南攻略等信息。"
105
+ name = "ragTool"
106
+
107
+ def _run(self, question):
108
+ a=time.time()
109
+ return rag_tool.get_answer(question)
110
+ def create_agent(llm, tools, system_prompt):
111
+ prompt = ChatPromptTemplate.from_messages([
112
+ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"),
113
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
114
+ ])
115
+ agent = create_openai_tools_agent(llm, tools, prompt)
116
+ executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
117
+ return executor
118
+
119
+
120
+ def agent_node(state, agent, name):
121
+ result = agent.invoke(state)
122
+ return {
123
+ "messages": [HumanMessage(content=result["output"], name=name)]
124
+ }
125
+ def chat(message, history=[]):
126
+ history_message.append(HumanMessage(content=message))
127
+ a=time.time()
128
+ res = graph.invoke({"messages":history_message})
129
+ res = res['messages'][-1].content
130
+ history_message.append(AIMessage(content=res))
131
+ print(time.time()-a)
132
+ print(res)
133
+ return res
134
+
135
+
136
+ # 1.封装agent,其中包含特有的工具,系统提示,agent的执行器
137
+ # 2.封装node,node是一个函数,在函数中会对agent进行调用
138
+ llm = ChatOpenAI(model="gpt-3.5-turbo")
139
+ supervisor_llm = ChatOpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY4"))
140
+ get_location_agent = create_agent(llm=llm, tools=[getLocation()],
141
+ system_prompt="你是一个获取地点经纬度的助手,当用户需要获取经纬度时,你需要准确提供经纬度信息")
142
+ get_location_node = functools.partial(agent_node, agent=get_location_agent, name="get_location_agent")
143
+ search_around_agent = create_agent(llm=llm, tools=[searchAround()],
144
+ system_prompt="你是一个地图通,你能够根据提供的经纬度去搜索周边店面信息。并返回给用户")
145
+ search_around_node = functools.partial(agent_node, agent=search_around_agent, name="search_around_agent")
146
+ rag_agent = create_agent(llm=llm, tools=[ragTool()],
147
+ system_prompt="你是一个RAG工具,主要是对于旅游相关内容的指南攻略等信息。")
148
+ rag_node = functools.partial(agent_node, agent=rag_agent, name="rag_agent")
149
+
150
+ member = ["search_around_agent", "get_location_agent", 'rag_agent']
151
+
152
+ system_prompt = f"""
153
+ 你是一名任务管理者,负责管理任务的调度,下面是你的工作者{member},给定以下请求,与工作者一起响应,并采取下一步行动。
154
+ 每个工作者将执行一个任务并回复执行后的结果和状态,若已经完成后,用FINISH回应。
155
+ """
156
+
157
+ options = member + ["FINISH"]
158
+
159
+ function_def = {
160
+ "name": "route",
161
+ "description": "选择下一个工作者",
162
+ "parameters": {
163
+ "title": "routeSchema",
164
+ "type": "object",
165
+ "properties": {
166
+ "next": {
167
+ "title": "Next",
168
+ "anyOf": [
169
+ {
170
+ "enum": options
171
+ }
172
+ ],
173
+ }
174
+ },
175
+ "required": ["next"]
176
+ }
177
+ }
178
+
179
+ prompt = ChatPromptTemplate.from_messages(
180
+ [("system",system_prompt),MessagesPlaceholder(variable_name="messages"),
181
+ ("system",f"基于上述的对话接下来应该是谁来采取行动?请在以下选项中进行选择{options},"
182
+ "如果你认为最后一句能够对以上对话形成较好回复比如完成了打招呼,请尽量选择'FINISH',"
183
+ "如果问题仍未解决且问题与旅游相关请选择'rag_agent',"
184
+ "如果问题仍未解决且问题与获取地点经纬度相关请选择'get_location_agent',"
185
+ "如果问题仍未解决且问题与地点周边信息相关请选择'search_around_agent',")]
186
+ # "如果问题仍未解决且问题是一般性问题请选择'normal_chat'。")]
187
+ ).partial(options=str(options),member=",".join(member))
188
+
189
+
190
+ supervisor_chain = prompt | supervisor_llm.bind_functions(functions=[function_def],function_call="route") | JsonOutputFunctionsParser()
191
+
192
+
193
+
194
+ class AgentState(TypedDict):
195
+ messages : Annotated[Sequence[BaseMessage],operator.add]
196
+ next : str
197
+
198
+
199
+
200
+ work_flow = StateGraph(AgentState)
201
+
202
+ work_flow.add_node("get_location_agent",get_location_node)
203
+ work_flow.add_node("search_around_agent",search_around_node)
204
+ work_flow.add_node("rag_agent",rag_node)
205
+ work_flow.add_node("supervisor",supervisor_chain)
206
+
207
+ for name in member:
208
+ work_flow.add_edge(name,"supervisor")
209
+
210
+ conditional_map = {
211
+ "get_location_agent":"get_location_agent",
212
+ "search_around_agent":"search_around_agent",
213
+ "rag_agent":"rag_agent",
214
+ "FINISH":END,
215
+ }
216
+
217
+ work_flow.add_conditional_edges("supervisor",lambda x : x["next"],conditional_map)
218
+
219
+ work_flow.set_entry_point("supervisor")
220
+
221
+ graph = work_flow.compile()
222
+ history_message = []
223
+ chat("北京的旅游攻略")
224
+
225
+ # iface_chat_file = gr.ChatInterface(
226
+ # fn=chat,
227
+ # examples=['北京的经纬度是多少', '天安门附近的餐馆有哪些','北京的旅游攻略'],
228
+ # title="Chat File Interface",
229
+ # )
230
+ # iface_chat_file.launch(share=True, server_name='0.0.0.0', server_port=5001)
231
+ # a=time.time()
232
+ # res = graph.invoke({"messages":[HumanMessage(content="天安门附近的餐馆有哪些")]})
233
+ # print(time.time()-a)
234
+ # print(res)
235
+
236
+
237
+
travel/travel_new.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 举例:1、高德搜索附近的店 2、高德获取地点的经纬度。 3、RAG功能
3
+
4
+ """
5
+ import functools
6
+ import operator
7
+ import os
8
+ import time
9
+ from typing import Type, TypedDict, Annotated, Sequence
10
+
11
+ import aiohttp
12
+ import requests
13
+ from dotenv import load_dotenv
14
+ from langchain_community.output_parsers.ernie_functions import JsonOutputFunctionsParser
15
+ from langchain_core.messages import HumanMessage, BaseMessage, SystemMessage, AIMessage
16
+ from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
17
+ from langchain_core.tools import BaseTool
18
+ from langgraph.constants import END
19
+ from langgraph.graph import StateGraph
20
+ from pydantic.v1 import BaseModel, Field
21
+ from typing import Literal
22
+ from RAGGraph import RAGTool
23
+ from langchain.agents import create_openai_tools_agent, AgentExecutor
24
+
25
+ import gradio as gr
26
+ from langchain_openai import ChatOpenAI
27
+ # 1.封装agent,其中包含特有的工具,系统提示,agent的执行器
28
+ # 2.封装node,node是一个函数,在函数中会对agent进行调用
29
+ load_dotenv()
30
+ rag_tool = RAGTool()
31
+
32
+ class AgentType(BaseModel):
33
+ """
34
+ 将用户查询路由到最相关的数据源
35
+ """
36
+ route: Literal["search_around_agent", "get_location_agent", 'rag_agent', 'FINISH', 'normal_chat'] = Field(...,description="用户给定一个问题,选择最相关一个进行输出")
37
+ arguments: dict | None = Field(None, description="给选定的代理的参数")
38
+
39
+ class AgentState(TypedDict):
40
+ messages : Annotated[Sequence[BaseMessage],operator.add]
41
+ next : str
42
+
43
+
44
+ class searchAroundInput(BaseModel):
45
+ keyword: str = Field(..., description="搜索关键词")
46
+ location: str = Field(..., description="地点的经纬度")
47
+
48
+
49
+ class searchAround(BaseTool):
50
+ args_schema: Type[BaseModel] = searchAroundInput
51
+ description = "这是一个搜索周边信息的方法,需要用户提供关键词和地点的经纬度,才能进行周边信息的搜索。如果用户没有提供关键词或者地点的经纬度,则提示用户给出关键词和地点的经纬度并再进行周边信息的搜索。"
52
+ name = "searchAround"
53
+
54
+ def _run(self, keyword, location):
55
+ around_url = "https://restapi.amap.com/v5/place/around"
56
+ params = {
57
+ "key": "df8ff851968143fb413203f195fcd7d7",
58
+ "keywords": keyword,
59
+ "location": location
60
+ }
61
+ print("同步调用获取地点周边的方法")
62
+ res = requests.get(url=around_url, params=params)
63
+ # prompt = "请帮我整理以下内容中的名称,地址和距离,并按照地址与名称对应输出,且告诉距离多少米,内容:{}".format(
64
+ # res.json())
65
+ # result = llm.invoke(prompt)
66
+ return res.text
67
+
68
+ async def _arun(self, keyword, location):
69
+ async with aiohttp.ClientSession() as session:
70
+ around_url = "https://restapi.amap.com/v5/place/around"
71
+ params = {
72
+ "key": "df8ff851968143fb413203f195fcd7d7",
73
+ "keywords": keyword,
74
+ "location": location
75
+ }
76
+ print("异步调用获取地点周边的方法")
77
+ async with session.get(url=around_url, params=params) as response:
78
+ return await response.json()
79
+
80
+
81
+ class getLocationInput(BaseModel):
82
+ keyword: str = Field(..., description="搜索关键词")
83
+
84
+
85
+ class getLocation(BaseTool):
86
+ args_schema: Type[BaseModel] = getLocationInput
87
+ description = "这是一个获取地点的经纬度的方法,需要用户提供关键词,才能进行地点的经纬度的获取。如果用户没有提供关键词,则提示用户给出关键词并再进行地点的经纬度的获取。"
88
+ name = "getLocation"
89
+
90
+ def _run(self, keyword):
91
+ url = "https://restapi.amap.com/v5/place/text"
92
+ params = {
93
+ "key": "df8ff851968143fb413203f195fcd7d7",
94
+ "keywords": keyword,
95
+ }
96
+ res = requests.get(url=url, params=params)
97
+ print("同步调用获取地点的经纬度方法")
98
+ return '{}的经纬度是:'.format(keyword) + res.json()["pois"][0]["location"]
99
+
100
+ async def _arun(self, keyword):
101
+ async with aiohttp.ClientSession() as session:
102
+ url = "https://restapi.amap.com/v5/place/text"
103
+ params = {
104
+ "key": "df8ff851968143fb413203f195fcd7d7",
105
+ "keywords": keyword,
106
+ }
107
+ print("异步调用获取地点的经纬度方法")
108
+ async with session.get(url=url, params=params) as response:
109
+ res = await response.json()
110
+ return '{}的经纬度是:'.format(keyword) + res["pois"][0]["location"]
111
+
112
+ class ragToolInput(BaseModel):
113
+ question: str = Field(..., description="用户的问题")
114
+
115
+ class ragTool(BaseTool):
116
+ args_schema: Type[BaseModel] = ragToolInput
117
+ description = "这是一个RAG工具,可以提供中国国内旅游的相关指南攻略等信息。"
118
+ name = "ragTool"
119
+
120
+ def _run(self, question):
121
+ a=time.time()
122
+ res = rag_tool.get_answer(question)
123
+ print('ragTool',time.time()-a)
124
+ return res
125
+ def create_agent(llm, tools, system_prompt):
126
+ prompt = ChatPromptTemplate.from_messages([
127
+ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"),
128
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
129
+ ])
130
+ agent = create_openai_tools_agent(llm, tools, prompt)
131
+ executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
132
+ return executor
133
+
134
+
135
+ def agent_node(state, agent, name):
136
+ result = agent.invoke(state)
137
+ return {
138
+ "messages": [HumanMessage(content=result["output"], name=name)]
139
+ }
140
+
141
+
142
+ def supervisor(state):
143
+ message = state['messages']
144
+ prompt = [SystemMessage(content=(
145
+ f"基于上述的对话接下来应该是谁来采取行动?请在以下选项中进行选择{options},"
146
+ "如果你认为最后一句能够对以上对话形成较好回复比如完成了打招呼,请尽量选择'FINISH',"
147
+ "如果问题仍未解决且问题与旅游相关请选择'rag_agent',"
148
+ "如果问题仍未解决且问题与获取地点经纬度相关请选择'get_location_agent',"
149
+ "如果问题仍未解决且问题与地点周边信息相关请选择'search_around_agent',"
150
+ "如果问题仍未解决且问题是一般性问题请选择'normal_chat'。"
151
+ ))]
152
+ a=time.time()
153
+ res = supervisor_llm.with_structured_output(AgentType).invoke(message+prompt)
154
+ print('supervisor',time.time()-a)
155
+ return {'next': res.route}
156
+ # if res.route == "FINISH":
157
+ # state['messages'].append(SystemMessage(content="任务结束"))
158
+ # state['next'] = "FINISH"
159
+ # return state
160
+ # else:
161
+ # state['next'] = res.route
162
+ # return state
163
+
164
+ def noraml_chat(state,name):
165
+ message = state['messages']
166
+ result = llm.invoke(message)
167
+ return {
168
+ "messages": [HumanMessage(content=result.content, name=name)]
169
+ }
170
+ def chat(message, history=[]):
171
+
172
+ history_message.append(HumanMessage(content=message))
173
+ res = graph.invoke({"messages":history_message})
174
+ res = res['messages'][-1].content
175
+ history_message.append(AIMessage(content=res))
176
+ print(res)
177
+ return res
178
+
179
+ if __name__ == '__main__':
180
+ llm = ChatOpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY4"))
181
+ supervisor_llm = ChatOpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY4"))
182
+ get_location_agent = create_agent(llm=llm, tools=[getLocation()],
183
+ system_prompt="你是一个获取地点经纬度的助手,当用户需要获取经纬度时,你需要准确提供经纬度信息")
184
+ get_location_node = functools.partial(agent_node, agent=get_location_agent, name="get_location_agent")
185
+ search_around_agent = create_agent(llm=llm, tools=[searchAround()],
186
+ system_prompt="你是一个地图通,你能够根据提供的经纬度去搜索周边店面信息。并返回给用户")
187
+ search_around_node = functools.partial(agent_node, agent=search_around_agent, name="search_around_agent")
188
+ rag_agent = create_agent(llm=llm, tools=[ragTool()],
189
+ system_prompt="你是一个RAG工具,主要是对于旅游相关内容的指南攻略等信息。")
190
+ rag_node = functools.partial(agent_node, agent=rag_agent, name="rag_agent")
191
+ # normal_agent = create_agent(llm=llm, tools=[],
192
+ # system_prompt="你能回答一些不需要联网搜索的一般性的问题")
193
+ normal_node = functools.partial(noraml_chat, name="normal_chat")
194
+ member = ["search_around_agent", "get_location_agent", 'rag_agent', 'normal_chat']
195
+
196
+ options = member + ["FINISH"]
197
+
198
+ work_flow = StateGraph(AgentState)
199
+
200
+ work_flow.add_node("get_location_agent",get_location_node)
201
+ work_flow.add_node("search_around_agent",search_around_node)
202
+ work_flow.add_node("rag_agent",rag_node)
203
+ work_flow.add_node("normal_chat",normal_node)
204
+ work_flow.add_node("supervisor",supervisor)
205
+
206
+ for name in member:
207
+ work_flow.add_edge(name,"supervisor")
208
+
209
+ conditional_map = {
210
+ "get_location_agent":"get_location_agent",
211
+ "search_around_agent":"search_around_agent",
212
+ "rag_agent":"rag_agent",
213
+ "normal_chat":"normal_chat",
214
+ "FINISH":END,
215
+ }
216
+
217
+ work_flow.add_conditional_edges("supervisor",lambda x : x["next"],conditional_map)
218
+
219
+ work_flow.set_entry_point("supervisor")
220
+
221
+ graph = work_flow.compile()
222
+
223
+ system_prompt = f"""
224
+ 你是一名任务管理者,负责管理任务的调度,下面是你的工作者{member},给定以下请求,与工作者一起响应,并采取下一步行动。
225
+ 每个工作者将执行一个任务并回复执行后的结果和状态,若对话中最后一句能够形成较好回复比如完成了打招呼,请选择'FINISH'。
226
+ """
227
+ history_message = [SystemMessage(content=system_prompt)]
228
+ # chat('hello')
229
+
230
+
231
+ iface_chat_file = gr.ChatInterface(
232
+ fn=chat,
233
+ examples=['北京的经纬度是多少', '天安门附近的餐馆有哪些','北京的旅游攻略'],
234
+ title="Chat File Interface",
235
+ )
236
+ iface_chat_file.launch(share=True, server_name='0.0.0.0', server_port=5001)