kartavya23 commited on
Commit
47ad957
1 Parent(s): 5f9e152

Upload 4 files

Browse files
rag_101/__pycache__/retriever.cpython-39.pyc ADDED
Binary file (4.88 kB). View file
 
rag_101/client.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.callbacks import FileCallbackHandler
2
+ from langchain_community.chat_models import ChatOllama
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from loguru import logger
6
+
7
+ from rag_101.retriever import (
8
+ RAGException,
9
+ create_parent_retriever,
10
+ load_embedding_model,
11
+ load_pdf,
12
+ load_reranker_model,
13
+ retrieve_context,
14
+ )
15
+
16
+
17
+ class RAGClient:
18
+ embedding_model = load_embedding_model()
19
+ reranker_model = load_reranker_model()
20
+
21
+ def __init__(self, files, model="mistral"):
22
+ docs = load_pdf(files=files)
23
+ self.retriever = create_parent_retriever(docs, self.embedding_model)
24
+
25
+ llm = ChatOllama(model=model)
26
+ prompt_template = ChatPromptTemplate.from_template(
27
+ (
28
+ "Please answer the following question based on the provided `context` that follows the question.\n"
29
+ "Think step by step before coming to answer. If you do not know the answer then just say 'I do not know'\n"
30
+ "question: {question}\n"
31
+ "context: ```{context}```\n"
32
+ )
33
+ )
34
+ self.chain = prompt_template | llm | StrOutputParser()
35
+
36
+ def stream(self, query: str) -> dict:
37
+ try:
38
+ context, similarity_score = self.retrieve_context(query)[0]
39
+ context = context.page_content
40
+ if similarity_score < 0.005:
41
+ context = "This context is not confident. " + context
42
+ except RAGException as e:
43
+ context, similarity_score = e.args[0], 0
44
+ logger.info(context)
45
+ for r in self.chain.stream({"context": context, "question": query}):
46
+ yield r
47
+
48
+ def retrieve_context(self, query: str):
49
+ return retrieve_context(
50
+ query, retriever=self.retriever, reranker_model=self.reranker_model
51
+ )
52
+
53
+ def generate(self, query: str) -> dict:
54
+ contexts = self.retrieve_context(query)
55
+
56
+ return {
57
+ "contexts": contexts,
58
+ "response": self.chain.invoke(
59
+ {"context": contexts[0][0].page_content, "question": query}
60
+ ),
61
+ }
rag_101/rag.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List, Optional, Union
3
+
4
+ from langchain_community.chat_models import ChatOllama
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from retriever import (
8
+ create_parent_retriever,
9
+ load_embedding_model,
10
+ load_pdf,
11
+ load_reranker_model,
12
+ retrieve_context,
13
+ )
14
+
15
+
16
+ def main(
17
+ file: str = "2401.08406v3.pdf",
18
+ llm_name="mistral",
19
+ ):
20
+ docs = load_pdf(files=file)
21
+
22
+ embedding_model = load_embedding_model()
23
+ retriever = create_parent_retriever(docs, embedding_model)
24
+ reranker_model = load_reranker_model()
25
+
26
+ llm = ChatOllama(model=llm_name)
27
+ prompt_template = ChatPromptTemplate.from_template(
28
+ (
29
+ "Please answer the following question based on the provided `context` that follows the question.\n"
30
+ "If you do not know the answer then just say 'I do not know'\n"
31
+ "question: {question}\n"
32
+ "context: ```{context}```\n"
33
+ )
34
+ )
35
+ chain = prompt_template | llm | StrOutputParser()
36
+
37
+ while True:
38
+ query = input("Ask question: ")
39
+ context = retrieve_context(
40
+ query, retriever=retriever, reranker_model=reranker_model
41
+ )[0]
42
+ print("LLM Response: ", end="")
43
+ for e in chain.stream({"context": context[0].page_content, "question": query}):
44
+ print(e, end="")
45
+ print()
46
+ time.sleep(0.1)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ from jsonargparse import CLI
51
+
52
+ CLI(main)
rag_101/retriever.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["HF_HOME"] = "weights"
4
+ os.environ["TORCH_HOME"] = "weights"
5
+
6
+ from typing import List, Optional, Union
7
+
8
+ from langchain.callbacks import FileCallbackHandler
9
+ from langchain.retrievers import ContextualCompressionRetriever, ParentDocumentRetriever
10
+ from langchain.retrievers.document_compressors import EmbeddingsFilter
11
+ from langchain.storage import InMemoryStore
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_community.document_loaders import UnstructuredFileLoader
14
+ from langchain_community.embeddings import HuggingFaceEmbeddings
15
+ from langchain_community.vectorstores import FAISS, Chroma
16
+ from langchain_core.documents import Document
17
+ from loguru import logger
18
+ from rich import print
19
+ from sentence_transformers import CrossEncoder
20
+ from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
21
+
22
+ logfile = "log/output.log"
23
+ logger.add(logfile, colorize=True, enqueue=True)
24
+ handler = FileCallbackHandler(logfile)
25
+
26
+
27
+ persist_directory = None
28
+
29
+
30
+ class RAGException(Exception):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+
35
+ def rerank_docs(reranker_model, query, retrieved_docs):
36
+ query_and_docs = [(query, r.page_content) for r in retrieved_docs]
37
+ scores = reranker_model.predict(query_and_docs)
38
+ return sorted(list(zip(retrieved_docs, scores)), key=lambda x: x[1], reverse=True)
39
+
40
+
41
+ def load_pdf(
42
+ files: Union[str, List[str]] = "2401.08406v3.pdf"
43
+ ) -> List[Document]:
44
+ if isinstance(files, str):
45
+ loader = UnstructuredFileLoader(
46
+ files,
47
+ post_processors=[clean_extra_whitespace, group_broken_paragraphs],
48
+ )
49
+ return loader.load()
50
+
51
+ loaders = [
52
+ UnstructuredFileLoader(
53
+ file,
54
+ post_processors=[clean_extra_whitespace, group_broken_paragraphs],
55
+ )
56
+ for file in files
57
+ ]
58
+ docs = []
59
+ for loader in loaders:
60
+ docs.extend(
61
+ loader.load(),
62
+ )
63
+ return docs
64
+
65
+
66
+ def create_parent_retriever(
67
+ docs: List[Document], embeddings_model: HuggingFaceEmbeddings()
68
+ ):
69
+ parent_splitter = RecursiveCharacterTextSplitter(
70
+ separators=["\n\n\n", "\n\n"],
71
+ chunk_size=2000,
72
+ length_function=len,
73
+ is_separator_regex=False,
74
+ )
75
+
76
+ # This text splitter is used to create the child documents
77
+ child_splitter = RecursiveCharacterTextSplitter(
78
+ separators=["\n\n\n", "\n\n"],
79
+ chunk_size=1000,
80
+ chunk_overlap=300,
81
+ length_function=len,
82
+ is_separator_regex=False,
83
+ )
84
+ # The vectorstore to use to index the child chunks
85
+ vectorstore = Chroma(
86
+ collection_name="split_documents",
87
+ embedding_function=embeddings_model,
88
+ persist_directory=persist_directory,
89
+ )
90
+ # The storage layer for the parent documents
91
+ store = InMemoryStore()
92
+ retriever = ParentDocumentRetriever(
93
+ vectorstore=vectorstore,
94
+ docstore=store,
95
+ child_splitter=child_splitter,
96
+ parent_splitter=parent_splitter,
97
+ k=10,
98
+ )
99
+ retriever.add_documents(docs)
100
+ return retriever
101
+
102
+
103
+ def retrieve_context(query, retriever, reranker_model):
104
+ retrieved_docs = retriever.get_relevant_documents(query)
105
+
106
+ if len(retrieved_docs) == 0:
107
+ raise RAGException(
108
+ f"Couldn't retrieve any relevant document with the query `{query}`. Try modifying your question!"
109
+ )
110
+ reranked_docs = rerank_docs(
111
+ query=query, retrieved_docs=retrieved_docs, reranker_model=reranker_model
112
+ )
113
+ return reranked_docs
114
+
115
+
116
+ def load_embedding_model(
117
+ model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cuda"
118
+ ) -> HuggingFaceEmbeddings:
119
+ model_kwargs = {"device": device}
120
+ encode_kwargs = {
121
+ "normalize_embeddings": True
122
+ } # set True to compute cosine similarity
123
+ embedding_model = HuggingFaceEmbeddings(
124
+ model_name=model_name,
125
+ model_kwargs=model_kwargs,
126
+ encode_kwargs=encode_kwargs,
127
+ )
128
+ return embedding_model
129
+
130
+
131
+ def load_reranker_model(
132
+ reranker_model_name: str = "BAAI/bge-reranker-large", device: str = "cuda"
133
+ ) -> CrossEncoder:
134
+ reranker_model = CrossEncoder(
135
+ model_name=reranker_model_name, max_length=1024, device=device
136
+ )
137
+ return reranker_model
138
+
139
+
140
+ def main(
141
+ file: str = "2401.08406v3.pdf",
142
+ query: Optional[str] = None,
143
+ llm_name="mistral",
144
+ ):
145
+ docs = load_pdf(files=file)
146
+
147
+ embedding_model = load_embedding_model()
148
+ retriever = create_parent_retriever(docs, embedding_model)
149
+ reranker_model = load_reranker_model()
150
+
151
+ context = retrieve_context(
152
+ query, retriever=retriever, reranker_model=reranker_model
153
+ )[0]
154
+ print("context:\n", context, "\n", "=" * 50, "\n")
155
+
156
+
157
+ if __name__ == "__main__":
158
+ from jsonargparse import CLI
159
+
160
+ CLI(main)