santuchal commited on
Commit
eaf88b9
1 Parent(s): 4015508

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -30
app.py CHANGED
@@ -1,29 +1,34 @@
1
  import gradio as gr
2
  import os
3
 
4
- from langchain.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
- from langchain.embeddings import HuggingFaceEmbeddings
9
- from langchain.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
- from langchain.llms import HuggingFaceHub
13
 
14
  from pathlib import Path
15
  import chromadb
 
16
 
17
  from transformers import AutoTokenizer
18
  import transformers
19
  import torch
20
  import tqdm
21
  import accelerate
 
 
22
 
23
 
24
  # default_persist_directory = './chroma_HF/'
25
- list_llm = ["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1", \
26
- "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
 
 
27
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
28
  "google/flan-t5-xxl"
29
  ]
@@ -98,32 +103,58 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
98
  # Warning: langchain issue
99
  # URL: https://github.com/langchain-ai/langchain/issues/6080
100
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
101
- llm = HuggingFaceHub(
 
 
 
 
 
 
 
 
 
 
102
  repo_id=llm_model,
103
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
 
 
104
  )
105
  elif llm_model == "microsoft/phi-2":
106
  raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
107
- llm = HuggingFaceHub(
108
  repo_id=llm_model,
109
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
 
 
 
 
 
110
  )
111
  elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
112
- llm = HuggingFaceHub(
113
  repo_id=llm_model,
114
- model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
 
 
 
115
  )
116
  elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
117
  raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
118
- llm = HuggingFaceHub(
119
  repo_id=llm_model,
120
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
 
 
 
121
  )
122
  else:
123
- llm = HuggingFaceHub(
124
  repo_id=llm_model,
125
  # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
126
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
 
 
 
127
  )
128
 
129
  progress(0.75, desc="Defining buffer memory...")
@@ -149,18 +180,36 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
149
  return qa_chain
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  # Initialize database
153
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
154
  # Create list of documents (when valid)
155
  list_file_path = [x.name for x in list_file_obj if x is not None]
156
  # Create collection_name for vector database
157
  progress(0.1, desc="Creating collection name...")
158
- collection_name = Path(list_file_path[0]).stem
159
- # Fix potential issues from naming convention
160
- collection_name = collection_name.replace(" ","-")
161
- collection_name = collection_name[:50]
162
- # print('list_file_path: ', list_file_path)
163
- print('Collection name: ', collection_name)
164
  progress(0.25, desc="Loading document...")
165
  # Load document and create splits
166
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
@@ -195,19 +244,23 @@ def conversation(qa_chain, message, history):
195
  # Generate response using QA chain
196
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
197
  response_answer = response["answer"]
 
 
198
  response_sources = response["source_documents"]
199
  response_source1 = response_sources[0].page_content.strip()
200
  response_source2 = response_sources[1].page_content.strip()
 
201
  # Langchain sources are zero-based
202
  response_source1_page = response_sources[0].metadata["page"] + 1
203
  response_source2_page = response_sources[1].metadata["page"] + 1
 
204
  # print ('chat response: ', response_answer)
205
  # print('DB source', response_sources)
206
 
207
  # Append user message and response to chat history
208
  new_history = history + [(message, response_answer)]
209
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
210
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
211
 
212
 
213
  def upload_file(file_obj):
@@ -274,6 +327,9 @@ def demo():
274
  with gr.Row():
275
  doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
276
  source2_page = gr.Number(label="Page", scale=1)
 
 
 
277
  with gr.Row():
278
  msg = gr.Textbox(placeholder="Type message", container=True)
279
  with gr.Row():
@@ -287,23 +343,23 @@ def demo():
287
  outputs=[vector_db, collection_name, db_progress])
288
  qachain_btn.click(initialize_LLM, \
289
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
290
- outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0], \
291
  inputs=None, \
292
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \
293
  queue=False)
294
 
295
  # Chatbot events
296
  msg.submit(conversation, \
297
  inputs=[qa_chain, msg, chatbot], \
298
- outputs=[qa_chain, msg, chatbot], \
299
  queue=False)
300
  submit_btn.click(conversation, \
301
  inputs=[qa_chain, msg, chatbot], \
302
- outputs=[qa_chain, msg, chatbot], \
303
  queue=False)
304
- clear_btn.click(lambda:[None,"",0,"",0], \
305
  inputs=None, \
306
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \
307
  queue=False)
308
  demo.queue().launch(debug=True)
309
 
 
1
  import gradio as gr
2
  import os
3
 
4
+ from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import Chroma
7
  from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
 
14
  from pathlib import Path
15
  import chromadb
16
+ from unidecode import unidecode
17
 
18
  from transformers import AutoTokenizer
19
  import transformers
20
  import torch
21
  import tqdm
22
  import accelerate
23
+ import re
24
+
25
 
26
 
27
  # default_persist_directory = './chroma_HF/'
28
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
29
+ "google/gemma-7b-it","google/gemma-2b-it", \
30
+ "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \
31
+ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
32
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
33
  "google/flan-t5-xxl"
34
  ]
 
103
  # Warning: langchain issue
104
  # URL: https://github.com/langchain-ai/langchain/issues/6080
105
  if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
106
+ llm = HuggingFaceEndpoint(
107
+ repo_id=llm_model,
108
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
109
+ temperature = temperature,
110
+ max_new_tokens = max_tokens,
111
+ top_k = top_k,
112
+ load_in_8bit = True,
113
+ )
114
+ elif llm_model == "HuggingFaceH4/zephyr-7b-gemma-v0.1":
115
+ raise gr.Error("zephyr-7b-gemma-v0.1 is too large to be loaded automatically on free inference endpoint")
116
+ llm = HuggingFaceEndpoint(
117
  repo_id=llm_model,
118
+ temperature = temperature,
119
+ max_new_tokens = max_tokens,
120
+ top_k = top_k,
121
  )
122
  elif llm_model == "microsoft/phi-2":
123
  raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
124
+ llm = HuggingFaceEndpoint(
125
  repo_id=llm_model,
126
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
127
+ temperature = temperature,
128
+ max_new_tokens = max_tokens,
129
+ top_k = top_k,
130
+ trust_remote_code = True,
131
+ torch_dtype = "auto",
132
  )
133
  elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
134
+ llm = HuggingFaceEndpoint(
135
  repo_id=llm_model,
136
+ # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
137
+ temperature = temperature,
138
+ max_new_tokens = 250,
139
+ top_k = top_k,
140
  )
141
  elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
142
  raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
143
+ llm = HuggingFaceEndpoint(
144
  repo_id=llm_model,
145
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
146
+ temperature = temperature,
147
+ max_new_tokens = max_tokens,
148
+ top_k = top_k,
149
  )
150
  else:
151
+ llm = HuggingFaceEndpoint(
152
  repo_id=llm_model,
153
  # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
154
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
155
+ temperature = temperature,
156
+ max_new_tokens = max_tokens,
157
+ top_k = top_k,
158
  )
159
 
160
  progress(0.75, desc="Defining buffer memory...")
 
180
  return qa_chain
181
 
182
 
183
+ # Generate collection name for vector database
184
+ # - Use filepath as input, ensuring unicode text
185
+ def create_collection_name(filepath):
186
+ # Extract filename without extension
187
+ collection_name = Path(filepath).stem
188
+ # Fix potential issues from naming convention
189
+ ## Remove space
190
+ collection_name = collection_name.replace(" ","-")
191
+ ## ASCII transliterations of Unicode text
192
+ collection_name = unidecode(collection_name)
193
+ ## Remove special characters
194
+ #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
195
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
196
+ ## Limit length to 50 characters
197
+ collection_name = collection_name[:50]
198
+ ## Minimum length of 3 characters
199
+ if len(collection_name) < 3:
200
+ collection_name = collection_name + 'xyz'
201
+ print('Filepath: ', filepath)
202
+ print('Collection name: ', collection_name)
203
+ return collection_name
204
+
205
+
206
  # Initialize database
207
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
208
  # Create list of documents (when valid)
209
  list_file_path = [x.name for x in list_file_obj if x is not None]
210
  # Create collection_name for vector database
211
  progress(0.1, desc="Creating collection name...")
212
+ collection_name = create_collection_name(list_file_path[0])
 
 
 
 
 
213
  progress(0.25, desc="Loading document...")
214
  # Load document and create splits
215
  doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
 
244
  # Generate response using QA chain
245
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
246
  response_answer = response["answer"]
247
+ if response_answer.find("Helpful Answer:") != -1:
248
+ response_answer = response_answer.split("Helpful Answer:")[-1]
249
  response_sources = response["source_documents"]
250
  response_source1 = response_sources[0].page_content.strip()
251
  response_source2 = response_sources[1].page_content.strip()
252
+ response_source3 = response_sources[2].page_content.strip()
253
  # Langchain sources are zero-based
254
  response_source1_page = response_sources[0].metadata["page"] + 1
255
  response_source2_page = response_sources[1].metadata["page"] + 1
256
+ response_source3_page = response_sources[2].metadata["page"] + 1
257
  # print ('chat response: ', response_answer)
258
  # print('DB source', response_sources)
259
 
260
  # Append user message and response to chat history
261
  new_history = history + [(message, response_answer)]
262
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
263
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
264
 
265
 
266
  def upload_file(file_obj):
 
327
  with gr.Row():
328
  doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
329
  source2_page = gr.Number(label="Page", scale=1)
330
+ with gr.Row():
331
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
332
+ source3_page = gr.Number(label="Page", scale=1)
333
  with gr.Row():
334
  msg = gr.Textbox(placeholder="Type message", container=True)
335
  with gr.Row():
 
343
  outputs=[vector_db, collection_name, db_progress])
344
  qachain_btn.click(initialize_LLM, \
345
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
346
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
347
  inputs=None, \
348
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
349
  queue=False)
350
 
351
  # Chatbot events
352
  msg.submit(conversation, \
353
  inputs=[qa_chain, msg, chatbot], \
354
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
355
  queue=False)
356
  submit_btn.click(conversation, \
357
  inputs=[qa_chain, msg, chatbot], \
358
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
359
  queue=False)
360
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
361
  inputs=None, \
362
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
363
  queue=False)
364
  demo.queue().launch(debug=True)
365