File size: 5,437 Bytes
37d119f
 
 
 
0d73887
5b6661f
 
9e6b8ed
 
83e218d
629069a
9e6b8ed
83e218d
9e6b8ed
 
629069a
5b6661f
 
 
629069a
 
5b6661f
629069a
 
 
 
9e6b8ed
378c66b
0d73887
378c66b
 
 
9e6b8ed
 
 
 
 
 
 
378c66b
00754c3
9e6b8ed
 
5638045
 
 
378c66b
0d73887
378c66b
 
0d73887
 
 
 
 
 
 
5638045
 
 
 
 
9e6b8ed
 
 
 
629069a
9e6b8ed
 
 
378c66b
9e6b8ed
 
 
5b6661f
 
9e6b8ed
 
 
 
 
 
 
 
 
 
 
 
 
 
3014aa6
 
 
629069a
378c66b
9e6b8ed
629069a
9e6b8ed
 
 
 
 
 
629069a
9e6b8ed
 
 
 
378c66b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3014aa6
 
 
378c66b
 
 
0d73887
378c66b
 
 
 
 
 
0d73887
378c66b
 
 
 
9e6b8ed
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import subprocess

subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])

from functools import partial
import logging
from pathlib import Path
from time import perf_counter

import gradio as gr
from jinja2 import Environment, FileSystemLoader

from backend.query_llm import generate
from backend.semantic_search import qd_retriever

proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))

# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')

# Examples
examples = ['What is the capital of China?',
            'Why is the sky blue?',
            'Who won the mens world cup in 2014?', ]


def add_text(history, text):
    history = [] if history is None else history
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)


def bot(history, hyde=False):
    top_k = 4
    query = history[-1][0]

    logger.warning('Retrieving documents...')
    # Retrieve documents relevant to query
    document_start = perf_counter()
    if hyde:
        hyde_document = ""
        generator = generate(f"Write a wikipedia article intro paragraph to answer this query: {query}", history)
        for output_chunk in generator:
            hyde_document = output_chunk

        logger.warning(hyde_document)
        documents = qd_retriever.retrieve(hyde_document, top_k=top_k)
    else:
        documents = qd_retriever.retrieve(query, top_k=top_k)
    document_time = perf_counter() - document_start
    logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')

    # Create Prompt
    prompt = template.render(documents=documents, query=query)
    prompt_html = template_html.render(documents=documents, query=query)

    history[-1][1] = ""
    for character in generate(prompt, history[:-1]):
        history[-1][1] = character
        yield history, prompt_html


with gr.Blocks() as demo:
    with gr.Tab("RAGDemo"):
        chatbot = gr.Chatbot(
                [],
                elem_id="chatbot",
                avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
                               'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
                bubble_full_width=False,
                show_copy_button=True,
                show_share_button=True,
                )

        with gr.Row():
            txt = gr.Textbox(
                    scale=3,
                    show_label=False,
                    placeholder="Enter text and press enter",
                    container=False,
                    )
            txt_btn = gr.Button(value="Submit text", scale=1)

        # Examples
        gr.Examples(examples, txt)

        prompt_html = gr.HTML()
        # Turn off interactivity while generating if you click
        txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
                bot, chatbot, [chatbot, prompt_html])

        # Turn it back on
        txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

        # Turn off interactivity while generating if you hit enter
        txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
                bot, chatbot, [chatbot, prompt_html])

        # Turn it back on
        txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

    with gr.Tab("RAGDemo + HyDE"):
        hyde_chatbot = gr.Chatbot(
                [],
                elem_id="chatbot",
                avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
                               'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
                bubble_full_width=False,
                show_copy_button=True,
                show_share_button=True,
                )

        with gr.Row():
            hyde_txt = gr.Textbox(
                    scale=3,
                    show_label=False,
                    placeholder="Enter text and press enter",
                    container=False,
                    )
            hyde_txt_btn = gr.Button(value="Submit text", scale=1)

        # Examples
        gr.Examples(examples, hyde_txt)

        hyde_prompt_html = gr.HTML()
        # Turn off interactivity while generating if you click
        hyde_txt_msg = hyde_txt_btn.click(add_text, [hyde_chatbot, hyde_txt], [hyde_chatbot, hyde_txt], queue=False).then(
                partial(bot, hyde=True), [hyde_chatbot], [hyde_chatbot, hyde_prompt_html])

        # Turn it back on
        hyde_txt_msg.then(lambda: gr.Textbox(interactive=True), None, [hyde_txt], queue=False)

        # Turn off interactivity while generating if you hit enter
        hyde_txt_msg = hyde_txt.submit(add_text, [hyde_chatbot, hyde_txt], [hyde_chatbot, hyde_txt], queue=False).then(
                partial(bot, hyde=True), [hyde_chatbot], [hyde_chatbot, hyde_prompt_html])

        # Turn it back on
        hyde_txt_msg.then(lambda: gr.Textbox(interactive=True), None, [hyde_txt], queue=False)

demo.queue()
demo.launch(debug=True)