Vincent Claes commited on
Commit
9c2548e
1 Parent(s): ba2c923

use chromadb and openai to query docs

Browse files
Files changed (7) hide show
  1. Makefile +3 -1
  2. README.md +4 -6
  3. app.py +48 -57
  4. import_data.py +0 -81
  5. poetry.lock +0 -0
  6. pyproject.toml +0 -20
  7. requirements.txt +7 -117
Makefile CHANGED
@@ -1,2 +1,4 @@
1
  deps:
2
- poetry export --without-hashes --format=requirements.txt > requirements.txt
 
 
 
1
  deps:
2
+ pip install -U virtualenv
3
+ virtualenv .venv
4
+ .venv/bin/pip install -r requirements.txt
README.md CHANGED
@@ -12,10 +12,8 @@ pinned: false
12
  # Internal DOC QA
13
 
14
  ```bash
15
- poetry shell
16
- poetry install
17
- export OPENAI_API_KEY=<...>
18
- export VERBA_URL=<...>
19
- export VERBA_API_KEY=<...>
20
- verba start --model "gpt-3.5-turbo"
21
  ```
 
12
  # Internal DOC QA
13
 
14
  ```bash
15
+ make deps
16
+
17
+ source .venv/bin/activate
18
+ python app.py
 
 
19
  ```
app.py CHANGED
@@ -1,75 +1,66 @@
1
- import os
2
-
3
  import gradio as gr
4
- import weaviate
5
- from langchain import LLMChain
6
- from langchain.chains import SequentialChain
 
 
 
7
  from langchain.chat_models import ChatOpenAI
8
- from langchain.prompts import ChatPromptTemplate
9
 
10
- collection_name = "Chunk"
11
 
12
- MODEL = "gpt-3.5-turbo"
13
- LANGUAGE = "en" # nl / en
14
- llm = ChatOpenAI(temperature=0.0, openai_api_key=os.environ["OPENAI_API_KEY"])
 
 
 
 
15
 
 
 
 
16
 
17
- def get_answer_given_the_context(llm, prompt, context) -> SequentialChain:
18
- template = f"""
19
- Provide an answer to the prompt given the context.
20
-
21
- <PROMPT>
22
-
23
- {prompt}
24
-
25
- <CONTEXT>
26
-
27
- {context}
28
-
29
- """
30
 
31
- prompt_get_skills_intersection = ChatPromptTemplate.from_template(template=template)
32
- skills_match_chain = LLMChain(
33
- llm=llm,
34
- prompt=prompt_get_skills_intersection,
35
- output_key="answer",
36
- )
37
 
38
- chain = SequentialChain(
39
- chains=[skills_match_chain],
40
- input_variables=["prompt", "context"],
41
- output_variables=[
42
- skills_match_chain.output_key,
43
- ],
44
- verbose=False,
45
- )
46
- return chain({"prompt": prompt, "context": context})["answer"]
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- def predict(prompt):
50
- client = weaviate.Client(
51
- url=os.environ["WEAVIATE_URL"],
52
- auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
53
- additional_headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]},
54
- )
55
 
56
- search_result = (
57
- client.query.get(class_name=collection_name, properties=["text"])
58
- .with_near_text({"concepts": prompt})
59
- # .with_generate(single_prompt="{text}")
60
- .with_limit(5)
61
- .do()
62
- )
63
- context_list = [
64
- element["text"] for element in search_result["data"]["Get"]["Chunk"]
65
- ]
66
- context = "\n".join(context_list)
67
 
68
- return get_answer_given_the_context(llm=llm, prompt=prompt, context=context)
 
 
69
 
70
 
71
  iface = gr.Interface(
72
- fn=predict, # the function to wrap
73
  inputs="text", # the input type
74
  outputs="text", # the output type
75
  examples=[
 
 
 
1
  import gradio as gr
2
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain.embeddings import OpenAIEmbeddings
5
+ from langchain.vectorstores import Chroma
6
+ from langchain.retrievers import SVMRetriever
7
+ from langchain.chains import RetrievalQA
8
  from langchain.chat_models import ChatOpenAI
 
9
 
 
10
 
11
+ def load_data():
12
+ # load the documents
13
+ loader = DirectoryLoader('./data', glob="**/*.pdf", show_progress=True, loader_cls=PyPDFLoader)
14
+ docs = loader.load()
15
+ # replace all new lines with spaces
16
+ [setattr(doc, "page_content", doc.page_content.replace("\n", " ")) for doc in docs]
17
+ print(docs)
18
 
19
+ # split the documents into chunks
20
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 50)
21
+ all_splits = text_splitter.split_documents(docs)
22
 
23
+ # construct vector store
24
+ vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())
25
+ # https://python.langchain.com/docs/use_cases/question_answering.html#go-deeper-3
26
+ svm_retriever = SVMRetriever.from_documents(all_splits, OpenAIEmbeddings())
27
+ return svm_retriever, vectorstore
 
 
 
 
 
 
 
 
28
 
29
+ svm_retriever, vectorstore = load_data()
 
 
 
 
 
30
 
31
+ def process_question(question, svm_retriever=svm_retriever, vectorstore=vectorstore):
 
 
 
 
 
 
 
 
32
 
33
+ docs_svm=svm_retriever.get_relevant_documents(question)
34
+ print(len(docs_svm))
35
+ llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
36
+ qa_chain = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever(), return_source_documents=True)
37
+ result = qa_chain({"query": question})
38
+
39
+ output = f"""
40
+ ============RESULT==============
41
+ \n
42
+ {result["result"]}
43
+ \n
44
+ ============SOURCES=============
45
+ """
46
 
47
+ # Initialize an empty list to hold the lines
48
+ lines = []
 
 
 
 
49
 
50
+ source_docs = [(x.metadata["source"], x.page_content) for x in result["source_documents"]]
51
+ for i, doc in enumerate(source_docs):
52
+ lines.append(f"* CHUNK: {i} *")
53
+ lines.append(f"original doc: {doc[0]}")
54
+ lines.append(f"{doc[1]}")
55
+ lines.append('') # for a newline between chunks
 
 
 
 
 
56
 
57
+ # Join the lines with a newline character to get the multi-line string
58
+ output += '\n'.join(lines)
59
+ return output
60
 
61
 
62
  iface = gr.Interface(
63
+ fn=process_question, # the function to wrap
64
  inputs="text", # the input type
65
  outputs="text", # the output type
66
  examples=[
import_data.py DELETED
@@ -1,81 +0,0 @@
1
- import os
2
- import weaviate
3
- from llama_index import download_loader
4
- from llama_index.vector_stores import WeaviateVectorStore
5
- from llama_index import VectorStoreIndex, StorageContext
6
- from pathlib import Path
7
- import argparse
8
-
9
-
10
- def get_pdf_files(base_path, loader):
11
- """
12
- Get paths to all PDF files in a directory and its subdirectories.
13
-
14
- Parameters:
15
- - base_path (str): The path to the starting directory.
16
-
17
- Returns:
18
- - list of str: A list of paths to all PDF files found.
19
- """
20
- pdf_paths = []
21
-
22
- # Check if the base path exists and is a directory
23
- if not os.path.exists(base_path):
24
- raise FileNotFoundError(f"The specified base path does not exist: {base_path}")
25
- if not os.path.isdir(base_path):
26
- raise NotADirectoryError(
27
- f"The specified base_path is not a directory: {base_path}"
28
- )
29
-
30
- # Loop through all directories and files starting from the base path
31
- for root, dirs, files in os.walk(base_path):
32
- for filename in files:
33
- # If a file has a .pdf extension, add its path to the list
34
- if filename.endswith(".pdf"):
35
- pdf_file = loader.load_data(file=Path(root, filename))
36
- pdf_paths.extend(pdf_file)
37
-
38
- return pdf_paths
39
-
40
-
41
- def main(args):
42
- PDFReader = download_loader("PDFReader")
43
- loader = PDFReader()
44
-
45
- documents = get_pdf_files(args.pdf_dir, loader)
46
-
47
- client = weaviate.Client(
48
- url=os.environ["WEAVIATE_URL"],
49
- auth_client_secret=weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]),
50
- additional_headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]},
51
- )
52
-
53
- # construct vector store
54
- vector_store = WeaviateVectorStore(
55
- weaviate_client=client, index_name=args.customer, text_key="content"
56
- )
57
-
58
- # setting up the storage for the embeddings
59
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
60
-
61
- # set up the index
62
- index = VectorStoreIndex(documents, storage_context=storage_context)
63
- query_engine = index.as_query_engine()
64
- response = query_engine.query(args.query)
65
- print(response)
66
-
67
-
68
- if __name__ == "__main__":
69
- parser = argparse.ArgumentParser(description="Process and query PDF files.")
70
-
71
- parser.add_argument("--customer", default="Ausy", help="Customer name")
72
- parser.add_argument("--pdf_dir", default="./data", help="Directory containing PDFs")
73
- parser.add_argument(
74
- "--query",
75
- default="What is CX0 customer exprience office?",
76
- help="Query to execute",
77
- )
78
-
79
- args = parser.parse_args()
80
-
81
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
poetry.lock DELETED
The diff for this file is too large to render. See raw diff
 
pyproject.toml DELETED
@@ -1,20 +0,0 @@
1
- [tool.poetry]
2
- name = "ausy-rag-demo"
3
- version = "0.1.0"
4
- description = ""
5
- authors = ["Vincent Claes <[email protected]>"]
6
- readme = "README.md"
7
- packages = [{include = "ausy_rag_demo"}]
8
-
9
- [tool.poetry.dependencies]
10
- python = "^3.9"
11
- llama-index = "^0.8.29.post1"
12
- weaviate-client = "^3.24.1"
13
- pypdf = "^3.16.1"
14
- goldenverba = "^0.2.3"
15
- gradio = "^3.44.4"
16
-
17
-
18
- [build-system]
19
- requires = ["poetry-core"]
20
- build-backend = "poetry.core.masonry.api"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,117 +1,7 @@
1
- aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
2
- aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "4.0"
3
- aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
4
- altair==5.1.1 ; python_version >= "3.9" and python_version < "4.0"
5
- annotated-types==0.5.0 ; python_version >= "3.9" and python_version < "4.0"
6
- anyio==3.7.1 ; python_version >= "3.9" and python_version < "4.0"
7
- async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "4.0"
8
- attrs==23.1.0 ; python_version >= "3.9" and python_version < "4.0"
9
- authlib==1.2.1 ; python_version >= "3.9" and python_version < "4.0"
10
- beautifulsoup4==4.12.2 ; python_version >= "3.9" and python_version < "4.0"
11
- blis==0.7.10 ; python_version >= "3.9" and python_version < "4.0"
12
- catalogue==2.0.9 ; python_version >= "3.9" and python_version < "4.0"
13
- certifi==2023.7.22 ; python_version >= "3.9" and python_version < "4.0"
14
- cffi==1.15.1 ; python_version >= "3.9" and python_version < "4.0"
15
- charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "4.0"
16
- click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
17
- colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
18
- confection==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
19
- contourpy==1.1.1 ; python_version >= "3.9" and python_version < "4.0"
20
- cryptography==41.0.4 ; python_version >= "3.9" and python_version < "4.0"
21
- cycler==0.11.0 ; python_version >= "3.9" and python_version < "4.0"
22
- cymem==2.0.8 ; python_version >= "3.9" and python_version < "4.0"
23
- dataclasses-json==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
24
- exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
25
- fastapi==0.103.1 ; python_version >= "3.9" and python_version < "4.0"
26
- ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
27
- filelock==3.12.4 ; python_version >= "3.9" and python_version < "4.0"
28
- fonttools==4.42.1 ; python_version >= "3.9" and python_version < "4.0"
29
- frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "4.0"
30
- fsspec==2023.9.1 ; python_version >= "3.9" and python_version < "4.0"
31
- goldenverba==0.2.3 ; python_version >= "3.9" and python_version < "4.0"
32
- gradio-client==0.5.1 ; python_version >= "3.9" and python_version < "4.0"
33
- gradio==3.44.4 ; python_version >= "3.9" and python_version < "4.0"
34
- greenlet==2.0.2 ; python_version >= "3.9" and python_version < "4.0" and (platform_machine == "win32" or platform_machine == "WIN32" or platform_machine == "AMD64" or platform_machine == "amd64" or platform_machine == "x86_64" or platform_machine == "ppc64le" or platform_machine == "aarch64")
35
- h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
36
- httpcore==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
37
- httptools==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
38
- httpx==0.25.0 ; python_version >= "3.9" and python_version < "4.0"
39
- huggingface-hub==0.17.2 ; python_version >= "3.9" and python_version < "4.0"
40
- idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
41
- importlib-metadata==6.8.0 ; python_version >= "3.9" and python_version < "3.10"
42
- importlib-resources==6.1.0 ; python_version >= "3.9" and python_version < "4.0"
43
- jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
44
- joblib==1.3.2 ; python_version >= "3.9" and python_version < "4.0"
45
- jsonschema-specifications==2023.7.1 ; python_version >= "3.9" and python_version < "4.0"
46
- jsonschema==4.19.1 ; python_version >= "3.9" and python_version < "4.0"
47
- kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
48
- langchain==0.0.296 ; python_version >= "3.9" and python_version < "4.0"
49
- langcodes==3.3.0 ; python_version >= "3.9" and python_version < "4.0"
50
- langsmith==0.0.38 ; python_version >= "3.9" and python_version < "4.0"
51
- llama-index==0.8.29.post1 ; python_version >= "3.9" and python_version < "4.0"
52
- markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
53
- marshmallow==3.20.1 ; python_version >= "3.9" and python_version < "4.0"
54
- matplotlib==3.8.0 ; python_version >= "3.9" and python_version < "4.0"
55
- multidict==6.0.4 ; python_version >= "3.9" and python_version < "4.0"
56
- murmurhash==1.0.10 ; python_version >= "3.9" and python_version < "4.0"
57
- mypy-extensions==1.0.0 ; python_version >= "3.9" and python_version < "4.0"
58
- nest-asyncio==1.5.8 ; python_version >= "3.9" and python_version < "4.0"
59
- nltk==3.8.1 ; python_version >= "3.9" and python_version < "4.0"
60
- numexpr==2.8.6 ; python_version >= "3.9" and python_version < "4.0"
61
- numpy==1.25.2 ; python_version >= "3.9" and python_version < "4.0"
62
- openai==0.28.0 ; python_version >= "3.9" and python_version < "4.0"
63
- orjson==3.9.7 ; python_version >= "3.9" and python_version < "4.0"
64
- packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
65
- pandas==2.1.0 ; python_version >= "3.9" and python_version < "4.0"
66
- pathy==0.10.2 ; python_version >= "3.9" and python_version < "4.0"
67
- pillow==10.0.1 ; python_version >= "3.9" and python_version < "4.0"
68
- preshed==3.0.9 ; python_version >= "3.9" and python_version < "4.0"
69
- pycparser==2.21 ; python_version >= "3.9" and python_version < "4.0"
70
- pydantic-core==2.6.3 ; python_version >= "3.9" and python_version < "4.0"
71
- pydantic==2.3.0 ; python_version >= "3.9" and python_version < "4.0"
72
- pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
73
- pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
74
- pypdf==3.16.1 ; python_version >= "3.9" and python_version < "4.0"
75
- python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
76
- python-dotenv==1.0.0 ; python_version >= "3.9" and python_version < "4.0"
77
- python-multipart==0.0.6 ; python_version >= "3.9" and python_version < "4.0"
78
- pytz==2023.3.post1 ; python_version >= "3.9" and python_version < "4.0"
79
- pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
80
- referencing==0.30.2 ; python_version >= "3.9" and python_version < "4.0"
81
- regex==2023.8.8 ; python_version >= "3.9" and python_version < "4.0"
82
- requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
83
- rpds-py==0.10.3 ; python_version >= "3.9" and python_version < "4.0"
84
- semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
85
- setuptools-scm==8.0.1 ; python_version >= "3.9" and python_version < "4.0"
86
- setuptools==68.2.2 ; python_version >= "3.9" and python_version < "4.0"
87
- six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
88
- smart-open==6.4.0 ; python_version >= "3.9" and python_version < "4.0"
89
- sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
90
- soupsieve==2.5 ; python_version >= "3.9" and python_version < "4.0"
91
- spacy-legacy==3.0.12 ; python_version >= "3.9" and python_version < "4.0"
92
- spacy-loggers==1.0.5 ; python_version >= "3.9" and python_version < "4.0"
93
- spacy==3.6.1 ; python_version >= "3.9" and python_version < "4.0"
94
- sqlalchemy==2.0.21 ; python_version >= "3.9" and python_version < "4.0"
95
- srsly==2.4.7 ; python_version >= "3.9" and python_version < "4.0"
96
- starlette==0.27.0 ; python_version >= "3.9" and python_version < "4.0"
97
- tenacity==8.2.3 ; python_version >= "3.9" and python_version < "4.0"
98
- thinc==8.1.12 ; python_version >= "3.9" and python_version < "4.0"
99
- tiktoken==0.5.1 ; python_version >= "3.9" and python_version < "4.0"
100
- tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
101
- toolz==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
102
- tqdm==4.66.1 ; python_version >= "3.9" and python_version < "4.0"
103
- typer==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
104
- typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "4.0"
105
- typing-inspect==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
106
- tzdata==2023.3 ; python_version >= "3.9" and python_version < "4.0"
107
- urllib3==1.26.16 ; python_version >= "3.9" and python_version < "4.0"
108
- uvicorn==0.23.2 ; python_version >= "3.9" and python_version < "4.0"
109
- uvicorn[standard]==0.23.2 ; python_version >= "3.9" and python_version < "4.0"
110
- uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "4.0"
111
- validators==0.22.0 ; python_version >= "3.9" and python_version < "4.0"
112
- wasabi==1.1.2 ; python_version >= "3.9" and python_version < "4.0"
113
- watchfiles==0.20.0 ; python_version >= "3.9" and python_version < "4.0"
114
- weaviate-client==3.24.1 ; python_version >= "3.9" and python_version < "4.0"
115
- websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
116
- yarl==1.9.2 ; python_version >= "3.9" and python_version < "4.0"
117
- zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"
 
1
+ openai
2
+ chromadb
3
+ langchain
4
+ pypdf
5
+ tiktoken
6
+ scikit-learn
7
+ gradio