Spaces:
Build error
Build error
File size: 7,050 Bytes
8b15eea 3cb8374 8b15eea 5f75644 8b15eea 91f49a8 8b15eea df1aa0b 4e966cd df1aa0b 8de88bd 0665e63 3cb8374 8de88bd df1aa0b 8b15eea 8de88bd df1aa0b 8b15eea 8de88bd 7bda49c 8de88bd 8b15eea df1aa0b 3cb8374 8b15eea 3cb8374 df1aa0b 8b15eea 7bda49c 8b15eea 5d0067c 1a7a096 5d0067c 8b15eea 5d0067c 8de88bd 8b15eea cf11d3f 8b15eea 8de88bd 8b15eea 8de88bd 8b15eea 8de88bd 3cb8374 8de88bd 8b15eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import logging
from functools import partial
from pathlib import Path
from time import perf_counter
import gradio as gr
from jinja2 import Environment, FileSystemLoader
from transformers import AutoTokenizer
from backend.query_llm import check_endpoint_status, generate
from backend.semantic_search import retriever
proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('derek-thomas/jais-13b-chat-hf')
# Examples
examples = ['من كان طرفي معركة اكتيوم البحرية؟',
'لم السماء زرقاء؟',
"من فاز بكأس العالم للرجال في عام 2014؟",]
def add_text(history, text):
history = [] if history is None else history
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
def bot(history, hyde=False):
top_k = 5
query = history[-1][0]
logger.warning('Retrieving documents...')
# Retrieve documents relevant to query
document_start = perf_counter()
if hyde:
hyde_document = generate(f"Write a wikipedia article intro paragraph to answer this query: {query}").split('### Response: [|AI|]')[-1]
logger.warning(hyde_document)
documents = retriever(hyde_document, top_k=top_k)
else:
documents = retriever(query, top_k=top_k)
document_time = perf_counter() - document_start
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
# Function to count tokens
def count_tokens(text):
return len(tokenizer.encode(text))
# Create Prompt
prompt = template.render(documents=documents, query=query)
# Check if the prompt is too long
token_count = count_tokens(prompt)
while token_count > 2048:
# Shorten your documents here. This is just a placeholder for the logic you'd use.
documents.pop() # Remove the last document
prompt = template.render(documents=documents, query=query) # Re-render the prompt
token_count = count_tokens(prompt) # Re-count tokens
prompt_html = template_html.render(documents=documents, query=query)
history[-1][1] = ""
response = generate(prompt)
history[-1][1] = response.split('### Response: [|AI|]')[-1]
return history, prompt_html
intro_md = """
# Arabic RAG
This is a project to demonstrate Retreiver Augmented Generation (RAG) in Arabic and English. It uses
[Arabic Wikipedia](https://ar.wikipedia.org/wiki) as a base to answer questions you have.
A retriever ([sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/discussions/8))
will find the articles relevant to your query and include them in a prompt so the reader ([core42/jais-13b-chat](https://huggingface.co/core42/jais-13b-chat))
can then answer your questions on it.
You can see the prompt clearly displayed below the chatbot to understand what is going to the LLM.
# Read this if you get an error
I'm using Inference Endpoint's Scale to Zero to save money on GPUs. If the staus shows its not "Running" send a
chat to wake it up. You will get a `500 error` and it will take ~7 min to wake up.
"""
with gr.Blocks() as demo:
gr.Markdown(intro_md)
endpoint_status = gr.Textbox(check_endpoint_status, label="Inference Endpoint Status", every=1)
with gr.Tab("Arabic-RAG"):
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
bubble_full_width=False,
show_copy_button=True,
show_share_button=True,
)
with gr.Row():
txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter query in Arabic or English and press enter",
container=False,
)
txt_btn = gr.Button(value="Submit text", scale=1)
gr.Examples(examples, txt)
prompt_html = gr.HTML()
# Turn off interactivity while generating if you click
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, [chatbot, prompt_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
# Turn off interactivity while generating if you hit enter
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, [chatbot, prompt_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
with gr.Tab("Arabic-RAG + HyDE"):
hyde_chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
bubble_full_width=False,
show_copy_button=True,
show_share_button=True,
)
with gr.Row():
hyde_txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter text and press enter",
container=False,
)
hyde_txt_btn = gr.Button(value="Submit text", scale=1)
gr.Examples(examples, hyde_txt)
hyde_prompt_html = gr.HTML()
# Turn off interactivity while generating if you click
hyde_txt_msg = hyde_txt_btn.click(add_text, [hyde_chatbot, hyde_txt], [hyde_chatbot, hyde_txt],
queue=False).then(
partial(bot, hyde=True), [hyde_chatbot], [hyde_chatbot, hyde_prompt_html])
# Turn it back on
hyde_txt_msg.then(lambda: gr.Textbox(interactive=True), None, [hyde_txt], queue=False)
# Turn off interactivity while generating if you hit enter
hyde_txt_msg = hyde_txt.submit(add_text, [hyde_chatbot, hyde_txt], [hyde_chatbot, hyde_txt], queue=False).then(
partial(bot, hyde=True), [hyde_chatbot], [hyde_chatbot, hyde_prompt_html])
# Turn it back on
hyde_txt_msg.then(lambda: gr.Textbox(interactive=True), None, [hyde_txt], queue=False)
demo.queue()
demo.launch(debug=True)
|