Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,29 +1,34 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
|
4 |
-
from
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
-
from
|
7 |
from langchain.chains import ConversationalRetrievalChain
|
8 |
-
from
|
9 |
-
from
|
10 |
from langchain.chains import ConversationChain
|
11 |
from langchain.memory import ConversationBufferMemory
|
12 |
-
from
|
13 |
|
14 |
from pathlib import Path
|
15 |
import chromadb
|
|
|
16 |
|
17 |
from transformers import AutoTokenizer
|
18 |
import transformers
|
19 |
import torch
|
20 |
import tqdm
|
21 |
import accelerate
|
|
|
|
|
22 |
|
23 |
|
24 |
# default_persist_directory = './chroma_HF/'
|
25 |
-
list_llm = ["mistralai/
|
26 |
-
"
|
|
|
|
|
27 |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
|
28 |
"google/flan-t5-xxl"
|
29 |
]
|
@@ -98,32 +103,58 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
|
|
98 |
# Warning: langchain issue
|
99 |
# URL: https://github.com/langchain-ai/langchain/issues/6080
|
100 |
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
|
101 |
-
llm =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
repo_id=llm_model,
|
103 |
-
|
|
|
|
|
104 |
)
|
105 |
elif llm_model == "microsoft/phi-2":
|
106 |
raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
|
107 |
-
llm =
|
108 |
repo_id=llm_model,
|
109 |
-
model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
|
|
|
|
|
|
|
|
|
|
|
110 |
)
|
111 |
elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
|
112 |
-
llm =
|
113 |
repo_id=llm_model,
|
114 |
-
model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
|
|
|
|
|
|
|
115 |
)
|
116 |
elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
|
117 |
raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
|
118 |
-
llm =
|
119 |
repo_id=llm_model,
|
120 |
-
model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
|
|
|
|
|
|
|
121 |
)
|
122 |
else:
|
123 |
-
llm =
|
124 |
repo_id=llm_model,
|
125 |
# model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
|
126 |
-
model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
|
|
|
|
|
|
|
127 |
)
|
128 |
|
129 |
progress(0.75, desc="Defining buffer memory...")
|
@@ -149,18 +180,36 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
|
|
149 |
return qa_chain
|
150 |
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
# Initialize database
|
153 |
def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
|
154 |
# Create list of documents (when valid)
|
155 |
list_file_path = [x.name for x in list_file_obj if x is not None]
|
156 |
# Create collection_name for vector database
|
157 |
progress(0.1, desc="Creating collection name...")
|
158 |
-
collection_name =
|
159 |
-
# Fix potential issues from naming convention
|
160 |
-
collection_name = collection_name.replace(" ","-")
|
161 |
-
collection_name = collection_name[:50]
|
162 |
-
# print('list_file_path: ', list_file_path)
|
163 |
-
print('Collection name: ', collection_name)
|
164 |
progress(0.25, desc="Loading document...")
|
165 |
# Load document and create splits
|
166 |
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
|
@@ -195,19 +244,23 @@ def conversation(qa_chain, message, history):
|
|
195 |
# Generate response using QA chain
|
196 |
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
|
197 |
response_answer = response["answer"]
|
|
|
|
|
198 |
response_sources = response["source_documents"]
|
199 |
response_source1 = response_sources[0].page_content.strip()
|
200 |
response_source2 = response_sources[1].page_content.strip()
|
|
|
201 |
# Langchain sources are zero-based
|
202 |
response_source1_page = response_sources[0].metadata["page"] + 1
|
203 |
response_source2_page = response_sources[1].metadata["page"] + 1
|
|
|
204 |
# print ('chat response: ', response_answer)
|
205 |
# print('DB source', response_sources)
|
206 |
|
207 |
# Append user message and response to chat history
|
208 |
new_history = history + [(message, response_answer)]
|
209 |
# return gr.update(value=""), new_history, response_sources[0], response_sources[1]
|
210 |
-
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
|
211 |
|
212 |
|
213 |
def upload_file(file_obj):
|
@@ -274,6 +327,9 @@ def demo():
|
|
274 |
with gr.Row():
|
275 |
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
|
276 |
source2_page = gr.Number(label="Page", scale=1)
|
|
|
|
|
|
|
277 |
with gr.Row():
|
278 |
msg = gr.Textbox(placeholder="Type message", container=True)
|
279 |
with gr.Row():
|
@@ -287,23 +343,23 @@ def demo():
|
|
287 |
outputs=[vector_db, collection_name, db_progress])
|
288 |
qachain_btn.click(initialize_LLM, \
|
289 |
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
|
290 |
-
outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0], \
|
291 |
inputs=None, \
|
292 |
-
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \
|
293 |
queue=False)
|
294 |
|
295 |
# Chatbot events
|
296 |
msg.submit(conversation, \
|
297 |
inputs=[qa_chain, msg, chatbot], \
|
298 |
-
outputs=[qa_chain, msg, chatbot], \
|
299 |
queue=False)
|
300 |
submit_btn.click(conversation, \
|
301 |
inputs=[qa_chain, msg, chatbot], \
|
302 |
-
outputs=[qa_chain, msg, chatbot], \
|
303 |
queue=False)
|
304 |
-
clear_btn.click(lambda:[None,"",0,"",0], \
|
305 |
inputs=None, \
|
306 |
-
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \
|
307 |
queue=False)
|
308 |
demo.queue().launch(debug=True)
|
309 |
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
|
4 |
+
from langchain_community.document_loaders import PyPDFLoader
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
+
from langchain_community.vectorstores import Chroma
|
7 |
from langchain.chains import ConversationalRetrievalChain
|
8 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
9 |
+
from langchain_community.llms import HuggingFacePipeline
|
10 |
from langchain.chains import ConversationChain
|
11 |
from langchain.memory import ConversationBufferMemory
|
12 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
13 |
|
14 |
from pathlib import Path
|
15 |
import chromadb
|
16 |
+
from unidecode import unidecode
|
17 |
|
18 |
from transformers import AutoTokenizer
|
19 |
import transformers
|
20 |
import torch
|
21 |
import tqdm
|
22 |
import accelerate
|
23 |
+
import re
|
24 |
+
|
25 |
|
26 |
|
27 |
# default_persist_directory = './chroma_HF/'
|
28 |
+
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
|
29 |
+
"google/gemma-7b-it","google/gemma-2b-it", \
|
30 |
+
"HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \
|
31 |
+
"meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
|
32 |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
|
33 |
"google/flan-t5-xxl"
|
34 |
]
|
|
|
103 |
# Warning: langchain issue
|
104 |
# URL: https://github.com/langchain-ai/langchain/issues/6080
|
105 |
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
|
106 |
+
llm = HuggingFaceEndpoint(
|
107 |
+
repo_id=llm_model,
|
108 |
+
# model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
|
109 |
+
temperature = temperature,
|
110 |
+
max_new_tokens = max_tokens,
|
111 |
+
top_k = top_k,
|
112 |
+
load_in_8bit = True,
|
113 |
+
)
|
114 |
+
elif llm_model == "HuggingFaceH4/zephyr-7b-gemma-v0.1":
|
115 |
+
raise gr.Error("zephyr-7b-gemma-v0.1 is too large to be loaded automatically on free inference endpoint")
|
116 |
+
llm = HuggingFaceEndpoint(
|
117 |
repo_id=llm_model,
|
118 |
+
temperature = temperature,
|
119 |
+
max_new_tokens = max_tokens,
|
120 |
+
top_k = top_k,
|
121 |
)
|
122 |
elif llm_model == "microsoft/phi-2":
|
123 |
raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
|
124 |
+
llm = HuggingFaceEndpoint(
|
125 |
repo_id=llm_model,
|
126 |
+
# model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
|
127 |
+
temperature = temperature,
|
128 |
+
max_new_tokens = max_tokens,
|
129 |
+
top_k = top_k,
|
130 |
+
trust_remote_code = True,
|
131 |
+
torch_dtype = "auto",
|
132 |
)
|
133 |
elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
|
134 |
+
llm = HuggingFaceEndpoint(
|
135 |
repo_id=llm_model,
|
136 |
+
# model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
|
137 |
+
temperature = temperature,
|
138 |
+
max_new_tokens = 250,
|
139 |
+
top_k = top_k,
|
140 |
)
|
141 |
elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
|
142 |
raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
|
143 |
+
llm = HuggingFaceEndpoint(
|
144 |
repo_id=llm_model,
|
145 |
+
# model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
|
146 |
+
temperature = temperature,
|
147 |
+
max_new_tokens = max_tokens,
|
148 |
+
top_k = top_k,
|
149 |
)
|
150 |
else:
|
151 |
+
llm = HuggingFaceEndpoint(
|
152 |
repo_id=llm_model,
|
153 |
# model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
|
154 |
+
# model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
|
155 |
+
temperature = temperature,
|
156 |
+
max_new_tokens = max_tokens,
|
157 |
+
top_k = top_k,
|
158 |
)
|
159 |
|
160 |
progress(0.75, desc="Defining buffer memory...")
|
|
|
180 |
return qa_chain
|
181 |
|
182 |
|
183 |
+
# Generate collection name for vector database
|
184 |
+
# - Use filepath as input, ensuring unicode text
|
185 |
+
def create_collection_name(filepath):
|
186 |
+
# Extract filename without extension
|
187 |
+
collection_name = Path(filepath).stem
|
188 |
+
# Fix potential issues from naming convention
|
189 |
+
## Remove space
|
190 |
+
collection_name = collection_name.replace(" ","-")
|
191 |
+
## ASCII transliterations of Unicode text
|
192 |
+
collection_name = unidecode(collection_name)
|
193 |
+
## Remove special characters
|
194 |
+
#collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
|
195 |
+
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
|
196 |
+
## Limit length to 50 characters
|
197 |
+
collection_name = collection_name[:50]
|
198 |
+
## Minimum length of 3 characters
|
199 |
+
if len(collection_name) < 3:
|
200 |
+
collection_name = collection_name + 'xyz'
|
201 |
+
print('Filepath: ', filepath)
|
202 |
+
print('Collection name: ', collection_name)
|
203 |
+
return collection_name
|
204 |
+
|
205 |
+
|
206 |
# Initialize database
|
207 |
def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
|
208 |
# Create list of documents (when valid)
|
209 |
list_file_path = [x.name for x in list_file_obj if x is not None]
|
210 |
# Create collection_name for vector database
|
211 |
progress(0.1, desc="Creating collection name...")
|
212 |
+
collection_name = create_collection_name(list_file_path[0])
|
|
|
|
|
|
|
|
|
|
|
213 |
progress(0.25, desc="Loading document...")
|
214 |
# Load document and create splits
|
215 |
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
|
|
|
244 |
# Generate response using QA chain
|
245 |
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
|
246 |
response_answer = response["answer"]
|
247 |
+
if response_answer.find("Helpful Answer:") != -1:
|
248 |
+
response_answer = response_answer.split("Helpful Answer:")[-1]
|
249 |
response_sources = response["source_documents"]
|
250 |
response_source1 = response_sources[0].page_content.strip()
|
251 |
response_source2 = response_sources[1].page_content.strip()
|
252 |
+
response_source3 = response_sources[2].page_content.strip()
|
253 |
# Langchain sources are zero-based
|
254 |
response_source1_page = response_sources[0].metadata["page"] + 1
|
255 |
response_source2_page = response_sources[1].metadata["page"] + 1
|
256 |
+
response_source3_page = response_sources[2].metadata["page"] + 1
|
257 |
# print ('chat response: ', response_answer)
|
258 |
# print('DB source', response_sources)
|
259 |
|
260 |
# Append user message and response to chat history
|
261 |
new_history = history + [(message, response_answer)]
|
262 |
# return gr.update(value=""), new_history, response_sources[0], response_sources[1]
|
263 |
+
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
|
264 |
|
265 |
|
266 |
def upload_file(file_obj):
|
|
|
327 |
with gr.Row():
|
328 |
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
|
329 |
source2_page = gr.Number(label="Page", scale=1)
|
330 |
+
with gr.Row():
|
331 |
+
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
|
332 |
+
source3_page = gr.Number(label="Page", scale=1)
|
333 |
with gr.Row():
|
334 |
msg = gr.Textbox(placeholder="Type message", container=True)
|
335 |
with gr.Row():
|
|
|
343 |
outputs=[vector_db, collection_name, db_progress])
|
344 |
qachain_btn.click(initialize_LLM, \
|
345 |
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
|
346 |
+
outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
|
347 |
inputs=None, \
|
348 |
+
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
349 |
queue=False)
|
350 |
|
351 |
# Chatbot events
|
352 |
msg.submit(conversation, \
|
353 |
inputs=[qa_chain, msg, chatbot], \
|
354 |
+
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
355 |
queue=False)
|
356 |
submit_btn.click(conversation, \
|
357 |
inputs=[qa_chain, msg, chatbot], \
|
358 |
+
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
359 |
queue=False)
|
360 |
+
clear_btn.click(lambda:[None,"",0,"",0,"",0], \
|
361 |
inputs=None, \
|
362 |
+
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
363 |
queue=False)
|
364 |
demo.queue().launch(debug=True)
|
365 |
|