derek-thomas HF staff commited on
Commit
9e6b8ed
1 Parent(s): 1a9b6de

Adding application, not finished yet, still wont expand text

Browse files
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import perf_counter
2
+
3
+ from jinja2 import Template
4
+
5
+ from backend.semantic_search import qd_retriever
6
+
7
+ template_string = """
8
+ Instructions: Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
9
+ Context:
10
+ ---
11
+ {% for doc in documents %}
12
+ {{ doc.content }}
13
+ ---
14
+ {% endfor %}
15
+ Query: {{ query }}
16
+ """
17
+
18
+ md_template_string = """
19
+ <b>Instructions</b>:
20
+ <span style="color: green;">Use the following pieces of context to answer the question at the end.<br>If you don't know the answer, just say that you don't know, <span style="color: green; font-weight: bold;">don't try to make up an answer.</span></span><br>
21
+
22
+ <b>Context</b>:
23
+ {% for doc in documents %}
24
+ <div id=\"box{{ loop.index }}\" style=\"border: 2px solid #aaa; padding: 10px; margin-top: 10px; border-radius: 5px; background-color: #1E90FF; position: relative; cursor: pointer;\">
25
+ <div style=\"font-size: 0.8em; position: absolute; top: 10px; left: 10px;\"><b>Doc {{ loop.index }}</b></div>
26
+ <span id=\"doc{{ loop.index }}-short\" style=\"color: white; display: block; margin-top: 20px;\">{{ doc.content[:50] }}...</span>
27
+ <span id=\"doc{{ loop.index }}-full\" style=\"color: white; display: none; margin-top: 20px;\">{{ doc.content }}</span>
28
+ </div>
29
+ {% endfor %}
30
+ <b>Query</b>: <span style=\"color: yellow;\">{{ query }}</span>
31
+ <script>
32
+ document.addEventListener("DOMContentLoaded", function() {
33
+ {% for doc in documents %}
34
+ document.getElementById("box{{ loop.index }}").addEventListener("click", function() {
35
+ toggleContent('doc{{ loop.index }}');
36
+ });
37
+ {% endfor %}
38
+ });
39
+
40
+ function toggleContent(docID) {
41
+ var shortContent = document.getElementById(docID + '-short');
42
+ var fullContent = document.getElementById(docID + '-full');
43
+ if (fullContent.style.display === 'none') {
44
+ shortContent.style.display = 'none';
45
+ fullContent.style.display = 'block';
46
+ } else {
47
+ shortContent.style.display = 'block';
48
+ fullContent.style.display = 'none';
49
+ }
50
+ }
51
+ </script>
52
+ """
53
+
54
+ template = Template(template_string)
55
+ md_template = Template(md_template_string)
56
+ import gradio as gr
57
+
58
+ from backend.query_llm import generate
59
+
60
+
61
+ def add_text(history, text):
62
+ history = [] if history is None else history
63
+ history = history + [(text, None)]
64
+ return history, gr.Textbox(value="", interactive=False)
65
+
66
+
67
+ def bot(history, system_prompt=""):
68
+ top_k = 5
69
+ query = history[-1][0]
70
+
71
+ # Retrieve documents relevant to query
72
+ document_start = perf_counter()
73
+ documents = qd_retriever.retrieve(query, top_k=top_k)
74
+ document_time = document_start - perf_counter()
75
+
76
+ # Create Prompt
77
+ prompt = template.render(documents=documents, query=query)
78
+ md_prompt = md_template.render(documents=documents, query=query)
79
+
80
+ # # Query LLM with prompt based on relevant documents
81
+ # llm_start = perf_counter()
82
+ # result = generate(prompt=prompt, history='')
83
+ # llm_time = llm_start - perf_counter()
84
+ # times = (document_time, llm_time)
85
+
86
+ history[-1][1] = ""
87
+ for character in generate(prompt, history[:-1]):
88
+ history[-1][1] = character
89
+ yield history, md_prompt
90
+
91
+
92
+ with gr.Blocks() as demo:
93
+ with gr.Tab("Application"):
94
+ chatbot = gr.Chatbot(
95
+ [],
96
+ elem_id="chatbot",
97
+ avatar_images=('examples/lama.jpeg', 'examples/lama2.jpeg'),
98
+ bubble_full_width=False,
99
+ show_copy_button=True,
100
+ show_share_button=True,
101
+ )
102
+
103
+ with gr.Row():
104
+ txt = gr.Textbox(
105
+ scale=3,
106
+ show_label=False,
107
+ placeholder="Enter text and press enter",
108
+ container=False,
109
+ )
110
+ txt_btn = gr.Button(value="Submit text", scale=1)
111
+
112
+ prompt_md = gr.HTML()
113
+ # Turn off interactivity while generating if you hit enter
114
+ txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
115
+ bot, chatbot, [chatbot, prompt_md])
116
+
117
+ # Turn it back on
118
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
119
+
120
+ # Turn off interactivity while generating if you hit enter
121
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
122
+ bot, chatbot, [chatbot, prompt_md])
123
+
124
+ # Turn it back on
125
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
126
+
127
+ gr.Examples(['What is the largest city on earth?', 'Who has the record for the fastest mile?'], txt)
128
+
129
+ demo.queue()
130
+ demo.launch(debug=True)
backend/query_llm.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Generator, List
2
+
3
+ import gradio as gr
4
+ from huggingface_hub import InferenceClient
5
+ from transformers import AutoTokenizer
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
8
+
9
+ temperature = 0.9
10
+ top_p = 0.6
11
+ repetition_penalty = 1.2
12
+
13
+ text_client = InferenceClient(
14
+ "mistralai/Mistral-7B-Instruct-v0.1"
15
+ )
16
+
17
+
18
+ def format_prompt(message: str) -> str:
19
+ """
20
+ Formats the given message using a chat template.
21
+
22
+ Args:
23
+ message (str): The user message to be formatted.
24
+
25
+ Returns:
26
+ str: Formatted message after applying the chat template.
27
+ """
28
+
29
+ # Create a list of message dictionaries with role and content
30
+ messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
31
+
32
+ # Return the message after applying the chat template
33
+ return tokenizer.apply_chat_template(messages, tokenize=False)
34
+
35
+
36
+ def generate(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
37
+ top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
38
+ """
39
+ Generate a sequence of tokens based on a given prompt and history using Mistral client.
40
+
41
+ Args:
42
+ prompt (str): The initial prompt for the text generation.
43
+ history (str): Context or history for the text generation.
44
+ temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
45
+ max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
46
+ top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
47
+ repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
48
+
49
+ Returns:
50
+ Generator[str, None, str]: A generator yielding chunks of generated text.
51
+ Returns a final string if an error occurs.
52
+ """
53
+
54
+ temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
55
+ top_p = float(top_p)
56
+
57
+ generate_kwargs = {
58
+ 'temperature': temperature,
59
+ 'max_new_tokens': max_new_tokens,
60
+ 'top_p': top_p,
61
+ 'repetition_penalty': repetition_penalty,
62
+ 'do_sample': True,
63
+ 'seed': 42,
64
+ }
65
+
66
+ formatted_prompt = format_prompt(prompt)
67
+
68
+ try:
69
+ stream = text_client.text_generation(formatted_prompt, **generate_kwargs,
70
+ stream=True, details=True, return_full_text=False)
71
+ output = ""
72
+ for response in stream:
73
+ output += response.token.text
74
+ yield output
75
+
76
+ except Exception as e:
77
+ if "Too Many Requests" in str(e):
78
+ print("ERROR: Too many requests on Mistral client")
79
+ gr.Warning("Unfortunately Mistral is unable to process")
80
+ return "Unfortunately, I am not able to process your request now."
81
+ else:
82
+ print("Unhandled Exception:", str(e))
83
+ gr.Warning("Unfortunately Mistral is unable to process")
84
+ return "I do not know what happened, but I couldn't understand you."
85
+
86
+ return output
backend/retrieval_augmented_generation.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import perf_counter
2
+ from jinja2 import Template
3
+
4
+ from backend.query_llm import generate
5
+ from backend.semantic_search import qd_retriever
6
+
7
+ template_string = """
8
+ Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
9
+ Context:
10
+ ---
11
+ {% for doc in documents %}
12
+ {{ doc.content }}
13
+ ---
14
+ {% endfor %}
15
+ Query: {{ query }}
16
+ """
17
+
18
+ template = Template(template_string)
19
+
20
+
21
+ def rag(query, top_k=5):
22
+
23
+ # Retrieve documents relevant to query
24
+ document_start = perf_counter()
25
+ documents = qd_retriever.retrieve(query, top_k=top_k)
26
+ document_time = document_start - perf_counter()
27
+
28
+ # Create Prompt
29
+ prompt = template.render(documents=documents, query=query)
30
+
31
+ # Query LLM with prompt based on relevant documents
32
+ llm_start = perf_counter()
33
+ result = generate(prompt=prompt, history='')
34
+ llm_time = llm_start - perf_counter()
35
+
36
+ times = (document_time, llm_time)
37
+ return prompt, result
38
+
backend/semantic_search.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_haystack import QdrantDocumentStore
2
+ from haystack.nodes import EmbeddingRetriever
3
+ from pathlib import Path
4
+
5
+ proj_dir = Path(__file__).parents[1]
6
+ qd_document_store = QdrantDocumentStore(path=str(proj_dir/'Qdrant'), index='RAGDemo')
7
+ qd_retriever = EmbeddingRetriever(document_store=qd_document_store,
8
+ embedding_model="BAAI/bge-base-en-v1.5",
9
+ model_format="sentence_transformers",
10
+ use_gpu=False)
utilities/retrievers.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from logging import getLogger
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from haystack.nodes import EmbeddingRetriever
7
+ from qdrant_haystack import QdrantDocumentStore
8
+
9
+ logger = getLogger(__name__)
10
+
11
+ proj_dir = Path(__file__).parents[1]
12
+
13
+ st_document_store_path = proj_dir / 'haystack_pickles' / 'simple-wiki_all-mpnet-base-v2_document-store.pkl'
14
+
15
+ logger.warning('Loading Document Store...')
16
+ with open(st_document_store_path, 'rb') as handle:
17
+ st_document_store = pickle.load(handle)
18
+ logger.warning('Loaded Document Store...')
19
+
20
+ qd_document_store = QdrantDocumentStore(path=str(proj_dir/'Qdrant'))
21
+
22
+ qd_document_store.main_device = torch.device('cpu')
23
+ qd_retriever = EmbeddingRetriever(document_store=qd_document_store,
24
+ embedding_model="sentence-transformers/all-mpnet-base-v2",
25
+ model_format="sentence_transformers",
26
+ use_gpu=True)