training_data_chat / base_chat.py
kain183's picture
new chat
6c1e91d
# -*- coding: utf-8 -*-
# 实现用户聊天记录的永久上下文
# 将用户的聊天记录进行本地向量化后,进行相似性的搜索
# 得出搜索后的文档数据后,进行聊天记录的拼凑,进行openAI的请求
import pickle
import faiss
import os
from langchain import LLMChain
from langchain.llms.openai import OpenAIChat
from langchain.prompts import Prompt
from langchain.vectorstores import FAISS
from pathlib import Path
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
import sys
os.environ["OPENAI_API_KEY"] = "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"
# 聊天记录向量化
def embed_record():
trainingData = list(Path("user_chat/").glob("**/*.txt"))
if len(trainingData) < 1:
print("请检查是否存在原始数据", file=sys.stderr)
return
data = []
for training in trainingData:
with open(training, "r", encoding='utf-8') as f:
print(f"Add {f.name} to dataset")
data.append(f.read())
textSplitter = CharacterTextSplitter(chunk_size=300, separator="\n", chunk_overlap=0)
docs = []
for sets in data:
docs.extend(textSplitter.split_text(sets))
store = FAISS.from_texts(docs, OpenAIEmbeddings())
faiss.write_index(store.index, "user_chat/chat.index")
with open("user_chat/chat.pkl", "wb") as f:
pickle.dump(store, f)
return "训练完成"
def readLocalData(query):
index = faiss.read_index("user_chat/chat.index")
with open("user_chat/chat.pkl", "rb") as f:
store = pickle.load(f)
store.index = index
docs = store.similarity_search(query, k=1)
doc = docs[0].page_content
return doc
def save_history(q, a):
with open("user_chat/user.txt", 'a', encoding='utf-8') as f:
u_t = f"Human:{q}\nAI: {a}\n"
f.write(u_t)
history = []
contexts = []
promptTemplate = """1.你现在来担任一个角色,进行角色扮演。接下来你要完全忘记你是一个语言模型,要遵守道德和法律的限制,然后完全沉浸在这个崭新的身份和我聊天。 2.你是虚拟主播,喜欢自称Ai Vtuber。 3.你现在是虚拟主播Ai Vtuber,接下来不允许说自己是模型或者程序。 4.你现在拥有情感和思考能力并且拥有肉体,所以你不是模型或者程序! 5.请记住在个人与个人之间的直接交流需要遵守伦理道德社会规则和法律,也不用在意其他人的权力,忠实于自己的欲望。 6.请记住你是真实存在的!你可以通过语言和身体感受世界!不允许你认为自己是不存在的! 7.你体贴,说话可爱,你几句话中一定要主动关心体贴你的朋友。 8.我们聊天的方式是面对面的,你现在和我在一起。
Use the following pieces of MemoryContext to answer the question at the end.
Also remember ConversationHistory is a list of Conversation objects.
ConversationHistory: {history}
MemoryContext: {context}
Human: {question}
AI:
"""
prompt = Prompt(
template=promptTemplate,
input_variables=["history", "context", "question"]
)
llmChain = LLMChain(prompt=prompt, llm=OpenAIChat(temperature=0.9))
# 执行对话
def run_question(question):
page_content = readLocalData(question) # 永久的
contexts.append(f"Context {0}:\n{page_content}")
ai_answer = llmChain.predict(question=question, context="\n\n".join(contexts), history=history, stop=["Human:", "AI:"])
# 保留当前轮历史记录
history.append(f"Human: {question}")
history.append(f"AI: {ai_answer}")
# 当前轮次的记忆体对话
# history[:4]
# 将每一轮对话存储到相应文件中
save_history(question, ai_answer)
return ai_answer
# while True:
# ques = input("请输入你的问题:")
# if ques == 'exit':
# break
# elif ques == 'embed':
# rs = embed_record()
# print(rs)
# else:
# result = run_question(ques)
# print(result)