will33am's picture
update
fd3b2e0
raw
history blame
3.48 kB
import gradio as gr
from datasets import load_dataset
import re
def load_dataset_demo(name):
try:
dataset = load_dataset(name)["train"].filter(lambda x: x["flags"])
except Exception as _:
dataset = load_dataset(name)["train"]
return dataset
NAME_DATASETS = [
"Self-GRIT/selfrag_dataset-embed_query_instruct-Meta-Llama-3-70B-Instruct_temp-0.01",
"Self-GRIT/selfrag_dataset-embed_query_instruct-Meta-Llama-3-70B-Instruct_temp-1.0",
"Self-GRIT/selfrag_dataset_mini-embed_query_instruct-Meta-Llama-3-70B-Instruct_temp-0.01",
"Self-GRIT/selfrag_dataset_mini-embed_query_instruct-Meta-Llama-3-8B-Instruct",
]
DATASETS = {name: load_dataset_demo(name) for name in NAME_DATASETS}
INSTRUCTION_COL = "instruction"
OUTPUT_COL = "output"
OUTPUT_ORIGIN = "output_origin"
def extract_pairs(text):
# Regex pattern to match <embed>...</embed><passage>...</passage> pairs
pattern = r"<embed>(.*?)</embed><passage>(.*?)</passage>"
# Find all matches in the text
matches = re.findall(pattern, text, re.DOTALL)
return matches
def preprocess_qa_pairs(text):
qa_pairs = extract_pairs(text)
response = ""
if len(qa_pairs) == 0:
response = "No query-passage pairs found."
else:
for i, (query, passage) in enumerate(qa_pairs):
response += f"========================== QP-Pair {i+1} =============================\n"
response += f"Query:\n{query.strip()}\n"
response += f"Passage:\n{passage.strip()}\n\n"
return response
def output_fn(dropdown, slider):
dataset = DATASETS[dropdown]
example = dataset[int(slider)]
return (
example[INSTRUCTION_COL],
example[OUTPUT_ORIGIN],
example[OUTPUT_COL],
preprocess_qa_pairs(example[OUTPUT_COL]),
)
with gr.Blocks() as demo:
gr.Markdown("# Explore Self-RAG Datasets")
with gr.Group():
with gr.Row():
with gr.Column():
dropdown = gr.Dropdown(
NAME_DATASETS,
multiselect=False,
label="Dataset",
info="Select the dataset name",
)
with gr.Column():
slider = gr.Slider(
minimum=0,
maximum=max([len(dataset) for _, dataset in DATASETS.items()]),
step=1,
label="#example",
value=0,
)
button = gr.Button(value="Submit", variant="primary")
with gr.Group():
with gr.Row():
output_instruction = gr.Textbox(
label="Instruction", placeholder="Instruction", type="text"
)
with gr.Row():
with gr.Row():
output_self_rag = gr.Textbox(
label="SELG-RAG output", placeholder="SELG-RAG output", type="text"
)
output_self_grit = gr.Textbox(
label="SELF-GRIT output",
placeholder="SELF-GRIT output",
type="text",
)
with gr.Group():
output_qps = gr.Textbox(
label="Query-Passage Pairs", placeholder="Query-Passage Pairs", type="text"
)
button.click(
fn=output_fn,
inputs=[dropdown, slider],
outputs=[output_instruction, output_self_rag, output_self_grit, output_qps],
)
demo.launch(share=True, debug=True)