acpotts commited on
Commit
1c8239b
1 Parent(s): 1ec0b15

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +36 -64
  2. requirements.txt +0 -1
app.py CHANGED
@@ -7,21 +7,18 @@ from aimakerspace.openai_utils.prompts import (
7
  SystemRolePrompt,
8
  AssistantRolePrompt,
9
  )
10
- # from aimakerspace.openai_utils.embedding import EmbeddingModel
11
- # from aimakerspace.vectordatabase import VectorDatabase
12
- from langchain_openai import ChatOpenAI
13
- # from aimakerspace.openai_utils.chatmodel import ChatOpenAI
14
  import chainlit as cl
15
  from langchain_text_splitters import RecursiveCharacterTextSplitter
16
- # from langchain_experimental.text_splitter import SemanticChunker
17
- # from langchain_openai.embeddings import OpenAIEmbeddings
18
- from sentence_transformers import SentenceTransformer
19
  from langchain_huggingface import HuggingFaceEmbeddings
20
- from langchain_community.vectorstores import FAISS
21
- from langchain_openai.embeddings import OpenAIEmbeddings
22
- from langchain_core.documents import Document
23
  from dotenv import load_dotenv
24
- # from langchain.chains import RetrievalQA
25
 
26
  load_dotenv()
27
 
@@ -37,27 +34,27 @@ Question:
37
  """
38
  user_role_prompt = UserRolePrompt(user_prompt_template)
39
 
40
- # class RetrievalAugmentedQAPipeline:
41
- # def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
42
- # self.llm = llm
43
- # self.vector_db_retriever = vector_db_retriever
44
 
45
- # async def arun_pipeline(self, user_query: str):
46
- # context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
47
 
48
- # context_prompt = ""
49
- # for context in context_list:
50
- # context_prompt += context[0] + "\n"
51
 
52
- # formatted_system_prompt = system_role_prompt.create_message()
53
 
54
- # formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
55
 
56
- # async def generate_response():
57
- # async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
58
- # yield chunk
59
 
60
- # return {"response": generate_response(), "context": context_list}
61
 
62
  text_splitter = RecursiveCharacterTextSplitter()
63
 
@@ -80,8 +77,8 @@ def process_text_file(file: AskFileResponse):
80
  documents = pdf_loader.load()
81
  else:
82
  raise ValueError("Provide a .txt or .pdf file")
83
- #texts = [x.page_content for x in text_splitter.transform_documents(documents)]
84
- # texts = [x.page_content for x in text_splitter.split_documents(documents)]
85
  return text_splitter.split_documents(documents)
86
 
87
 
@@ -100,7 +97,8 @@ async def on_chat_start():
100
  max_files=10
101
  ).send()
102
 
103
- processed_documents = []
 
104
  for file in files:
105
 
106
  msg = cl.Message(
@@ -110,50 +108,26 @@ async def on_chat_start():
110
 
111
  # load the file
112
  texts = process_text_file(file)
113
- processed_documents.extend(texts)
114
  print(f"Processing {len(texts)} text chunks")
115
 
116
  # Create a dict vector store
117
- # vector_db = VectorDatabase()
118
- # vector_db = await vector_db.abuild_from_list(texts)
119
 
120
- # chat_openai = ChatOpenAI()
121
 
122
- # Create a chain
123
- # retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
124
- # vector_db_retriever=vector_db,
125
- # llm=chat_openai
126
- # )
127
-
128
- # model = SentenceTransformer("acpotts/finetuned_arctic")
129
-
130
- finetune_embeddings = HuggingFaceEmbeddings(model_name='acpotts/finetuned_arctic')
131
 
132
- finetune_vectorstore = FAISS.from_documents(processed_documents, finetune_embeddings)
133
- finetune_retriever = finetune_vectorstore.as_retriever(search_kwargs={"k": 6})
134
-
135
- from operator import itemgetter
136
- from langchain_core.output_parsers import StrOutputParser
137
- from langchain_core.runnables import RunnablePassthrough, RunnableParallel
138
-
139
- rag_llm = ChatOpenAI(
140
- model="gpt-4o-mini",
141
- temperature=0
142
- )
143
-
144
-
145
- finetune_rag_chain = (
146
- {"context": itemgetter("question") | finetune_retriever, "question": itemgetter("question")}
147
- | RunnablePassthrough.assign(context=itemgetter("context"))
148
- | {"response": system_template | rag_llm | StrOutputParser(), "context": itemgetter("context")}
149
  )
150
-
151
 
152
  # Let the user know that the system is ready
153
  msg.content = f"Processing `{file.name}` done. You can now ask questions!"
154
  await msg.update()
155
 
156
- cl.user_session.set("chain", finetune_rag_chain)
157
 
158
 
159
  @cl.on_message
@@ -161,9 +135,7 @@ async def main(message):
161
  chain = cl.user_session.get("chain")
162
 
163
  msg = cl.Message(content="")
164
- # finetune_rag_chain.invoke({"question": message.content})
165
- # result = await chain.arun_pipeline(message.content)
166
- result = await chain.arun_pipeline({'question': message.content})
167
 
168
  async for stream_resp in result["response"]:
169
  await msg.stream_token(stream_resp)
 
7
  SystemRolePrompt,
8
  AssistantRolePrompt,
9
  )
10
+ from aimakerspace.openai_utils.embedding import EmbeddingModel
11
+ from aimakerspace.vectordatabase import VectorDatabase
12
+ from aimakerspace.openai_utils.chatmodel import ChatOpenAI
 
13
  import chainlit as cl
14
  from langchain_text_splitters import RecursiveCharacterTextSplitter
15
+ # from sentence_transformers import SentenceTransformer
 
 
16
  from langchain_huggingface import HuggingFaceEmbeddings
17
+ # from langchain_community.vectorstores import FAISS
18
+ # from langchain_openai.embeddings import OpenAIEmbeddings
19
+ # from langchain_core.documents import Document
20
  from dotenv import load_dotenv
21
+
22
 
23
  load_dotenv()
24
 
 
34
  """
35
  user_role_prompt = UserRolePrompt(user_prompt_template)
36
 
37
+ class RetrievalAugmentedQAPipeline:
38
+ def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
39
+ self.llm = llm
40
+ self.vector_db_retriever = vector_db_retriever
41
 
42
+ async def arun_pipeline(self, user_query: str):
43
+ context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
44
 
45
+ context_prompt = ""
46
+ for context in context_list:
47
+ context_prompt += context[0] + "\n"
48
 
49
+ formatted_system_prompt = system_role_prompt.create_message()
50
 
51
+ formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
52
 
53
+ async def generate_response():
54
+ async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
55
+ yield chunk
56
 
57
+ return {"response": generate_response(), "context": context_list}
58
 
59
  text_splitter = RecursiveCharacterTextSplitter()
60
 
 
77
  documents = pdf_loader.load()
78
  else:
79
  raise ValueError("Provide a .txt or .pdf file")
80
+ texts = [x.page_content for x in text_splitter.transform_documents(documents)]
81
+
82
  return text_splitter.split_documents(documents)
83
 
84
 
 
97
  max_files=10
98
  ).send()
99
 
100
+ embedding_model = HuggingFaceEmbeddings(model_name='acpotts/finetuned_arctic')
101
+ vector_db = VectorDatabase(embedding_model=embedding_model)
102
  for file in files:
103
 
104
  msg = cl.Message(
 
108
 
109
  # load the file
110
  texts = process_text_file(file)
111
+
112
  print(f"Processing {len(texts)} text chunks")
113
 
114
  # Create a dict vector store
 
 
115
 
116
+ vector_db = await vector_db.abuild_from_list(texts)
117
 
118
+ chat_openai = ChatOpenAI()
 
 
 
 
 
 
 
 
119
 
120
+ #Create a chain
121
+ retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
122
+ vector_db_retriever=vector_db,
123
+ llm=chat_openai
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  )
 
125
 
126
  # Let the user know that the system is ready
127
  msg.content = f"Processing `{file.name}` done. You can now ask questions!"
128
  await msg.update()
129
 
130
+ cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
131
 
132
 
133
  @cl.on_message
 
135
  chain = cl.user_session.get("chain")
136
 
137
  msg = cl.Message(content="")
138
+ result = await chain.arun_pipeline(message.content)
 
 
139
 
140
  async for stream_resp in result["response"]:
141
  await msg.stream_token(stream_resp)
requirements.txt CHANGED
@@ -7,4 +7,3 @@ pypdf
7
  sentence_transformers
8
  langchain_text_splitters
9
  langchain-community
10
- faiss-cpu
 
7
  sentence_transformers
8
  langchain_text_splitters
9
  langchain-community