os1187 commited on
Commit
f19a164
1 Parent(s): 761b3dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from langchain.document_loaders import PyPDFLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.vectorstores import Chroma
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.llms import HuggingFaceHub
9
+ from langchain.memory import ConversationBufferMemory
10
+ import chromadb
11
+ from transformers import AutoTokenizer
12
+ import transformers
13
+ import torch
14
+
15
+ # Constants and configuration
16
+ list_llm = [
17
+ "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2",
18
+ "mistralai/Mistral-7B-Instruct-v0.1", "HuggingFaceH4/zephyr-7b-beta",
19
+ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
20
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct",
21
+ "tiiuae/falcon-7b-instruct", "google/flan-t5-xxl"
22
+ ]
23
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
24
+
25
+ # Function placeholders (actual function implementations from the original script)
26
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
27
+ loaders = [PyPDFLoader(x) for x in list_file_path]
28
+ pages = []
29
+ for loader in loaders:
30
+ pages.extend(loader.load())
31
+ text_splitter = RecursiveCharacterTextSplitter(
32
+ chunk_size = chunk_size,
33
+ chunk_overlap = chunk_overlap)
34
+ doc_splits = text_splitter.split_documents(pages)
35
+ return doc_splits
36
+
37
+ def create_db(splits, collection_name):
38
+ embedding = HuggingFaceEmbeddings()
39
+ new_client = chromadb.EphemeralClient()
40
+ vectordb = Chroma.from_documents(
41
+ documents=splits,
42
+ embedding=embedding,
43
+ client=new_client,
44
+ collection_name=collection_name,
45
+ )
46
+ return vectordb
47
+
48
+ def load_db():
49
+ embedding = HuggingFaceEmbeddings()
50
+ vectordb = Chroma(embedding_function=embedding)
51
+ return vectordb
52
+
53
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
54
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
55
+ llm = HuggingFaceHub(
56
+ repo_id=llm_model,
57
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
58
+ )
59
+ else:
60
+ llm = HuggingFaceHub(
61
+ repo_id=llm_model,
62
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
63
+ )
64
+ memory = ConversationBufferMemory(
65
+ memory_key="chat_history",
66
+ output_key='answer',
67
+ return_messages=True
68
+ )
69
+ retriever = vector_db.as_retriever()
70
+ qa_chain = ConversationalRetrievalChain.from_llm(
71
+ llm,
72
+ retriever=retriever,
73
+ chain_type="stuff",
74
+ memory=memory,
75
+ return_source_documents=True,
76
+ return_generated_question=False,
77
+ )
78
+ return qa_chain
79
+
80
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap):
81
+ list_file_path = [x.name for x in list_file_obj if x is not None]
82
+ collection_name = os.path.basename(list_file_path[0]).replace(" ","-")[:50]
83
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
84
+ vector_db = create_db(doc_splits, collection_name)
85
+ return vector_db, collection_name, "Complete!"
86
+
87
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
88
+ llm_name = list_llm[llm_option]
89
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
90
+ return qa_chain, "Complete!"
91
+
92
+ def format_chat_history(message, chat_history):
93
+ formatted_chat_history = []
94
+ for user_message, bot_message in chat_history:
95
+ formatted_chat_history.append(f"User: {user_message}")
96
+ formatted_chat_history.append(f"Assistant: {bot_message}")
97
+ return formatted_chat_history
98
+
99
+ def conversation(qa_chain, message, history):
100
+ formatted_chat_history = format_chat_history(message, history)
101
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
102
+ response_answer = response["answer"]
103
+ response_sources = response["source_documents"]
104
+ response_source1 = response_sources[0].page_content.strip()
105
+ response_source2 = response_sources[1].page_content.strip()
106
+ response_source1_page = response_sources[0].metadata["page"] + 1
107
+ response_source2_page = response_sources[1].metadata["page"] + 1
108
+
109
+ new_history = history + [(message, response_answer)]
110
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
111
+
112
+ def upload_file(file_obj):
113
+ list_file_path = [file.name for file in file_obj]
114
+ return list_file_path
115
+
116
+
117
+ def gradio_ui():
118
+ with gr.Blocks(theme="base") as demo:
119
+ # States
120
+ vector_db, qa_chain, collection_name = gr.State(), gr.State(), gr.State()
121
+ db_progress, llm_progress = gr.Textbox(), gr.Textbox()
122
+ chatbot, doc_source1, source1_page, doc_source2, source2_page = gr.Chatbot(), gr.Textbox(), gr.Number(), gr.Textbox(), gr.Number()
123
+ msg = gr.Textbox(placeholder="Type message")
124
+
125
+ with gr.Tabs():
126
+ # Tab 1: Document Pre-processing
127
+ with gr.Tab("Step 1 - Document Pre-processing"):
128
+ with gr.Row():
129
+ document = gr.File(label="Upload your PDF document", file_types=["pdf"])
130
+ with gr.Row():
131
+ chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=50, label="Chunk size", interactive=True)
132
+ chunk_overlap = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Chunk overlap", interactive=True)
133
+ with gr.Row():
134
+ db_init_btn = gr.Button("Initialize Vector Database")
135
+
136
+ # Tab 2: QA Chain Initialization
137
+ with gr.Tab("Step 2 - QA Chain Initialization"):
138
+ with gr.Row():
139
+ llm_selection = gr.Radio(list_llm_simple, label="Choose LLM Model", value=list_llm_simple[0])
140
+ with gr.Row():
141
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature", interactive=True)
142
+ max_tokens = gr.Slider(minimum=64, maximum=1024, value=256, step=64, label="Max Tokens", interactive=True)
143
+ top_k = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top K", interactive=True)
144
+ with gr.Row():
145
+ qa_init_btn = gr.Button("Initialize QA Chain")
146
+
147
+ # Tab 3: Conversation with Chatbot
148
+ with gr.Tab("Step 3 - Conversation with Chatbot"):
149
+ chat_history = gr.State()
150
+ with gr.Row():
151
+ chatbot
152
+ with gr.Row():
153
+ msg
154
+ submit_btn = gr.Button("Submit")
155
+
156
+ # Handlers
157
+ db_init_btn.click(initialize_database, inputs=[document, chunk_size, chunk_overlap], outputs=[vector_db, collection_name, db_progress])
158
+ qa_init_btn.click(initialize_LLM, inputs=[llm_selection, temperature, max_tokens, top_k, vector_db], outputs=[qa_chain, llm_progress])
159
+ submit_btn.click(conversation, inputs=[qa_chain, msg, chat_history], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page])
160
+
161
+ return demo
162
+
163
+ if __name__ == "__main__":
164
+ gradio_ui().launch()