herMaster commited on
Commit
03f2f12
β€’
1 Parent(s): 76a5d4c

uploaded the model and changed the inference library

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_pdf import PDF
3
+ from qdrant_client import models, QdrantClient
4
+ from sentence_transformers import SentenceTransformer
5
+ from PyPDF2 import PdfReader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.callbacks.manager import CallbackManager
8
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9
+ # from langchain.llms import LlamaCpp
10
+ from langchain.vectorstores import Qdrant
11
+ from qdrant_client.http import models
12
+ # from langchain.llms import CTransformers
13
+ from ctransformers import AutoModelForCausalLM
14
+
15
+
16
+ # loading the embedding model -
17
+
18
+ encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1')
19
+
20
+ print("embedding model loaded.............................")
21
+ print("####################################################")
22
+
23
+ # loading the LLM
24
+
25
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
26
+
27
+ print("loading the LLM......................................")
28
+
29
+ llm = LlamaCpp(
30
+ model_path="./llama-2-7b-chat.Q3_K_S.gguf",
31
+ n_ctx=2048,
32
+ f16_kv=True, # MUST set to True, otherwise you will run into problem after a couple of calls
33
+ callback_manager=callback_manager,
34
+ verbose=True,
35
+ )
36
+
37
+ # llm = AutoModelForCausalLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF",
38
+ # model_file="llama-2-7b-chat.Q3_K_S.gguf",
39
+ # model_type="llama",
40
+ # temperature = 0.2,
41
+ # repetition_penalty = 1.5,
42
+ # max_new_tokens = 300,
43
+ # )
44
+
45
+
46
+
47
+ print("LLM loaded........................................")
48
+ print("################################################################")
49
+
50
+ # def get_chunks(text):
51
+ # text_splitter = RecursiveCharacterTextSplitter(
52
+ # # seperator = "\n",
53
+ # chunk_size = 250,
54
+ # chunk_overlap = 50,
55
+ # length_function = len,
56
+ # )
57
+
58
+ # chunks = text_splitter.split_text(text)
59
+ # return chunks
60
+
61
+
62
+ # pdf_path = './100 Weird Facts About the Human Body.pdf'
63
+
64
+
65
+ # reader = PdfReader(pdf_path)
66
+ # text = ""
67
+ # num_of_pages = len(reader.pages)
68
+
69
+ # for page in range(num_of_pages):
70
+ # current_page = reader.pages[page]
71
+ # text += current_page.extract_text()
72
+
73
+
74
+ # chunks = get_chunks(text)
75
+ # print(chunks)
76
+ # print("Chunks are ready.....................................")
77
+ # print("######################################################")
78
+
79
+ # client = QdrantClient(path = "./db")
80
+ # print("db created................................................")
81
+ # print("#####################################################################")
82
+
83
+ # client.recreate_collection(
84
+ # collection_name="my_facts",
85
+ # vectors_config=models.VectorParams(
86
+ # size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model
87
+ # distance=models.Distance.COSINE,
88
+ # ),
89
+ # )
90
+
91
+ # print("Collection created........................................")
92
+ # print("#########################################################")
93
+
94
+
95
+
96
+ # li = []
97
+ # for i in range(len(chunks)):
98
+ # li.append(i)
99
+
100
+ # dic = zip(li, chunks)
101
+ # dic= dict(dic)
102
+
103
+ # client.upload_records(
104
+ # collection_name="my_facts",
105
+ # records=[
106
+ # models.Record(
107
+ # id=idx,
108
+ # vector=encoder.encode(dic[idx]).tolist(),
109
+ # payload= {dic[idx][:5] : dic[idx]}
110
+ # ) for idx in dic.keys()
111
+ # ],
112
+ # )
113
+
114
+ # print("Records uploaded........................................")
115
+ # print("###########################################################")
116
+
117
+ def chat(file, question):
118
+ def get_chunks(text):
119
+ text_splitter = RecursiveCharacterTextSplitter(
120
+ # seperator = "\n",
121
+ chunk_size = 250,
122
+ chunk_overlap = 50,
123
+ length_function = len,
124
+ )
125
+
126
+ chunks = text_splitter.split_text(text)
127
+ return chunks
128
+
129
+
130
+ pdf_path = file
131
+
132
+
133
+ reader = PdfReader(pdf_path)
134
+ text = ""
135
+ num_of_pages = len(reader.pages)
136
+
137
+ for page in range(num_of_pages):
138
+ current_page = reader.pages[page]
139
+ text += current_page.extract_text()
140
+
141
+
142
+ chunks = get_chunks(text)
143
+ # print(chunks)
144
+ # print("Chunks are ready.....................................")
145
+ # print("######################################################")
146
+
147
+ client = QdrantClient(path = "./db")
148
+ # print("db created................................................")
149
+ # print("#####################################################################")
150
+
151
+ client.recreate_collection(
152
+ collection_name="my_facts",
153
+ vectors_config=models.VectorParams(
154
+ size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model
155
+ distance=models.Distance.COSINE,
156
+ ),
157
+ )
158
+
159
+ # print("Collection created........................................")
160
+ # print("#########################################################")
161
+
162
+
163
+
164
+ li = []
165
+ for i in range(len(chunks)):
166
+ li.append(i)
167
+
168
+ dic = zip(li, chunks)
169
+ dic= dict(dic)
170
+
171
+ client.upload_records(
172
+ collection_name="my_facts",
173
+ records=[
174
+ models.Record(
175
+ id=idx,
176
+ vector=encoder.encode(dic[idx]).tolist(),
177
+ payload= {dic[idx][:5] : dic[idx]}
178
+ ) for idx in dic.keys()
179
+ ],
180
+ )
181
+
182
+ # print("Records uploaded........................................")
183
+ # print("###########################################################")
184
+
185
+
186
+ hits = client.search(
187
+ collection_name="my_facts",
188
+ query_vector=encoder.encode(question).tolist(),
189
+ limit=3
190
+ )
191
+ context = []
192
+ for hit in hits:
193
+ context.append(list(hit.payload.values())[0])
194
+
195
+ context = context[0] + context[1] + context[2]
196
+
197
+ system_prompt = """You are a helpful assistant, you will use the provided context to answer user questions.
198
+ Read the given context before answering questions and think step by step. If you can not answer a user question based on
199
+ the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question."""
200
+
201
+
202
+ B_INST, E_INST = "[INST]", "[/INST]"
203
+
204
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
205
+
206
+ SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
207
+
208
+ instruction = f"""
209
+ Context: {context}
210
+ User: {question}"""
211
+
212
+ prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
213
+
214
+ result = llm(prompt_template)
215
+ return result
216
+
217
+
218
+ screen = gr.Interface(
219
+ fn = chat,
220
+ inputs = [PDF(label="Upload a PDF", interactive=True), gr.Textbox(lines = 10, placeholder = "Enter your question here πŸ‘‰")],
221
+ outputs = gr.Textbox(lines = 10, placeholder = "Your answer will be here soon πŸš€"),
222
+ title="Q&A with PDF πŸ‘©πŸ»β€πŸ’»πŸ““βœπŸ»πŸ’‘",
223
+ description="This app facilitates a conversation with PDFs available on https://www.delo.si/assets/media/other/20110728/100%20Weird%20Facts%20About%20the%20Human%20Body.pdfπŸ’‘",
224
+ theme="soft",
225
+ # examples=["Hello", "what is the speed of human nerve impulses?"],
226
+ )
227
+
228
+ screen.launch()