valeriylo IlyaGusev commited on
Commit
420a8b8
0 Parent(s):

Duplicate from IlyaGusev/saiga_13b_llamacpp_retrieval_qa

Browse files

Co-authored-by: Ilya Gusev <[email protected]>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +11 -0
  3. app.py +366 -0
  4. requirements.txt +8 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Saiga 13b Q4_1 llama.cpp Retrieval QA
3
+ emoji: 📚
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.32.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: IlyaGusev/saiga_13b_llamacpp_retrieval_qa
11
+ ---
app.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from uuid import uuid4
4
+ from huggingface_hub import snapshot_download
5
+ from langchain.document_loaders import (
6
+ CSVLoader,
7
+ EverNoteLoader,
8
+ PDFMinerLoader,
9
+ TextLoader,
10
+ UnstructuredEmailLoader,
11
+ UnstructuredEPubLoader,
12
+ UnstructuredHTMLLoader,
13
+ UnstructuredMarkdownLoader,
14
+ UnstructuredODTLoader,
15
+ UnstructuredPowerPointLoader,
16
+ UnstructuredWordDocumentLoader,
17
+ )
18
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
19
+ from langchain.vectorstores import Chroma
20
+ from langchain.embeddings import HuggingFaceEmbeddings
21
+ from langchain.docstore.document import Document
22
+ from chromadb.config import Settings
23
+ from llama_cpp import Llama
24
+
25
+
26
+ SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
27
+ SYSTEM_TOKEN = 1788
28
+ USER_TOKEN = 1404
29
+ BOT_TOKEN = 9225
30
+ LINEBREAK_TOKEN = 13
31
+
32
+ ROLE_TOKENS = {
33
+ "user": USER_TOKEN,
34
+ "bot": BOT_TOKEN,
35
+ "system": SYSTEM_TOKEN
36
+ }
37
+
38
+ LOADER_MAPPING = {
39
+ ".csv": (CSVLoader, {}),
40
+ ".doc": (UnstructuredWordDocumentLoader, {}),
41
+ ".docx": (UnstructuredWordDocumentLoader, {}),
42
+ ".enex": (EverNoteLoader, {}),
43
+ ".epub": (UnstructuredEPubLoader, {}),
44
+ ".html": (UnstructuredHTMLLoader, {}),
45
+ ".md": (UnstructuredMarkdownLoader, {}),
46
+ ".odt": (UnstructuredODTLoader, {}),
47
+ ".pdf": (PDFMinerLoader, {}),
48
+ ".ppt": (UnstructuredPowerPointLoader, {}),
49
+ ".pptx": (UnstructuredPowerPointLoader, {}),
50
+ ".txt": (TextLoader, {"encoding": "utf8"}),
51
+ }
52
+
53
+
54
+ repo_name = "IlyaGusev/saiga_13b_lora_llamacpp"
55
+ model_name = "ggml-model-q4_1.bin"
56
+ embedder_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
57
+
58
+ snapshot_download(repo_id=repo_name, local_dir=".", allow_patterns=model_name)
59
+
60
+ model = Llama(
61
+ model_path=model_name,
62
+ n_ctx=2000,
63
+ n_parts=1,
64
+ )
65
+
66
+ max_new_tokens = 1500
67
+ embeddings = HuggingFaceEmbeddings(model_name=embedder_name)
68
+
69
+ def get_uuid():
70
+ return str(uuid4())
71
+
72
+
73
+ def load_single_document(file_path: str) -> Document:
74
+ ext = "." + file_path.rsplit(".", 1)[-1]
75
+ assert ext in LOADER_MAPPING
76
+ loader_class, loader_args = LOADER_MAPPING[ext]
77
+ loader = loader_class(file_path, **loader_args)
78
+ return loader.load()[0]
79
+
80
+
81
+ def get_message_tokens(model, role, content):
82
+ message_tokens = model.tokenize(content.encode("utf-8"))
83
+ message_tokens.insert(1, ROLE_TOKENS[role])
84
+ message_tokens.insert(2, LINEBREAK_TOKEN)
85
+ message_tokens.append(model.token_eos())
86
+ return message_tokens
87
+
88
+
89
+ def get_system_tokens(model):
90
+ system_message = {"role": "system", "content": SYSTEM_PROMPT}
91
+ return get_message_tokens(model, **system_message)
92
+
93
+
94
+ def upload_files(files, file_paths):
95
+ file_paths = [f.name for f in files]
96
+ return file_paths
97
+
98
+
99
+ def process_text(text):
100
+ lines = text.split("\n")
101
+ lines = [line for line in lines if len(line.strip()) > 2]
102
+ text = "\n".join(lines).strip()
103
+ if len(text) < 10:
104
+ return None
105
+ return text
106
+
107
+
108
+ def build_index(file_paths, db, chunk_size, chunk_overlap, file_warning):
109
+ documents = [load_single_document(path) for path in file_paths]
110
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
111
+ documents = text_splitter.split_documents(documents)
112
+ fixed_documents = []
113
+ for doc in documents:
114
+ doc.page_content = process_text(doc.page_content)
115
+ if not doc.page_content:
116
+ continue
117
+ fixed_documents.append(doc)
118
+
119
+ db = Chroma.from_documents(
120
+ fixed_documents,
121
+ embeddings,
122
+ client_settings=Settings(
123
+ anonymized_telemetry=False
124
+ )
125
+ )
126
+ file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
127
+ return db, file_warning
128
+
129
+
130
+ def user(message, history, system_prompt):
131
+ new_history = history + [[message, None]]
132
+ return "", new_history
133
+
134
+
135
+ def retrieve(history, db, retrieved_docs, k_documents):
136
+ context = ""
137
+ if db:
138
+ last_user_message = history[-1][0]
139
+ retriever = db.as_retriever(search_kwargs={"k": k_documents})
140
+ docs = retriever.get_relevant_documents(last_user_message)
141
+ retrieved_docs = "\n\n".join([doc.page_content for doc in docs])
142
+ return retrieved_docs
143
+
144
+
145
+ def bot(
146
+ history,
147
+ system_prompt,
148
+ conversation_id,
149
+ retrieved_docs,
150
+ top_p,
151
+ top_k,
152
+ temp
153
+ ):
154
+ if not history:
155
+ return
156
+
157
+ tokens = get_system_tokens(model)[:]
158
+ tokens.append(LINEBREAK_TOKEN)
159
+
160
+ for user_message, bot_message in history[:-1]:
161
+ message_tokens = get_message_tokens(model=model, role="user", content=user_message)
162
+ tokens.extend(message_tokens)
163
+ if bot_message:
164
+ message_tokens = get_message_tokens(model=model, role="bot", content=bot_message)
165
+ tokens.extend(message_tokens)
166
+
167
+ last_user_message = history[-1][0]
168
+ if retrieved_docs:
169
+ last_user_message = f"Контекст: {retrieved_docs}\n\nИспользуя контекст, ответь на вопрос: {last_user_message}"
170
+ message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
171
+ tokens.extend(message_tokens)
172
+
173
+ role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
174
+ tokens.extend(role_tokens)
175
+ generator = model.generate(
176
+ tokens,
177
+ top_k=top_k,
178
+ top_p=top_p,
179
+ temp=temp
180
+ )
181
+
182
+ partial_text = ""
183
+ for i, token in enumerate(generator):
184
+ if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens):
185
+ break
186
+ partial_text += model.detokenize([token]).decode("utf-8", "ignore")
187
+ history[-1][1] = partial_text
188
+ yield history
189
+
190
+
191
+ with gr.Blocks(
192
+ theme=gr.themes.Soft()
193
+ ) as demo:
194
+ db = gr.State(None)
195
+ conversation_id = gr.State(get_uuid)
196
+ favicon = '<img src="https://cdn.midjourney.com/b88e5beb-6324-4820-8504-a1a37a9ba36d/0_1.png" width="48px" style="display: inline">'
197
+ gr.Markdown(
198
+ f"""<h1><center>{favicon}Saiga 13B llama.cpp: retrieval QA</center></h1>
199
+ """
200
+ )
201
+
202
+ with gr.Row():
203
+ with gr.Column(scale=5):
204
+ file_output = gr.File(file_count="multiple", label="Загрузка файлов")
205
+ file_paths = gr.State([])
206
+ file_warning = gr.Markdown(f"Фрагменты ещё не загружены!")
207
+
208
+ with gr.Column(min_width=200, scale=3):
209
+ with gr.Tab(label="Параметры нарезки"):
210
+ chunk_size = gr.Slider(
211
+ minimum=50,
212
+ maximum=2000,
213
+ value=250,
214
+ step=50,
215
+ interactive=True,
216
+ label="Размер фрагментов",
217
+ )
218
+ chunk_overlap = gr.Slider(
219
+ minimum=0,
220
+ maximum=500,
221
+ value=30,
222
+ step=10,
223
+ interactive=True,
224
+ label="Пересечение"
225
+ )
226
+
227
+
228
+ with gr.Row():
229
+ k_documents = gr.Slider(
230
+ minimum=1,
231
+ maximum=10,
232
+ value=2,
233
+ step=1,
234
+ interactive=True,
235
+ label="Кол-во фрагментов для контекста"
236
+ )
237
+ with gr.Row():
238
+ retrieved_docs = gr.Textbox(
239
+ lines=6,
240
+ label="Извлеченные фрагменты",
241
+ placeholder="Появятся после задавания вопросов",
242
+ interactive=False
243
+ )
244
+ with gr.Row():
245
+ with gr.Column(scale=5):
246
+ system_prompt = gr.Textbox(label="Системный промпт", placeholder="", value=SYSTEM_PROMPT, interactive=False)
247
+ chatbot = gr.Chatbot(label="Диалог").style(height=400)
248
+ with gr.Column(min_width=80, scale=1):
249
+ with gr.Tab(label="Параметры генерации"):
250
+ top_p = gr.Slider(
251
+ minimum=0.0,
252
+ maximum=1.0,
253
+ value=0.9,
254
+ step=0.05,
255
+ interactive=True,
256
+ label="Top-p",
257
+ )
258
+ top_k = gr.Slider(
259
+ minimum=10,
260
+ maximum=100,
261
+ value=30,
262
+ step=5,
263
+ interactive=True,
264
+ label="Top-k",
265
+ )
266
+ temp = gr.Slider(
267
+ minimum=0.0,
268
+ maximum=2.0,
269
+ value=0.1,
270
+ step=0.1,
271
+ interactive=True,
272
+ label="Temp"
273
+ )
274
+
275
+ with gr.Row():
276
+ with gr.Column():
277
+ msg = gr.Textbox(
278
+ label="Отправить сообщение",
279
+ placeholder="Отправить сообщение",
280
+ show_label=False,
281
+ ).style(container=False)
282
+ with gr.Column():
283
+ with gr.Row():
284
+ submit = gr.Button("Отправить")
285
+ stop = gr.Button("Остановить")
286
+ clear = gr.Button("Очистить")
287
+
288
+ # Upload files
289
+ upload_event = file_output.change(
290
+ fn=upload_files,
291
+ inputs=[file_output, file_paths],
292
+ outputs=[file_paths],
293
+ queue=True,
294
+ ).success(
295
+ fn=build_index,
296
+ inputs=[file_paths, db, chunk_size, chunk_overlap, file_warning],
297
+ outputs=[db, file_warning],
298
+ queue=True
299
+ )
300
+
301
+ # Pressing Enter
302
+ submit_event = msg.submit(
303
+ fn=user,
304
+ inputs=[msg, chatbot, system_prompt],
305
+ outputs=[msg, chatbot],
306
+ queue=False,
307
+ ).success(
308
+ fn=retrieve,
309
+ inputs=[chatbot, db, retrieved_docs, k_documents],
310
+ outputs=[retrieved_docs],
311
+ queue=True,
312
+ ).success(
313
+ fn=bot,
314
+ inputs=[
315
+ chatbot,
316
+ system_prompt,
317
+ conversation_id,
318
+ retrieved_docs,
319
+ top_p,
320
+ top_k,
321
+ temp
322
+ ],
323
+ outputs=chatbot,
324
+ queue=True,
325
+ )
326
+
327
+ # Pressing the button
328
+ submit_click_event = submit.click(
329
+ fn=user,
330
+ inputs=[msg, chatbot, system_prompt],
331
+ outputs=[msg, chatbot],
332
+ queue=False,
333
+ ).success(
334
+ fn=retrieve,
335
+ inputs=[chatbot, db, retrieved_docs, k_documents],
336
+ outputs=[retrieved_docs],
337
+ queue=True,
338
+ ).success(
339
+ fn=bot,
340
+ inputs=[
341
+ chatbot,
342
+ system_prompt,
343
+ conversation_id,
344
+ retrieved_docs,
345
+ top_p,
346
+ top_k,
347
+ temp
348
+ ],
349
+ outputs=chatbot,
350
+ queue=True,
351
+ )
352
+
353
+ # Stop generation
354
+ stop.click(
355
+ fn=None,
356
+ inputs=None,
357
+ outputs=None,
358
+ cancels=[submit_event, submit_click_event],
359
+ queue=False,
360
+ )
361
+
362
+ # Clear history
363
+ clear.click(lambda: None, None, chatbot, queue=False)
364
+
365
+ demo.queue(max_size=128, concurrency_count=1)
366
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ llama-cpp-python==0.1.53
2
+ langchain==0.0.174
3
+ huggingface-hub==0.14.1
4
+ chromadb==0.3.23
5
+ pdfminer.six==20221105
6
+ unstructured==0.6.10
7
+ gradio==3.32.0
8
+ tabulate