Update main.py
Browse files
main.py
CHANGED
@@ -1,134 +1,134 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import bcrypt
|
4 |
-
import chainlit as cl
|
5 |
-
from chainlit.input_widget import TextInput, Select, Switch, Slider
|
6 |
-
from chainlit import user_session
|
7 |
-
from literalai import LiteralClient
|
8 |
-
literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
|
9 |
-
|
10 |
-
from operator import itemgetter
|
11 |
-
from pinecone import Pinecone
|
12 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
13 |
from langchain_community.llms import HuggingFaceEndpoint
|
14 |
-
|
|
|
|
|
15 |
from langchain.schema import StrOutputParser
|
16 |
-
from
|
17 |
-
|
18 |
-
|
19 |
-
from langchain.
|
20 |
-
from
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
async def on_chat_start():
|
53 |
-
await cl.Message(f"> REVIEWSTREAM").send()
|
54 |
-
await cl.Message(f"Nous avons le plaisir de vous accueillir dans l'application de recherche et d'analyse des publications.").send()
|
55 |
-
listPrompts_name = f"Liste des revues de recherche"
|
56 |
-
contentPrompts = """<p><img src='/public/hal-logo-header.png' width='32' align='absmiddle' /> <strong> Hal Archives Ouvertes</strong> : Une archive ouverte est un réservoir numérique contenant des documents issus de la recherche scientifique, généralement déposés par leurs auteurs, et permettant au grand public d'y accéder gratuitement et sans contraintes.
|
57 |
-
</p>
|
58 |
-
<p><img src='/public/logo-persee.png' width='32' align='absmiddle' /> <strong>Persée</strong> : offre un accès libre et gratuit à des collections complètes de publications scientifiques (revues, livres, actes de colloques, publications en série, sources primaires, etc.) associé à une gamme d'outils de recherche et d'exploitation.</p>
|
59 |
-
"""
|
60 |
-
prompt_elements = []
|
61 |
-
prompt_elements.append(
|
62 |
-
cl.Text(content=contentPrompts, name=listPrompts_name, display="side")
|
63 |
)
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
),
|
73 |
-
]
|
74 |
-
).send()
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
@cl.on_message
|
79 |
-
async def main(message: cl.Message):
|
80 |
-
os.environ['PINECONE_API_KEY'] = os.environ['PINECONE_API_KEY']
|
81 |
-
embeddings = HuggingFaceEmbeddings()
|
82 |
-
index_name = "all-venus"
|
83 |
-
pc = Pinecone(
|
84 |
-
api_key=os.environ['PINECONE_API_KEY']
|
85 |
)
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
{context}
|
99 |
-
|
|
|
100 |
"""
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
104 |
|
105 |
-
model = HuggingFaceEndpoint(
|
106 |
-
repo_id=repo_id, max_new_tokens=8000, temperature=1.0, task="text2text-generation", streaming=True
|
107 |
-
)
|
108 |
-
|
109 |
-
prompt = ChatPromptTemplate.from_messages(
|
110 |
-
[
|
111 |
-
(
|
112 |
-
"system",
|
113 |
-
f"Contexte : Vous êtes un chercheur de l'enseignement supérieur et vous êtes doué pour faire des analyses d'articles de recherche sur les thématiques liées à la pédagogie. En fonction des informations suivantes et du contexte suivant seulement et strictement.",
|
114 |
-
),
|
115 |
-
MessagesPlaceholder(variable_name="history"),
|
116 |
-
("human", "Contexte : {context}, réponds à la question suivante de la manière la plus pertinente, la plus exhaustive et la plus détaillée possible. {question}."),
|
117 |
-
]
|
118 |
-
)
|
119 |
runnable = (
|
120 |
-
RunnablePassthrough
|
121 |
-
history=RunnableLambda(memory.load_memory_variables) | itemgetter("history")
|
122 |
-
)
|
123 |
| prompt
|
124 |
| model
|
|
|
125 |
)
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
from langchain_community.llms import HuggingFaceEndpoint
|
5 |
+
|
6 |
+
#from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
7 |
+
from langchain.prompts import ChatPromptTemplate
|
8 |
from langchain.schema import StrOutputParser
|
9 |
+
from langchain_community.document_loaders import (
|
10 |
+
PyMuPDFLoader,
|
11 |
+
)
|
12 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
13 |
+
from langchain.vectorstores.chroma import Chroma
|
14 |
+
from langchain.indexes import SQLRecordManager, index
|
15 |
+
from langchain.schema import Document
|
16 |
+
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
|
17 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
18 |
+
|
19 |
+
import chainlit as cl
|
20 |
+
|
21 |
+
|
22 |
+
chunk_size = 1024
|
23 |
+
chunk_overlap = 50
|
24 |
+
|
25 |
+
embeddings_model = HuggingFaceEmbeddings()
|
26 |
+
|
27 |
+
PDF_STORAGE_PATH = "./public/pdfs"
|
28 |
+
|
29 |
+
|
30 |
+
def process_pdfs(pdf_storage_path: str):
|
31 |
+
pdf_directory = Path(pdf_storage_path)
|
32 |
+
docs = [] # type: List[Document]
|
33 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
34 |
+
|
35 |
+
for pdf_path in pdf_directory.glob("*.pdf"):
|
36 |
+
loader = PyMuPDFLoader(str(pdf_path))
|
37 |
+
documents = loader.load()
|
38 |
+
docs += text_splitter.split_documents(documents)
|
39 |
+
|
40 |
+
doc_search = Chroma.from_documents(docs, embeddings_model)
|
41 |
+
|
42 |
+
namespace = "chromadb/my_documents"
|
43 |
+
record_manager = SQLRecordManager(
|
44 |
+
namespace, db_url="sqlite:///record_manager_cache.sql"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
)
|
46 |
+
record_manager.create_schema()
|
47 |
+
|
48 |
+
index_result = index(
|
49 |
+
docs,
|
50 |
+
record_manager,
|
51 |
+
doc_search,
|
52 |
+
cleanup="incremental",
|
53 |
+
source_id_key="source",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
)
|
55 |
+
|
56 |
+
print(f"Indexing stats: {index_result}")
|
57 |
+
|
58 |
+
return doc_search
|
59 |
+
|
60 |
+
|
61 |
+
doc_search = process_pdfs(PDF_STORAGE_PATH)
|
62 |
+
#model = ChatOpenAI(model_name="gpt-4", streaming=True)
|
63 |
+
os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN']
|
64 |
+
repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
65 |
+
|
66 |
+
model = HuggingFaceEndpoint(
|
67 |
+
repo_id=repo_id, max_new_tokens=8000, temperature=1.0, task="text2text-generation", streaming=True
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
@cl.on_chat_start
|
72 |
+
async def on_chat_start():
|
73 |
+
template = """Answer the question based only on the following context:
|
74 |
+
|
75 |
{context}
|
76 |
+
|
77 |
+
Question: {question}
|
78 |
"""
|
79 |
+
prompt = ChatPromptTemplate.from_template(template)
|
80 |
+
|
81 |
+
def format_docs(docs):
|
82 |
+
return "\n\n".join([d.page_content for d in docs])
|
83 |
+
|
84 |
+
retriever = doc_search.as_retriever()
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
runnable = (
|
87 |
+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
|
|
|
88 |
| prompt
|
89 |
| model
|
90 |
+
| StrOutputParser()
|
91 |
)
|
92 |
+
|
93 |
+
cl.user_session.set("runnable", runnable)
|
94 |
+
|
95 |
+
|
96 |
+
@cl.on_message
|
97 |
+
async def on_message(message: cl.Message):
|
98 |
+
runnable = cl.user_session.get("runnable") # type: Runnable
|
99 |
+
msg = cl.Message(content="")
|
100 |
+
|
101 |
+
class PostMessageHandler(BaseCallbackHandler):
|
102 |
+
"""
|
103 |
+
Callback handler for handling the retriever and LLM processes.
|
104 |
+
Used to post the sources of the retrieved documents as a Chainlit element.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, msg: cl.Message):
|
108 |
+
BaseCallbackHandler.__init__(self)
|
109 |
+
self.msg = msg
|
110 |
+
self.sources = set() # To store unique pairs
|
111 |
+
|
112 |
+
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
|
113 |
+
for d in documents:
|
114 |
+
source_page_pair = (d.metadata['source'], d.metadata['page'])
|
115 |
+
self.sources.add(source_page_pair) # Add unique pairs to the set
|
116 |
+
|
117 |
+
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
|
118 |
+
if len(self.sources):
|
119 |
+
sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
|
120 |
+
self.msg.elements.append(
|
121 |
+
cl.Text(name="Sources", content=sources_text, display="inline")
|
122 |
+
)
|
123 |
+
|
124 |
+
async with cl.Step(type="run", name="QA Assistant"):
|
125 |
+
async for chunk in runnable.astream(
|
126 |
+
message.content,
|
127 |
+
config=RunnableConfig(callbacks=[
|
128 |
+
cl.LangchainCallbackHandler(),
|
129 |
+
PostMessageHandler(msg)
|
130 |
+
]),
|
131 |
+
):
|
132 |
+
await msg.stream_token(chunk)
|
133 |
+
|
134 |
+
await msg.send()
|