Spaces:
Runtime error
Runtime error
Commit
•
9e6b8ed
1
Parent(s):
1a9b6de
Adding application, not finished yet, still wont expand text
Browse files- app.py +130 -0
- backend/query_llm.py +86 -0
- backend/retrieval_augmented_generation.py +38 -0
- backend/semantic_search.py +10 -0
- utilities/retrievers.py +26 -0
app.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import perf_counter
|
2 |
+
|
3 |
+
from jinja2 import Template
|
4 |
+
|
5 |
+
from backend.semantic_search import qd_retriever
|
6 |
+
|
7 |
+
template_string = """
|
8 |
+
Instructions: Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
9 |
+
Context:
|
10 |
+
---
|
11 |
+
{% for doc in documents %}
|
12 |
+
{{ doc.content }}
|
13 |
+
---
|
14 |
+
{% endfor %}
|
15 |
+
Query: {{ query }}
|
16 |
+
"""
|
17 |
+
|
18 |
+
md_template_string = """
|
19 |
+
<b>Instructions</b>:
|
20 |
+
<span style="color: green;">Use the following pieces of context to answer the question at the end.<br>If you don't know the answer, just say that you don't know, <span style="color: green; font-weight: bold;">don't try to make up an answer.</span></span><br>
|
21 |
+
|
22 |
+
<b>Context</b>:
|
23 |
+
{% for doc in documents %}
|
24 |
+
<div id=\"box{{ loop.index }}\" style=\"border: 2px solid #aaa; padding: 10px; margin-top: 10px; border-radius: 5px; background-color: #1E90FF; position: relative; cursor: pointer;\">
|
25 |
+
<div style=\"font-size: 0.8em; position: absolute; top: 10px; left: 10px;\"><b>Doc {{ loop.index }}</b></div>
|
26 |
+
<span id=\"doc{{ loop.index }}-short\" style=\"color: white; display: block; margin-top: 20px;\">{{ doc.content[:50] }}...</span>
|
27 |
+
<span id=\"doc{{ loop.index }}-full\" style=\"color: white; display: none; margin-top: 20px;\">{{ doc.content }}</span>
|
28 |
+
</div>
|
29 |
+
{% endfor %}
|
30 |
+
<b>Query</b>: <span style=\"color: yellow;\">{{ query }}</span>
|
31 |
+
<script>
|
32 |
+
document.addEventListener("DOMContentLoaded", function() {
|
33 |
+
{% for doc in documents %}
|
34 |
+
document.getElementById("box{{ loop.index }}").addEventListener("click", function() {
|
35 |
+
toggleContent('doc{{ loop.index }}');
|
36 |
+
});
|
37 |
+
{% endfor %}
|
38 |
+
});
|
39 |
+
|
40 |
+
function toggleContent(docID) {
|
41 |
+
var shortContent = document.getElementById(docID + '-short');
|
42 |
+
var fullContent = document.getElementById(docID + '-full');
|
43 |
+
if (fullContent.style.display === 'none') {
|
44 |
+
shortContent.style.display = 'none';
|
45 |
+
fullContent.style.display = 'block';
|
46 |
+
} else {
|
47 |
+
shortContent.style.display = 'block';
|
48 |
+
fullContent.style.display = 'none';
|
49 |
+
}
|
50 |
+
}
|
51 |
+
</script>
|
52 |
+
"""
|
53 |
+
|
54 |
+
template = Template(template_string)
|
55 |
+
md_template = Template(md_template_string)
|
56 |
+
import gradio as gr
|
57 |
+
|
58 |
+
from backend.query_llm import generate
|
59 |
+
|
60 |
+
|
61 |
+
def add_text(history, text):
|
62 |
+
history = [] if history is None else history
|
63 |
+
history = history + [(text, None)]
|
64 |
+
return history, gr.Textbox(value="", interactive=False)
|
65 |
+
|
66 |
+
|
67 |
+
def bot(history, system_prompt=""):
|
68 |
+
top_k = 5
|
69 |
+
query = history[-1][0]
|
70 |
+
|
71 |
+
# Retrieve documents relevant to query
|
72 |
+
document_start = perf_counter()
|
73 |
+
documents = qd_retriever.retrieve(query, top_k=top_k)
|
74 |
+
document_time = document_start - perf_counter()
|
75 |
+
|
76 |
+
# Create Prompt
|
77 |
+
prompt = template.render(documents=documents, query=query)
|
78 |
+
md_prompt = md_template.render(documents=documents, query=query)
|
79 |
+
|
80 |
+
# # Query LLM with prompt based on relevant documents
|
81 |
+
# llm_start = perf_counter()
|
82 |
+
# result = generate(prompt=prompt, history='')
|
83 |
+
# llm_time = llm_start - perf_counter()
|
84 |
+
# times = (document_time, llm_time)
|
85 |
+
|
86 |
+
history[-1][1] = ""
|
87 |
+
for character in generate(prompt, history[:-1]):
|
88 |
+
history[-1][1] = character
|
89 |
+
yield history, md_prompt
|
90 |
+
|
91 |
+
|
92 |
+
with gr.Blocks() as demo:
|
93 |
+
with gr.Tab("Application"):
|
94 |
+
chatbot = gr.Chatbot(
|
95 |
+
[],
|
96 |
+
elem_id="chatbot",
|
97 |
+
avatar_images=('examples/lama.jpeg', 'examples/lama2.jpeg'),
|
98 |
+
bubble_full_width=False,
|
99 |
+
show_copy_button=True,
|
100 |
+
show_share_button=True,
|
101 |
+
)
|
102 |
+
|
103 |
+
with gr.Row():
|
104 |
+
txt = gr.Textbox(
|
105 |
+
scale=3,
|
106 |
+
show_label=False,
|
107 |
+
placeholder="Enter text and press enter",
|
108 |
+
container=False,
|
109 |
+
)
|
110 |
+
txt_btn = gr.Button(value="Submit text", scale=1)
|
111 |
+
|
112 |
+
prompt_md = gr.HTML()
|
113 |
+
# Turn off interactivity while generating if you hit enter
|
114 |
+
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
115 |
+
bot, chatbot, [chatbot, prompt_md])
|
116 |
+
|
117 |
+
# Turn it back on
|
118 |
+
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
119 |
+
|
120 |
+
# Turn off interactivity while generating if you hit enter
|
121 |
+
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
122 |
+
bot, chatbot, [chatbot, prompt_md])
|
123 |
+
|
124 |
+
# Turn it back on
|
125 |
+
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
126 |
+
|
127 |
+
gr.Examples(['What is the largest city on earth?', 'Who has the record for the fastest mile?'], txt)
|
128 |
+
|
129 |
+
demo.queue()
|
130 |
+
demo.launch(debug=True)
|
backend/query_llm.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Generator, List
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from huggingface_hub import InferenceClient
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
8 |
+
|
9 |
+
temperature = 0.9
|
10 |
+
top_p = 0.6
|
11 |
+
repetition_penalty = 1.2
|
12 |
+
|
13 |
+
text_client = InferenceClient(
|
14 |
+
"mistralai/Mistral-7B-Instruct-v0.1"
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def format_prompt(message: str) -> str:
|
19 |
+
"""
|
20 |
+
Formats the given message using a chat template.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
message (str): The user message to be formatted.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
str: Formatted message after applying the chat template.
|
27 |
+
"""
|
28 |
+
|
29 |
+
# Create a list of message dictionaries with role and content
|
30 |
+
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
|
31 |
+
|
32 |
+
# Return the message after applying the chat template
|
33 |
+
return tokenizer.apply_chat_template(messages, tokenize=False)
|
34 |
+
|
35 |
+
|
36 |
+
def generate(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
|
37 |
+
top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
|
38 |
+
"""
|
39 |
+
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
prompt (str): The initial prompt for the text generation.
|
43 |
+
history (str): Context or history for the text generation.
|
44 |
+
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
|
45 |
+
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
|
46 |
+
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
|
47 |
+
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
Generator[str, None, str]: A generator yielding chunks of generated text.
|
51 |
+
Returns a final string if an error occurs.
|
52 |
+
"""
|
53 |
+
|
54 |
+
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
|
55 |
+
top_p = float(top_p)
|
56 |
+
|
57 |
+
generate_kwargs = {
|
58 |
+
'temperature': temperature,
|
59 |
+
'max_new_tokens': max_new_tokens,
|
60 |
+
'top_p': top_p,
|
61 |
+
'repetition_penalty': repetition_penalty,
|
62 |
+
'do_sample': True,
|
63 |
+
'seed': 42,
|
64 |
+
}
|
65 |
+
|
66 |
+
formatted_prompt = format_prompt(prompt)
|
67 |
+
|
68 |
+
try:
|
69 |
+
stream = text_client.text_generation(formatted_prompt, **generate_kwargs,
|
70 |
+
stream=True, details=True, return_full_text=False)
|
71 |
+
output = ""
|
72 |
+
for response in stream:
|
73 |
+
output += response.token.text
|
74 |
+
yield output
|
75 |
+
|
76 |
+
except Exception as e:
|
77 |
+
if "Too Many Requests" in str(e):
|
78 |
+
print("ERROR: Too many requests on Mistral client")
|
79 |
+
gr.Warning("Unfortunately Mistral is unable to process")
|
80 |
+
return "Unfortunately, I am not able to process your request now."
|
81 |
+
else:
|
82 |
+
print("Unhandled Exception:", str(e))
|
83 |
+
gr.Warning("Unfortunately Mistral is unable to process")
|
84 |
+
return "I do not know what happened, but I couldn't understand you."
|
85 |
+
|
86 |
+
return output
|
backend/retrieval_augmented_generation.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import perf_counter
|
2 |
+
from jinja2 import Template
|
3 |
+
|
4 |
+
from backend.query_llm import generate
|
5 |
+
from backend.semantic_search import qd_retriever
|
6 |
+
|
7 |
+
template_string = """
|
8 |
+
Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
9 |
+
Context:
|
10 |
+
---
|
11 |
+
{% for doc in documents %}
|
12 |
+
{{ doc.content }}
|
13 |
+
---
|
14 |
+
{% endfor %}
|
15 |
+
Query: {{ query }}
|
16 |
+
"""
|
17 |
+
|
18 |
+
template = Template(template_string)
|
19 |
+
|
20 |
+
|
21 |
+
def rag(query, top_k=5):
|
22 |
+
|
23 |
+
# Retrieve documents relevant to query
|
24 |
+
document_start = perf_counter()
|
25 |
+
documents = qd_retriever.retrieve(query, top_k=top_k)
|
26 |
+
document_time = document_start - perf_counter()
|
27 |
+
|
28 |
+
# Create Prompt
|
29 |
+
prompt = template.render(documents=documents, query=query)
|
30 |
+
|
31 |
+
# Query LLM with prompt based on relevant documents
|
32 |
+
llm_start = perf_counter()
|
33 |
+
result = generate(prompt=prompt, history='')
|
34 |
+
llm_time = llm_start - perf_counter()
|
35 |
+
|
36 |
+
times = (document_time, llm_time)
|
37 |
+
return prompt, result
|
38 |
+
|
backend/semantic_search.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from qdrant_haystack import QdrantDocumentStore
|
2 |
+
from haystack.nodes import EmbeddingRetriever
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
proj_dir = Path(__file__).parents[1]
|
6 |
+
qd_document_store = QdrantDocumentStore(path=str(proj_dir/'Qdrant'), index='RAGDemo')
|
7 |
+
qd_retriever = EmbeddingRetriever(document_store=qd_document_store,
|
8 |
+
embedding_model="BAAI/bge-base-en-v1.5",
|
9 |
+
model_format="sentence_transformers",
|
10 |
+
use_gpu=False)
|
utilities/retrievers.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from logging import getLogger
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from haystack.nodes import EmbeddingRetriever
|
7 |
+
from qdrant_haystack import QdrantDocumentStore
|
8 |
+
|
9 |
+
logger = getLogger(__name__)
|
10 |
+
|
11 |
+
proj_dir = Path(__file__).parents[1]
|
12 |
+
|
13 |
+
st_document_store_path = proj_dir / 'haystack_pickles' / 'simple-wiki_all-mpnet-base-v2_document-store.pkl'
|
14 |
+
|
15 |
+
logger.warning('Loading Document Store...')
|
16 |
+
with open(st_document_store_path, 'rb') as handle:
|
17 |
+
st_document_store = pickle.load(handle)
|
18 |
+
logger.warning('Loaded Document Store...')
|
19 |
+
|
20 |
+
qd_document_store = QdrantDocumentStore(path=str(proj_dir/'Qdrant'))
|
21 |
+
|
22 |
+
qd_document_store.main_device = torch.device('cpu')
|
23 |
+
qd_retriever = EmbeddingRetriever(document_store=qd_document_store,
|
24 |
+
embedding_model="sentence-transformers/all-mpnet-base-v2",
|
25 |
+
model_format="sentence_transformers",
|
26 |
+
use_gpu=True)
|