Spaces:
Build error
Build error
Commit
•
8b15eea
1
Parent(s):
1089f86
Add gradio app!
Browse files- app.py +93 -0
- backend/query_llm.py +21 -0
- backend/semantic_search.py +45 -0
- templates/template.j2 +10 -0
- templates/template_html.j2 +47 -0
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from time import perf_counter
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from jinja2 import Environment, FileSystemLoader
|
7 |
+
|
8 |
+
from backend.query_llm import generate
|
9 |
+
from backend.semantic_search import retriever
|
10 |
+
|
11 |
+
proj_dir = Path(__file__).parent
|
12 |
+
# Setting up the logging
|
13 |
+
logging.basicConfig(level=logging.INFO)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
# Set up the template environment with the templates directory
|
17 |
+
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
|
18 |
+
|
19 |
+
# Load the templates directly from the environment
|
20 |
+
template = env.get_template('template.j2')
|
21 |
+
template_html = env.get_template('template_html.j2')
|
22 |
+
|
23 |
+
|
24 |
+
def add_text(history, text):
|
25 |
+
history = [] if history is None else history
|
26 |
+
history = history + [(text, None)]
|
27 |
+
return history, gr.Textbox(value="", interactive=False)
|
28 |
+
|
29 |
+
|
30 |
+
def bot(history, system_prompt=""):
|
31 |
+
top_k = 3
|
32 |
+
query = history[-1][0]
|
33 |
+
|
34 |
+
logger.warning('Retrieving documents...')
|
35 |
+
# Retrieve documents relevant to query
|
36 |
+
document_start = perf_counter()
|
37 |
+
documents = retriever(query, top_k=top_k)
|
38 |
+
document_time = document_start - perf_counter()
|
39 |
+
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
40 |
+
|
41 |
+
# Create Prompt
|
42 |
+
prompt = template.render(documents=documents, query=query)
|
43 |
+
prompt_html = template_html.render(documents=documents, query=query)
|
44 |
+
logger.warning(prompt)
|
45 |
+
|
46 |
+
history[-1][1] = ""
|
47 |
+
for character in generate(prompt):
|
48 |
+
history[-1][1] = character
|
49 |
+
yield history, prompt_html
|
50 |
+
|
51 |
+
|
52 |
+
with gr.Blocks() as demo:
|
53 |
+
with gr.Tab("Application"):
|
54 |
+
chatbot = gr.Chatbot(
|
55 |
+
[],
|
56 |
+
elem_id="chatbot",
|
57 |
+
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
|
58 |
+
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
|
59 |
+
bubble_full_width=False,
|
60 |
+
show_copy_button=True,
|
61 |
+
show_share_button=True,
|
62 |
+
)
|
63 |
+
|
64 |
+
with gr.Row():
|
65 |
+
txt = gr.Textbox(
|
66 |
+
scale=3,
|
67 |
+
show_label=False,
|
68 |
+
placeholder="Enter text and press enter",
|
69 |
+
container=False,
|
70 |
+
)
|
71 |
+
txt_btn = gr.Button(value="Submit text", scale=1)
|
72 |
+
|
73 |
+
prompt_html = gr.HTML()
|
74 |
+
# Turn off interactivity while generating if you hit enter
|
75 |
+
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
76 |
+
bot, chatbot, [chatbot, prompt_html])
|
77 |
+
|
78 |
+
# Turn it back on
|
79 |
+
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
80 |
+
|
81 |
+
# Turn off interactivity while generating if you hit enter
|
82 |
+
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
83 |
+
bot, chatbot, [chatbot, prompt_html])
|
84 |
+
|
85 |
+
# Turn it back on
|
86 |
+
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
87 |
+
|
88 |
+
gr.Examples(['What is the capital of China, I think its Shanghai?',
|
89 |
+
'Why is the sky blue?',
|
90 |
+
'Who won the mens world cup in 2014?',], txt)
|
91 |
+
|
92 |
+
demo.queue()
|
93 |
+
demo.launch(debug=True)
|
backend/query_llm.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from os import getenv
|
3 |
+
|
4 |
+
|
5 |
+
API_URL = getenv('API_URL')
|
6 |
+
BEARER = getenv('BEARER')
|
7 |
+
|
8 |
+
|
9 |
+
headers = {
|
10 |
+
"Authorization": f"Bearer {BEARER}",
|
11 |
+
"Content-Type": "application/json"
|
12 |
+
}
|
13 |
+
|
14 |
+
def call_jais(payload):
|
15 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
16 |
+
return response.json()
|
17 |
+
|
18 |
+
def generate(prompt: str):
|
19 |
+
payload = {'inputs': '', 'prompt':prompt}
|
20 |
+
response = call_jais(payload)
|
21 |
+
return response
|
backend/semantic_search.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
import time
|
4 |
+
|
5 |
+
import lancedb
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
|
8 |
+
|
9 |
+
# Setting up the logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# Start the timer for loading the QdrantDocumentStore
|
14 |
+
start_time = time.perf_counter()
|
15 |
+
|
16 |
+
proj_dir = Path(__file__).parents[1]
|
17 |
+
|
18 |
+
# Log the time taken to load the QdrantDocumentStore
|
19 |
+
db = lancedb.connect(proj_dir/"lancedb")
|
20 |
+
tbl = db.open_table('arabic-wiki')
|
21 |
+
lancedb_loading_time = time.perf_counter() - start_time
|
22 |
+
logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds")
|
23 |
+
|
24 |
+
# Start the timer for loading the EmbeddingRetriever
|
25 |
+
start_time = time.perf_counter()
|
26 |
+
|
27 |
+
name="sentence-transformers/paraphrase-multilingual-minilm-l12-v2"
|
28 |
+
st_model = SentenceTransformer(name, device='cuda')
|
29 |
+
|
30 |
+
# used for both training and querying
|
31 |
+
def embed_func(query):
|
32 |
+
return st_model.encode(query)
|
33 |
+
|
34 |
+
def vector_search(query_vector, top_k):
|
35 |
+
return tbl.search(query_vector).limit(top_k).to_list()
|
36 |
+
|
37 |
+
def retriever(query, top_k=3):
|
38 |
+
query_vector = embed_func(query)
|
39 |
+
documents = vector_search(query_vector, top_k)
|
40 |
+
return documents
|
41 |
+
|
42 |
+
|
43 |
+
# Log the time taken to load the EmbeddingRetriever
|
44 |
+
retriever_loading_time = time.perf_counter() - start_time
|
45 |
+
logger.info(f"Time taken to load EmbeddingRetriever: {retriever_loading_time:.6f} seconds")
|
templates/template.j2
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Instruction: Use the following unique documents in the Context section to answer the Query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
2 |
+
### Context
|
3 |
+
{% for doc in documents %}
|
4 |
+
---
|
5 |
+
{{ doc.content }}
|
6 |
+
{% endfor %}
|
7 |
+
---
|
8 |
+
[|AI|]:
|
9 |
+
### Query: [|Human|] {{query}}
|
10 |
+
### Response: [|AI|]
|
templates/template_html.j2
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h2>Prompt</h2>
|
2 |
+
Below is the prompt that is given to the model. <hr>
|
3 |
+
<h2>Instruction:</h2>
|
4 |
+
<span style="color: #FF00FF;">Use the following unique documents in the Context section to answer the Query at the end. If you don't know the answer, just say that you don't know, <span style="color: #FF00FF; font-weight: bold;">don't try to make up an answer.</span></span><br>
|
5 |
+
<h2>Context</h2>
|
6 |
+
{% for doc in documents %}
|
7 |
+
<details class="doc-box">
|
8 |
+
<summary>
|
9 |
+
<b>Doc {{ loop.index }}:</b> <span class="doc-short">{{ doc.content[:100] }}...</span>
|
10 |
+
</summary>
|
11 |
+
<div class="doc-full">{{ doc.content }}</div>
|
12 |
+
</details>
|
13 |
+
{% endfor %}
|
14 |
+
|
15 |
+
<h2>Query</h2> <span style="color: #801616;">{{ query }}</span>
|
16 |
+
|
17 |
+
<style>
|
18 |
+
.doc-box {
|
19 |
+
padding: 10px;
|
20 |
+
margin-top: 10px;
|
21 |
+
background-color: #48a3ff;
|
22 |
+
border: none;
|
23 |
+
}
|
24 |
+
.doc-short, .doc-full {
|
25 |
+
color: white;
|
26 |
+
}
|
27 |
+
summary::-webkit-details-marker {
|
28 |
+
color: white;
|
29 |
+
}
|
30 |
+
</style>
|
31 |
+
|
32 |
+
<script>
|
33 |
+
document.addEventListener("DOMContentLoaded", function() {
|
34 |
+
const detailsElements = document.querySelectorAll('.doc-box');
|
35 |
+
|
36 |
+
detailsElements.forEach(detail => {
|
37 |
+
detail.addEventListener('toggle', function() {
|
38 |
+
const docShort = this.querySelector('.doc-short');
|
39 |
+
if (this.open) {
|
40 |
+
docShort.style.display = 'none';
|
41 |
+
} else {
|
42 |
+
docShort.style.display = 'inline';
|
43 |
+
}
|
44 |
+
});
|
45 |
+
});
|
46 |
+
});
|
47 |
+
</script>
|