|
import gradio as gr |
|
from datasets import load_dataset |
|
import re |
|
|
|
|
|
def load_dataset_demo(name): |
|
try: |
|
dataset = load_dataset(name)["train"].filter(lambda x: x["flags"]) |
|
except Exception as _: |
|
dataset = load_dataset(name)["train"] |
|
return dataset |
|
|
|
|
|
NAME_DATASETS = [ |
|
"Self-GRIT/selfrag_dataset-embed_query_instruct-Meta-Llama-3-70B-Instruct_temp-0.01", |
|
"Self-GRIT/selfrag_dataset-embed_query_instruct-Meta-Llama-3-70B-Instruct_temp-1.0", |
|
"Self-GRIT/selfrag_dataset_mini-embed_query_instruct-Meta-Llama-3-70B-Instruct_temp-0.01", |
|
"Self-GRIT/selfrag_dataset_mini-embed_query_instruct-Meta-Llama-3-8B-Instruct", |
|
] |
|
|
|
DATASETS = {name: load_dataset_demo(name) for name in NAME_DATASETS} |
|
INSTRUCTION_COL = "instruction" |
|
OUTPUT_COL = "output" |
|
OUTPUT_ORIGIN = "output_origin" |
|
|
|
|
|
def extract_pairs(text): |
|
|
|
pattern = r"<embed>(.*?)</embed><passage>(.*?)</passage>" |
|
|
|
matches = re.findall(pattern, text, re.DOTALL) |
|
return matches |
|
|
|
|
|
def preprocess_qa_pairs(text): |
|
qa_pairs = extract_pairs(text) |
|
response = "" |
|
if len(qa_pairs) == 0: |
|
response = "No query-passage pairs found." |
|
else: |
|
for i, (query, passage) in enumerate(qa_pairs): |
|
response += f"========================== QP-Pair {i+1} =============================\n" |
|
response += f"Query:\n{query.strip()}\n" |
|
response += f"Passage:\n{passage.strip()}\n\n" |
|
return response |
|
|
|
|
|
def output_fn(dropdown, slider): |
|
dataset = DATASETS[dropdown] |
|
example = dataset[int(slider)] |
|
return ( |
|
example[INSTRUCTION_COL], |
|
example[OUTPUT_ORIGIN], |
|
example[OUTPUT_COL], |
|
preprocess_qa_pairs(example[OUTPUT_COL]), |
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Explore Self-RAG Datasets") |
|
with gr.Group(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
dropdown = gr.Dropdown( |
|
NAME_DATASETS, |
|
multiselect=False, |
|
label="Dataset", |
|
info="Select the dataset name", |
|
) |
|
with gr.Column(): |
|
slider = gr.Slider( |
|
minimum=0, |
|
maximum=max([len(dataset) for _, dataset in DATASETS.items()]), |
|
step=1, |
|
label="#example", |
|
value=0, |
|
) |
|
button = gr.Button(value="Submit", variant="primary") |
|
with gr.Group(): |
|
with gr.Row(): |
|
output_instruction = gr.Textbox( |
|
label="Instruction", placeholder="Instruction", type="text" |
|
) |
|
with gr.Row(): |
|
with gr.Row(): |
|
output_self_rag = gr.Textbox( |
|
label="SELG-RAG output", placeholder="SELG-RAG output", type="text" |
|
) |
|
output_self_grit = gr.Textbox( |
|
label="SELF-GRIT output", |
|
placeholder="SELF-GRIT output", |
|
type="text", |
|
) |
|
with gr.Group(): |
|
output_qps = gr.Textbox( |
|
label="Query-Passage Pairs", placeholder="Query-Passage Pairs", type="text" |
|
) |
|
button.click( |
|
fn=output_fn, |
|
inputs=[dropdown, slider], |
|
outputs=[output_instruction, output_self_rag, output_self_grit, output_qps], |
|
) |
|
demo.launch(share=True, debug=True) |
|
|