File size: 6,555 Bytes
b2e325f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e418d71
b2e325f
 
e418d71
b2e325f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e418d71
b2e325f
e418d71
 
 
 
 
b2e325f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
1. rag_reponse_002.py is a modified version of rag_reponse_001.py. 主要是为了测试用ChatGPT+Reranker+最后给出相似查询的页面结构。

"""
##TODO: 1. 将LLM改成ChatGPT. 2. Reranker. 3. 最后给出相似查询的页面结构

from langchain_community.vectorstores import FAISS
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableParallel
import streamlit as st
import re
import openai
import os
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from typing import Dict, List, Optional, Tuple, Union
# import chatgpt
import qwen_response
from dotenv import load_dotenv
import dashscope

load_dotenv()
### 设置openai的API key
os.environ["OPENAI_API_KEY"] = os.environ['user_token']
openai.api_key = os.environ['user_token']
bing_search_api_key = os.environ['bing_api_key']
dashscope.api_key = os.environ['dashscope_api_key']


from langchain.embeddings.openai import OpenAIEmbeddings

# embeddings = HuggingFaceEmbeddings(model_name='GanymedeNil/text2vec-large-chinese') ## 这里是联网情况下,部署在Huggingface上后使用。
# embeddings = OpenAIEmbeddings(disallowed_special=())  ## 这里是联网情况下,部署在Huggingface上后使用。
# embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/RAG/bge-large-zh') ## 切换成BGE的embedding。
# vector_store = FAISS.load_local("./faiss_index/", embeddings=embeddings, allow_dangerous_deserialization=True) ## 加载vector store到本地。
# vector_store = FAISS.load_local("./faiss_index/", embeddings=embeddings) ## 加载vector store到本地。 ### original code here.

# ## 配置ChatGLM的类与后端api server对应。
# class ChatGLM(LLM):
#     max_token: int = 8096 ###  无法输出response的时候,可以看一下这里。
#     temperature: float = 0.7
#     top_p = 0.9
#     history = []

#     def __init__(self):
#         super().__init__()

#     @property
#     def _llm_type(self) -> str:
#         return "ChatGLM"

#     def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
#         # headers中添加上content-type这个参数,指定为json格式
#         headers = {'Content-Type': 'application/json'}
#         data=json.dumps({
#             'prompt':prompt,
#             'temperature':self.temperature,
#             'history':self.history,
#             'max_length':self.max_token
#         })
#         print("ChatGLM prompt:",prompt)
#         # 调用api
#         # response = requests.post("http://0.0.0.0:8000",headers=headers,data=data) ##working。
#         response = requests.post("http://127.0.0.1:8000",headers=headers,data=data) ##working。
#         print("ChatGLM resp:", response)
        
#         if response.status_code!=200:
#             return "查询结果错误"
#         resp = response.json()
#         if stop is not None:
#             response = enforce_stop_tokens(response, stop)
#         self.history = self.history+[[None, resp['response']]] ##original
#         return resp['response'] ##original.

## 在绝对路径中提取完整的文件名
def extract_document_name(path):
    # 路径分割
    path_segments = path.split("/")
    # 文件名提取
    document_name = path_segments[-1]
    return document_name

## 从一段话中提取 1 句完整的句子,且该句子的长度必须超过 5 个词,同时去除了换行符'\n\n'。
import re
def extract_sentence(text):
    """
    从一段话中提取 1 句完整的句子,且该句子的长度必须超过 5 个词。

    Args:
        text: 一段话。

    Returns:
        提取到的句子。
    """

    # 去除换行符。
    text = text.replace('\n\n', '')
    # 使用正则表达式匹配句子。
    sentences = re.split(r'[。?!;]', text)

    # 过滤掉长度小于 5 个词的句子。
    sentences = [sentence for sentence in sentences if len(sentence.split()) >= 5]

    # 返回第一句句子。
    return sentences[0] if sentences else None

### 综合source的输出内容。
def rag_source(docs):
    print('starting source function!')
    source = ""
    for i, doc in enumerate(docs):
        source += f"**【信息来源 {i+1}】** " + extract_document_name(doc.metadata['source']) + ',' + f"第{docs[i].metadata['page']+1}页" + ',部分内容摘录:' + extract_sentence(doc.page_content) + '\n\n'
    print('source:', source)
    return source

def rag_response(username, user_input, k=3):
    # docs = vector_store.similarity_search('user_input', k=k) ## Original。
    
    embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-zh-v1.5') ## 这里是联网情况下,部署在Huggingface上后使用。
    # embeddings = HuggingFaceEmbeddings(model_name='GanymedeNil/text2vec-large-chinese') ## 这里是联网情况下,部署在Huggingface上后使用。
    print('embeddings:', embeddings)
    vector_store = FAISS.load_local(f"./{username}/faiss_index/", embeddings=embeddings, allow_dangerous_deserialization=True) ## 加载vector store到本地。
    docs = vector_store.similarity_search(user_input, k=k) ##TODO 'user_input' to user_input?
    context = [doc.page_content for doc in docs]
    # print('context: {}'.format(context))

    source = rag_source(docs=docs) ## 封装到一个函数中。
    
    ## 用大模型来回答问题。
    # llm = ChatGLM() ## 启动一个实例。
    # final_prompt = f"已知信息:\n{context}\n 根据这些已知信息来回答问题:\n{user_input}"
    final_prompt = f"已知信息:\n{context}\n 根据这些已知信息尽可能详细且专业地来回答问题:\n{user_input}"
    
    ## LLM的回答
    # response = llm(prompt=final_prompt) ## 通过实例化之后的LLM来输出结果。
    # response = chatgpt.chatgpt(user_prompt=final_prompt) ## 通过ChatGPT实例化之后的LLM来输出结果。
    response = qwen_response.call_with_messages(prompt=final_prompt)# import 
    # response = llm(prompt=final_prompt) ## 通过实例化之后的LLM来输出结果。
    # response = llm(prompt='where is shanghai')
    # print('response now:' + response)
    
    return response, source

# # import asyncio
# response, source = rag_response('我是一个企业主,我需要关注哪些存货的数据资源规则?')
# print(response)
# print(source)