Spaces:
Runtime error
Runtime error
Update build_langchain_vector_store.py
Browse files- build_langchain_vector_store.py +121 -0
build_langchain_vector_store.py
CHANGED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Builds and persists a LangChain vector store over the Website documentation using Chroma.
|
4 |
+
Source: https://github.com/Arize-ai/phoenix/blob/main/scripts/data/build_langchain_vector_store.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import getpass
|
9 |
+
import logging
|
10 |
+
import shutil
|
11 |
+
import sys
|
12 |
+
from functools import partial
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
from langchain.docstore.document import Document as LangChainDocument
|
16 |
+
from langchain.document_loaders import GitbookLoader
|
17 |
+
from langchain.embeddings import OpenAIEmbeddings
|
18 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
19 |
+
from langchain.vectorstores import Chroma
|
20 |
+
from tiktoken import Encoding, encoding_for_model
|
21 |
+
|
22 |
+
|
23 |
+
def load_gitbook_docs(docs_url: str) -> List[LangChainDocument]:
|
24 |
+
"""Loads documents from a Gitbook URL.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
docs_url (str): URL to Gitbook docs.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
List[LangChainDocument]: List of documents in LangChain format.
|
31 |
+
"""
|
32 |
+
loader = GitbookLoader(
|
33 |
+
docs_url,
|
34 |
+
load_all_paths=True,
|
35 |
+
)
|
36 |
+
return loader.load()
|
37 |
+
|
38 |
+
|
39 |
+
def tiktoken_len(text: str, tokenizer: Encoding) -> int:
|
40 |
+
"""Returns the length of a text in tokens.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
text (str): The text to tokenize and count.
|
44 |
+
tokenizer (tiktoken.Encoding): The tokenizer.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
int: The number of tokens in the text.
|
48 |
+
"""
|
49 |
+
|
50 |
+
tokens = tokenizer.encode(text, disallowed_special=())
|
51 |
+
return len(tokens)
|
52 |
+
|
53 |
+
|
54 |
+
def chunk_docs(
|
55 |
+
documents: List[LangChainDocument],
|
56 |
+
tokenizer: Encoding,
|
57 |
+
chunk_size: int = 400,
|
58 |
+
chunk_overlap: int = 20,
|
59 |
+
) -> List[LangChainDocument]:
|
60 |
+
"""Chunks the documents.
|
61 |
+
|
62 |
+
The chunking strategy used in this function is from the following notebook and accompanying
|
63 |
+
video:
|
64 |
+
|
65 |
+
- https://github.com/pinecone-io/examples/blob/master/generation/langchain/handbook/
|
66 |
+
xx-langchain-chunking.ipynb
|
67 |
+
- https://www.youtube.com/watch?v=eqOfr4AGLk8
|
68 |
+
|
69 |
+
Args:
|
70 |
+
documents (List[LangChainDocument]): A list of input documents.
|
71 |
+
|
72 |
+
tokenizer (tiktoken.Encoding): The tokenizer used to count the number of tokens in a text.
|
73 |
+
|
74 |
+
chunk_size (int, optional): The size of the chunks in tokens.
|
75 |
+
|
76 |
+
chunk_overlap (int, optional): The chunk overlap in tokens.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
List[LangChainDocument]: The chunked documents.
|
80 |
+
"""
|
81 |
+
|
82 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
83 |
+
chunk_size=chunk_size,
|
84 |
+
chunk_overlap=chunk_overlap,
|
85 |
+
length_function=partial(tiktoken_len, tokenizer=tokenizer),
|
86 |
+
separators=["\n\n", "\n", " ", ""],
|
87 |
+
)
|
88 |
+
return text_splitter.split_documents(documents)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
93 |
+
|
94 |
+
parser = argparse.ArgumentParser()
|
95 |
+
parser.add_argument(
|
96 |
+
"--persist-path",
|
97 |
+
type=str,
|
98 |
+
required=False,
|
99 |
+
help="Path to persist index.",
|
100 |
+
default="langchain-chroma-pulze-docs",
|
101 |
+
)
|
102 |
+
args = parser.parse_args()
|
103 |
+
|
104 |
+
docs_url = "https://docs.pulze.ai/"
|
105 |
+
embedding_model_name = "text-embedding-ada-002"
|
106 |
+
langchain_documents = load_gitbook_docs(docs_url)
|
107 |
+
chunked_langchain_documents = chunk_docs(
|
108 |
+
langchain_documents,
|
109 |
+
tokenizer=encoding_for_model(embedding_model_name),
|
110 |
+
chunk_size=200,
|
111 |
+
)
|
112 |
+
|
113 |
+
embedding_model = OpenAIEmbeddings(model=embedding_model_name)
|
114 |
+
shutil.rmtree(args.persist_path, ignore_errors=True)
|
115 |
+
vector_store = Chroma.from_documents(
|
116 |
+
chunked_langchain_documents, embedding=embedding_model, persist_directory=args.persist_path
|
117 |
+
)
|
118 |
+
read_vector_store = Chroma(
|
119 |
+
persist_directory=args.persist_path, embedding_function=embedding_model
|
120 |
+
)
|
121 |
+
# print(read_vector_store.similarity_search("How do I use Pulze?"))
|